The main idea is to download pre-trained Imagenet model weights from the timm repository and fit Assymetric Laplace Distribution to the final fully connected layer weights. 

Please refer to my post https://sidml.github.io/An-interpretation-of-the-final-fully-connected-layer/ for more details about the project.

In [1]:
!pip install -q timm

[0m

In [2]:
!wget https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv

--2022-05-21 09:18:11--  https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 40205 (39K) [text/plain]
Saving to: ‘results-imagenet.csv’


2022-05-21 09:18:11 (9.08 MB/s) - ‘results-imagenet.csv’ saved [40205/40205]



In [3]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
from collections import OrderedDict
from tqdm.auto import tqdm
from glob import glob
import pdb, json, timm, os

## Load Imagenet class names

In [4]:
def setup_imagenet_classes(data_dir = "../input/imagenetval"):
        
    # Read the categories
    with open(f"{data_dir}/imagenet_classes.txt", "r") as f:
        categories = [s.strip().lower() for s in f.readlines()]
    
    # https://gist.github.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57
    cls_map = pd.read_csv(f"{data_dir}/map_clsloc.txt", sep=' ', header=None)
    cls_map.columns = ['imagenet_label', 'imagenet_clsnum', 'string_label']

    with open(f"{data_dir}/imagenet_class_index.json") as data_file:    
        data = json.load(data_file)              
            
    cls_stats = pd.DataFrame(data.values(), index=data.keys(),
                      columns=['imagenet_label', 'string_label'])
    cls_stats.index.name = 'pytorch_clsnum'
    cls_stats = cls_stats.reset_index()
    cls_stats = cls_stats.set_index('imagenet_label')

    cls_stats.loc[cls_map.imagenet_label, "imagenet_clsnum"] = np.arange(1, 1001)
    cls_stats['imagenet_clsnum'] = cls_stats['imagenet_clsnum'].astype(int)
    
    val_sol = pd.read_csv(f"{data_dir}/LOC_train_solution.csv")
    # val_sol = pd.read_csv("LOC_val_solution.csv")
    all_rows = val_sol.PredictionString.apply(lambda x: x.strip().split(' '))

    class_counts = val_sol.ImageId.apply(lambda x: x.split('_')[0]).value_counts()
    class_counts.name = "cls_count"
    class_counts.index.name = 'imagenet_label'
    cls_stats = cls_stats.merge(class_counts, on='imagenet_label')

    cls_stats['pytorch_clsnum'] = cls_stats['pytorch_clsnum'].astype(int)


    if cls_stats.index.name!='pytorch_clsnum':
        cls_stats = cls_stats.reset_index()
        cls_stats = cls_stats.set_index('pytorch_clsnum')

    return cls_stats

## Extract Final fully connected layer weights

In [5]:
@torch.no_grad()
def setup_weights(model_name):
    model = timm.create_model(model_name, pretrained=True)    
    state = model.state_dict()        
    # last two layers should be fc
    layer_name = list(state.keys())[-2:]
    try:
        fc_w = state[layer_name[0]].cpu().numpy()
        if len(fc_w.shape)!=2:
            print(layer_name, fc_w.shape)
            print('unable to extract fc layer')
            return None, None
        if len(layer_name)==2:
            fc_b = state[layer_name[1]].cpu().numpy()
        else:
            fc_b = None
    except Exception as e:
        print('couldnt process', model_name)
        print(e)
        return None, None
    return fc_w, fc_b

## Fit Asymmetric Laplace Distribution

In [6]:
def fit_laplace(y, fig_name=None):
    y = np.sort(y)
    x = np.arange(len(y))/len(y)
    target = np.log(y)
    coeff = np.polyfit(x, target, 1)
    predict_y = np.poly1d(coeff)(x)
    r2 = r2_score(target, predict_y)
    if fig_name:
        cls_name = fig_name.split('/')[-1]
        plt.plot(predict_y); plt.plot(target)    
        plt.legend(['Predicted', 'log(Actual Weights)'])
        plt.title(f"Class :{cls_name}, R2: {r2:.3f}"); plt.tight_layout();
        plt.savefig(f"{fig_name}.jpg"); plt.clf()
    return coeff[0], coeff[1], r2

def fit_joint_laplace(fc_w):
    pos_weights = fc_w[fc_w > 0]
    neg_weights = -fc_w[fc_w < 0]

    slope_pos, int_pos, r2_pos = fit_laplace(pos_weights)
    slope_neg, int_neg, r2_neg = fit_laplace(neg_weights)
    
    # estimate dataset mean assuming joint distribution of iid asymmetric laplace distribution
    k_est_pos = np.sqrt(np.exp(int_pos)/(slope_pos-np.exp(int_pos)))
    laplace_lam_est_pos = slope_pos * k_est_pos
    mean_pos = (1-k_est_pos**2)/(laplace_lam_est_pos*k_est_pos)


    k_est_neg = np.sqrt(np.exp(int_neg)/(slope_neg-np.exp(int_neg)))
    laplace_lam_est_neg = slope_neg * k_est_neg
    mean_neg = (1-k_est_neg**2)/(laplace_lam_est_neg*k_est_neg)

    print(f"joint fit: mean_pos {mean_pos:.2f}, r2_pos {r2_pos:.2f}")
    print(f"joint fit: mean_neg {mean_neg:.2f}, r2_neg {r2_neg:.2f}")

    return
    

def fit_exp(cls_stats_df, name, fc_w):
    
    # reset previous values if they exist and write new values
    cls_stats_df.loc[:, ['slope_pos', 'int_pos', 'r2_pos',\
                      'slope_neg', 'int_neg', 'r2_neg']] = None
    
    os.makedirs(f"./{name}/pos", exist_ok=True)
    os.makedirs(f"./{name}/neg", exist_ok=True)

    plt.figure()
    for pytorch_clsnum in range(fc_w.shape[0]):
        y = fc_w[pytorch_clsnum]
        y_pos, y_neg = y[y>=0], -1*y[y<0]
        fig_name = f"./{name}/pos/{cls_stats_df.loc[pytorch_clsnum, 'string_label']}"
        cls_stats_df.loc[pytorch_clsnum, ['slope_pos', 'int_pos', 'r2_pos']] = fit_laplace(y_pos, fig_name)
#         fig_name = f"./{name}/neg/{cls_stats_df.loc[pytorch_clsnum, 'string_label']}"
        cls_stats_df.loc[pytorch_clsnum, ['slope_neg', 'int_neg', 'r2_neg']] = fit_laplace(y_neg, fig_name=None)
    plt.close()
    
                         
    all_coeff_pos = cls_stats_df.loc[:, ['slope_pos', 'int_pos']].astype(float).values
    all_coeff_neg = cls_stats_df.loc[:, ['slope_neg', 'int_neg']].astype(float).values
                         
    # estimate mean for each class which is asymmetric laplace distribution
    # take mean of all estimated laplace means to obtain overall mean of the dataset
    k_est_pos = np.sqrt(np.exp(all_coeff_pos[:, 1])/(all_coeff_pos[:, 0]-np.exp(all_coeff_pos[:, 1])))
    laplace_lam_est_pos = all_coeff_pos[:, 0] * k_est_pos
    mean_pos = (1-k_est_pos**2)/(laplace_lam_est_pos*k_est_pos)


    k_est_neg = np.sqrt(np.exp(all_coeff_neg[:, 1])/(all_coeff_neg[:, 0]-np.exp(all_coeff_neg[:, 1])))
    laplace_lam_est_neg = all_coeff_neg[:, 0] * k_est_neg
    mean_neg = (1-k_est_neg**2)/(laplace_lam_est_neg*k_est_neg)

    marginal_mean_pos, marginal_mean_neg = mean_pos.mean(), mean_neg.mean()
    print(f"marginal_mean_pos:{marginal_mean_pos:.2f}, marginal_mean_neg:{marginal_mean_neg:.2f}")

    # estimate joint laplace
    fit_joint_laplace(fc_w)
    
    plt.figure()
    coeff = np.polyfit(all_coeff_pos[:, 0], all_coeff_pos[:, 1], 1)
    plt.scatter(all_coeff_pos[:, 0], all_coeff_pos[:, 1])
    plt.title(f"slope:{coeff[0]:.2f}, intercept:{coeff[1]:.2f}")
    plt.savefig(f"./{name}/pos_coeff_scatter.jpg")
    plt.close()
                         
    return cls_stats_df

In [7]:
timm_results = pd.read_csv('results-imagenet.csv')
model_names = timm_results.model.values
weight_dir = "/root/.cache/torch/hub/checkpoints/"

## Kaggle kernels have a disk usage limit, so we can't process all the available models on timm. We will randomly select 75 models for our analysis. Please feel free to run & share the results for all the models if you have sufficient compute.
Most of the time is taken while generating the plots. We are saving ALD plots for all the 1000 imagenet classes for every model architecture.

In [8]:
cls_stats_df = setup_imagenet_classes()
os.makedirs('./exp_fit_data/', exist_ok=True)
model_names = np.random.choice(model_names, 75)
for model_name in model_names:
    print("\nprocessing:", model_name)
    try:
        fc_w, fc_b = setup_weights(model_name)
        if fc_w is None:
            print('unable to extract weights')
            continue
        cls_stats_df = fit_exp(cls_stats_df, model_name, fc_w)
        cls_stats_df.to_csv(f"./exp_fit_data/{model_name}.csv")
    except Exception as e:
        print(e)
        continue    
    # clear the weights otherwise kernel will run out of memory
    for f in glob(f"{weight_dir}/{model_name}*.pth"):
        os.remove(f)   


processing: gluon_resnet101_v1s


Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth" to /root/.cache/torch/hub/checkpoints/gluon_resnet101_v1s-60fe0cc1.pth
  after removing the cwd from sys.path.


Input contains NaN, infinity or a value too large for dtype('float32').

processing: dm_nfnet_f3


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth" to /root/.cache/torch/hub/checkpoints/dm_nfnet_f3-d74ab3aa.pth


marginal_mean_pos:522.61, marginal_mean_neg:392.78
joint fit: mean_pos 519.62, r2_pos 0.87
joint fit: mean_neg 393.91, r2_neg 0.81

processing: gluon_resnext101_32x4d


Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth" to /root/.cache/torch/hub/checkpoints/gluon_resnext101_32x4d-b253c8c4.pth
  after removing the cwd from sys.path.


Input contains NaN, infinity or a value too large for dtype('float32').

processing: resnet26t


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth" to /root/.cache/torch/hub/checkpoints/resnet26t_256_ra2-6f6fa748.pth


marginal_mean_pos:243.92, marginal_mean_neg:180.13
joint fit: mean_pos 243.43, r2_pos 0.87
joint fit: mean_neg 180.37, r2_neg 0.82

processing: rexnet_100


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth" to /root/.cache/torch/hub/checkpoints/rexnetv1_100-1b4dddf4.pth


marginal_mean_pos:164.87, marginal_mean_neg:144.85
joint fit: mean_pos 164.00, r2_pos 0.84
joint fit: mean_neg 145.87, r2_neg 0.81

processing: efficientnet_el


Downloading: "https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_el.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_el.pth


marginal_mean_pos:232.81, marginal_mean_neg:196.72
joint fit: mean_pos 231.75, r2_pos 0.85
joint fit: mean_neg 197.89, r2_neg 0.82

processing: tf_efficientnet_b3_ap


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b3_ap-aad25bdd.pth


marginal_mean_pos:167.89, marginal_mean_neg:152.55
joint fit: mean_pos 167.19, r2_pos 0.82
joint fit: mean_neg 153.29, r2_neg 0.80

processing: gluon_senet154


Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth" to /root/.cache/torch/hub/checkpoints/gluon_senet154-70a1a3c0.pth


marginal_mean_pos:880.54, marginal_mean_neg:508.44
joint fit: mean_pos 842.25, r2_pos 0.88
joint fit: mean_neg 519.94, r2_neg 0.81

processing: fbnetc_100


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth" to /root/.cache/torch/hub/checkpoints/fbnetc_100-c345b898.pth


marginal_mean_pos:225.42, marginal_mean_neg:191.09
joint fit: mean_pos 224.55, r2_pos 0.84
joint fit: mean_neg 192.55, r2_neg 0.82

processing: ese_vovnet19b_dw


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth" to /root/.cache/torch/hub/checkpoints/ese_vovnet19b_dw-a8741004.pth


marginal_mean_pos:205.57, marginal_mean_neg:153.11
joint fit: mean_pos 204.49, r2_pos 0.87
joint fit: mean_neg 153.61, r2_neg 0.81

processing: gluon_resnet50_v1b


Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth" to /root/.cache/torch/hub/checkpoints/gluon_resnet50_v1b-0ebe02e2.pth
  after removing the cwd from sys.path.


Input contains NaN, infinity or a value too large for dtype('float32').

processing: dpn92


Downloading: "https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth" to /root/.cache/torch/hub/checkpoints/dpn92_extra-b040e4a9b.pth


['classifier.weight', 'classifier.bias'] (1000, 2688, 1, 1)
unable to extract fc layer
unable to extract weights

processing: vit_small_patch32_224
marginal_mean_pos:370.34, marginal_mean_neg:375.24
joint fit: mean_pos 371.35, r2_pos 0.83
joint fit: mean_neg 376.69, r2_neg 0.83

processing: tf_efficientnet_b8


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b8_ra-572d5dd9.pth


marginal_mean_pos:213.05, marginal_mean_neg:206.44
joint fit: mean_pos 212.80, r2_pos 0.79
joint fit: mean_neg 206.95, r2_neg 0.78

processing: resnet32ts


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth" to /root/.cache/torch/hub/checkpoints/resnet32ts_256-aacf5250.pth


marginal_mean_pos:189.86, marginal_mean_neg:150.13
joint fit: mean_pos 189.64, r2_pos 0.87
joint fit: mean_neg 150.60, r2_neg 0.82

processing: resnet33ts


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth" to /root/.cache/torch/hub/checkpoints/resnet33ts_256-e91b09a4.pth


marginal_mean_pos:155.69, marginal_mean_neg:126.92
joint fit: mean_pos 155.55, r2_pos 0.86
joint fit: mean_neg 127.52, r2_neg 0.82

processing: gluon_resnet34_v1b


Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth" to /root/.cache/torch/hub/checkpoints/gluon_resnet34_v1b-c6d82d59.pth


marginal_mean_pos:318.15, marginal_mean_neg:194.65
joint fit: mean_pos 316.14, r2_pos 0.89
joint fit: mean_neg 194.93, r2_neg 0.80

processing: ig_resnext101_32x16d


Downloading: "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth" to /root/.cache/torch/hub/checkpoints/ig_resnext101_32x16-c6f796b0.pth


marginal_mean_pos:319.86, marginal_mean_neg:313.28
joint fit: mean_pos 319.82, r2_pos 0.76
joint fit: mean_neg 313.23, r2_neg 0.75

processing: beit_base_patch16_224


Downloading: "https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth" to /root/.cache/torch/hub/checkpoints/beit_base_patch16_224_pt22k_ft22kto1k.pth


invalid load key, '='.

processing: levit_384


Downloading: "https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth" to /root/.cache/torch/hub/checkpoints/LeViT-384-9bdaf2e2.pth


marginal_mean_pos:73.17, marginal_mean_neg:73.12
joint fit: mean_pos 73.18, r2_pos 0.81
joint fit: mean_neg 73.12, r2_neg 0.81

processing: efficientnet_b4


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b4_ra2_320-7eb33cd5.pth


marginal_mean_pos:73.80, marginal_mean_neg:60.56
joint fit: mean_pos 73.72, r2_pos 0.85
joint fit: mean_neg 60.77, r2_neg 0.81

processing: tf_efficientnet_lite1


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_lite1-bde8b488.pth


marginal_mean_pos:190.99, marginal_mean_neg:163.46
joint fit: mean_pos 190.57, r2_pos 0.84
joint fit: mean_neg 164.14, r2_neg 0.81

processing: swin_tiny_patch4_window7_224


Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth" to /root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth


marginal_mean_pos:194.80, marginal_mean_neg:194.35
joint fit: mean_pos 194.99, r2_pos 0.81
joint fit: mean_neg 194.62, r2_neg 0.82

processing: crossvit_tiny_240


Downloading: "https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth" to /root/.cache/torch/hub/checkpoints/crossvit_tiny_224.pth


marginal_mean_pos:143.86, marginal_mean_neg:145.34
joint fit: mean_pos 153.85, r2_pos 0.85
joint fit: mean_neg 153.82, r2_neg 0.85

processing: ecaresnet50d


Downloading: "https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth" to /root/.cache/torch/hub/checkpoints/ECAResNet50D_833caf58.pth


marginal_mean_pos:1410.23, marginal_mean_neg:1031.97
joint fit: mean_pos 1408.94, r2_pos 0.88
joint fit: mean_neg 1041.34, r2_neg 0.86

processing: rexnet_200


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth" to /root/.cache/torch/hub/checkpoints/rexnetv1_200-8c0b7f2d.pth


marginal_mean_pos:272.47, marginal_mean_neg:214.88
joint fit: mean_pos 270.92, r2_pos 0.86
joint fit: mean_neg 216.43, r2_neg 0.82

processing: regnetx_320


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth" to /root/.cache/torch/hub/checkpoints/regnetx_320-8ea38b93.pth


marginal_mean_pos:534.14, marginal_mean_neg:329.62
joint fit: mean_pos 528.10, r2_pos 0.89
joint fit: mean_neg 332.69, r2_neg 0.81

processing: crossvit_15_dagger_408


Downloading: "https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth" to /root/.cache/torch/hub/checkpoints/crossvit_15_dagger_384.pth


marginal_mean_pos:127.10, marginal_mean_neg:125.34
joint fit: mean_pos 131.63, r2_pos 0.84
joint fit: mean_neg 131.49, r2_neg 0.84

processing: resnext26ts


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth" to /root/.cache/torch/hub/checkpoints/resnext26ts_256_ra2-8bbd9106.pth


marginal_mean_pos:200.97, marginal_mean_neg:153.95
joint fit: mean_pos 200.72, r2_pos 0.87
joint fit: mean_neg 154.26, r2_neg 0.82

processing: tf_mixnet_m


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth" to /root/.cache/torch/hub/checkpoints/tf_mixnet_m-0f4d8805.pth


marginal_mean_pos:212.40, marginal_mean_neg:173.89
joint fit: mean_pos 210.50, r2_pos 0.84
joint fit: mean_neg 175.41, r2_neg 0.80

processing: tf_efficientnet_b7_ns


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b7_ns-1dbc32de.pth


marginal_mean_pos:120.26, marginal_mean_neg:126.52
joint fit: mean_pos 120.05, r2_pos 0.70
joint fit: mean_neg 126.90, r2_neg 0.72

processing: levit_128


Downloading: "https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth" to /root/.cache/torch/hub/checkpoints/LeViT-128-b88c2750.pth


marginal_mean_pos:64.13, marginal_mean_neg:63.97
joint fit: mean_pos 64.09, r2_pos 0.81
joint fit: mean_neg 63.97, r2_neg 0.81

processing: tv_resnet50


Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


marginal_mean_pos:594.61, marginal_mean_neg:366.87
joint fit: mean_pos 590.94, r2_pos 0.89
joint fit: mean_neg 367.54, r2_neg 0.81

processing: dla60_res2net


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth" to /root/.cache/torch/hub/checkpoints/res2net_dla60_4s-d88db7f9.pth


['fc.weight', 'fc.bias'] (1000, 1024, 1, 1)
unable to extract fc layer
unable to extract weights

processing: vit_tiny_patch16_384
marginal_mean_pos:309.71, marginal_mean_neg:311.31
joint fit: mean_pos 309.49, r2_pos 0.83
joint fit: mean_neg 312.03, r2_neg 0.83

processing: nfnet_l0


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth" to /root/.cache/torch/hub/checkpoints/nfnet_l0_ra2-45c6688d.pth


marginal_mean_pos:229.24, marginal_mean_neg:178.50
joint fit: mean_pos 228.01, r2_pos 0.86
joint fit: mean_neg 179.47, r2_neg 0.80

processing: tf_efficientnetv2_m_in21ft1k


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnetv2_m_21ft1k-bf41664a.pth


marginal_mean_pos:490.70, marginal_mean_neg:374.87
joint fit: mean_pos 492.91, r2_pos 0.86
joint fit: mean_neg 376.84, r2_neg 0.82

processing: gcresnext26ts


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth" to /root/.cache/torch/hub/checkpoints/gcresnext26ts_256-e414378b.pth


marginal_mean_pos:217.26, marginal_mean_neg:177.71
joint fit: mean_pos 216.92, r2_pos 0.85
joint fit: mean_neg 178.36, r2_neg 0.81

processing: legacy_seresnet152


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth" to /root/.cache/torch/hub/checkpoints/se_resnet152-d17c99b7.pth


marginal_mean_pos:1061.08, marginal_mean_neg:768.71
joint fit: mean_pos 1059.76, r2_pos 0.89
joint fit: mean_neg 773.45, r2_neg 0.86

processing: dpn107


Downloading: "https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth" to /root/.cache/torch/hub/checkpoints/dpn107_extra-1ac7121e2.pth


['classifier.weight', 'classifier.bias'] (1000, 2688, 1, 1)
unable to extract fc layer
unable to extract weights

processing: dla46x_c


Downloading: "http://dl.yf.io/dla/models/imagenet/dla46x_c-d761bae7.pth" to /root/.cache/torch/hub/checkpoints/dla46x_c-d761bae7.pth


['fc.weight', 'fc.bias'] (1000, 256, 1, 1)
unable to extract fc layer
unable to extract weights

processing: tf_efficientnet_lite0


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_lite0-0aa007d2.pth


marginal_mean_pos:185.57, marginal_mean_neg:158.30
joint fit: mean_pos 185.24, r2_pos 0.84
joint fit: mean_neg 158.84, r2_neg 0.81

processing: gernet_s


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth" to /root/.cache/torch/hub/checkpoints/gernet_s-756b4751.pth


marginal_mean_pos:314.21, marginal_mean_neg:260.61
joint fit: mean_pos 315.98, r2_pos 0.86
joint fit: mean_neg 260.85, r2_neg 0.84

processing: jx_nest_base


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth" to /root/.cache/torch/hub/checkpoints/jx_nest_base-8bc41011.pth


marginal_mean_pos:120.69, marginal_mean_neg:122.29
joint fit: mean_pos 120.79, r2_pos 0.81
joint fit: mean_neg 122.53, r2_neg 0.81

processing: hardcorenas_d


Downloading: "https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth" to /root/.cache/torch/hub/checkpoints/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth


marginal_mean_pos:140.34, marginal_mean_neg:134.81
joint fit: mean_pos 140.62, r2_pos 0.83
joint fit: mean_neg 135.11, r2_neg 0.82

processing: gluon_seresnext101_64x4d


Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth" to /root/.cache/torch/hub/checkpoints/gluon_seresnext101_64x4d-f9926f93.pth


marginal_mean_pos:1670.11, marginal_mean_neg:1150.41
joint fit: mean_pos 1668.65, r2_pos 0.89
joint fit: mean_neg 1169.87, r2_neg 0.86

processing: gluon_senet154


Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth" to /root/.cache/torch/hub/checkpoints/gluon_senet154-70a1a3c0.pth


marginal_mean_pos:880.54, marginal_mean_neg:508.44
joint fit: mean_pos 842.25, r2_pos 0.88
joint fit: mean_neg 519.94, r2_neg 0.81

processing: mobilenetv3_large_100


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth" to /root/.cache/torch/hub/checkpoints/mobilenetv3_large_100_ra-f55367f5.pth


marginal_mean_pos:162.29, marginal_mean_neg:156.33
joint fit: mean_pos 162.38, r2_pos 0.82
joint fit: mean_neg 156.96, r2_neg 0.82

processing: xcit_tiny_12_p16_224


Downloading: "https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth" to /root/.cache/torch/hub/checkpoints/xcit_tiny_12_p16_224.pth


marginal_mean_pos:217.00, marginal_mean_neg:216.53
joint fit: mean_pos 216.92, r2_pos 0.81
joint fit: mean_neg 216.88, r2_neg 0.81

processing: gluon_resnet50_v1s


Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth" to /root/.cache/torch/hub/checkpoints/gluon_resnet50_v1s-1762acc0.pth


marginal_mean_pos:879.64, marginal_mean_neg:463.20
joint fit: mean_pos 871.76, r2_pos 0.90
joint fit: mean_neg 464.59, r2_neg 0.80

processing: convnext_tiny_hnf
marginal_mean_pos:541.60, marginal_mean_neg:537.27
joint fit: mean_pos 541.45, r2_pos 0.83
joint fit: mean_neg 536.94, r2_neg 0.83

processing: resnet26


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth" to /root/.cache/torch/hub/checkpoints/resnet26-9aa10e23.pth


marginal_mean_pos:743.78, marginal_mean_neg:601.14
joint fit: mean_pos 748.15, r2_pos 0.88
joint fit: mean_neg 601.56, r2_neg 0.85

processing: tf_efficientnet_b7_ns


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b7_ns-1dbc32de.pth


marginal_mean_pos:120.26, marginal_mean_neg:126.52
joint fit: mean_pos 120.05, r2_pos 0.70
joint fit: mean_neg 126.90, r2_neg 0.72

processing: tf_efficientnetv2_b1


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnetv2_b1-be6e41b0.pth


marginal_mean_pos:184.74, marginal_mean_neg:156.53
joint fit: mean_pos 183.57, r2_pos 0.84
joint fit: mean_neg 157.51, r2_neg 0.81

processing: tf_efficientnet_b6


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b6_aa-80ba17e4.pth


marginal_mean_pos:191.33, marginal_mean_neg:171.96
joint fit: mean_pos 190.75, r2_pos 0.80
joint fit: mean_neg 172.81, r2_neg 0.78

processing: vit_small_patch32_224
marginal_mean_pos:370.34, marginal_mean_neg:375.24
joint fit: mean_pos 371.35, r2_pos 0.83
joint fit: mean_neg 376.69, r2_neg 0.83

processing: spnasnet_100


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth" to /root/.cache/torch/hub/checkpoints/spnasnet_100-048bc3f4.pth


marginal_mean_pos:315.89, marginal_mean_neg:265.84
joint fit: mean_pos 317.22, r2_pos 0.86
joint fit: mean_neg 267.11, r2_neg 0.83

processing: volo_d5_224
Unknown model (volo_d5_224)

processing: seresnext50_32x4d


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext50_32x4d_racm-a304a460.pth" to /root/.cache/torch/hub/checkpoints/seresnext50_32x4d_racm-a304a460.pth


marginal_mean_pos:338.28, marginal_mean_neg:224.15
joint fit: mean_pos 333.70, r2_pos 0.89
joint fit: mean_neg 225.16, r2_neg 0.83

processing: convnext_large_in22ft1k


Downloading: "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth" to /root/.cache/torch/hub/checkpoints/convnext_large_22k_1k_224.pth


marginal_mean_pos:437.50, marginal_mean_neg:442.05
joint fit: mean_pos 438.85, r2_pos 0.80
joint fit: mean_neg 443.51, r2_neg 0.80

processing: densenet169


Downloading: "https://download.pytorch.org/models/densenet169-b2777c0a.pth" to /root/.cache/torch/hub/checkpoints/densenet169-b2777c0a.pth


marginal_mean_pos:2564693.06, marginal_mean_neg:1361543.22
joint fit: mean_pos 2534249.59, r2_pos 0.97
joint fit: mean_neg 1367003.07, r2_neg 0.95

processing: tv_resnet34


Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


marginal_mean_pos:272.32, marginal_mean_neg:173.07
joint fit: mean_pos 271.41, r2_pos 0.89
joint fit: mean_neg 172.97, r2_neg 0.80

processing: eca_nfnet_l2


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth" to /root/.cache/torch/hub/checkpoints/ecanfnet_l2_ra3-da781a61.pth


marginal_mean_pos:214.29, marginal_mean_neg:184.44
joint fit: mean_pos 213.88, r2_pos 0.83
joint fit: mean_neg 185.20, r2_neg 0.79

processing: haloregnetz_b


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth" to /root/.cache/torch/hub/checkpoints/haloregnetz_c_raa_256-c8ad7616.pth


marginal_mean_pos:370.26, marginal_mean_neg:288.11
joint fit: mean_pos 370.83, r2_pos 0.86
joint fit: mean_neg 287.37, r2_neg 0.85

processing: eca_halonext26ts


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth" to /root/.cache/torch/hub/checkpoints/eca_halonext26ts_c_256-06906299.pth


marginal_mean_pos:322.26, marginal_mean_neg:266.88
joint fit: mean_pos 321.42, r2_pos 0.86
joint fit: mean_neg 267.44, r2_neg 0.83

processing: xcit_nano_12_p16_224


Downloading: "https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth" to /root/.cache/torch/hub/checkpoints/xcit_nano_12_p16_224.pth


marginal_mean_pos:212.22, marginal_mean_neg:211.51
joint fit: mean_pos 212.60, r2_pos 0.81
joint fit: mean_neg 211.48, r2_neg 0.81

processing: volo_d2_224
Unknown model (volo_d2_224)

processing: resmlp_12_distilled_224


Downloading: "https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth" to /root/.cache/torch/hub/checkpoints/resmlp_12_dist.pth


marginal_mean_pos:203.81, marginal_mean_neg:206.77
joint fit: mean_pos 203.84, r2_pos 0.82
joint fit: mean_neg 206.75, r2_neg 0.82

processing: hrnet_w18


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth" to /root/.cache/torch/hub/checkpoints/hrnetv2_w18-8cb57bb9.pth


marginal_mean_pos:526.82, marginal_mean_neg:314.83
joint fit: mean_pos 524.45, r2_pos 0.89
joint fit: mean_neg 315.48, r2_neg 0.80

processing: resnetblur50


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth" to /root/.cache/torch/hub/checkpoints/resnetblur50-84f4748f.pth


marginal_mean_pos:994.04, marginal_mean_neg:789.60
joint fit: mean_pos 997.74, r2_pos 0.88
joint fit: mean_neg 795.49, r2_neg 0.85

processing: resnet50


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth" to /root/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth


marginal_mean_pos:94.12, marginal_mean_neg:39.65
joint fit: mean_pos 93.93, r2_pos 0.87
joint fit: mean_neg 39.67, r2_neg 0.82

processing: dla60x_c


Downloading: "http://dl.yf.io/dla/models/imagenet/dla60x_c-b870c45c.pth" to /root/.cache/torch/hub/checkpoints/dla60x_c-b870c45c.pth


['fc.weight', 'fc.bias'] (1000, 256, 1, 1)
unable to extract fc layer
unable to extract weights

processing: resnet51q


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth" to /root/.cache/torch/hub/checkpoints/resnet51q_ra2-d47dcc76.pth


marginal_mean_pos:208.97, marginal_mean_neg:165.74
joint fit: mean_pos 208.11, r2_pos 0.86
joint fit: mean_neg 166.71, r2_neg 0.81

processing: coat_lite_small


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth" to /root/.cache/torch/hub/checkpoints/coat_lite_small-fea1d5a1.pth


marginal_mean_pos:167.48, marginal_mean_neg:166.65
joint fit: mean_pos 167.70, r2_pos 0.81
joint fit: mean_neg 166.87, r2_neg 0.81

processing: mixnet_m


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth" to /root/.cache/torch/hub/checkpoints/mixnet_m-4647fc68.pth


marginal_mean_pos:217.96, marginal_mean_neg:179.54
joint fit: mean_pos 216.34, r2_pos 0.84
joint fit: mean_neg 180.98, r2_neg 0.81


<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

### Zip the generated results and clean directory

In [9]:
!zip -rq out.zip ./

In [10]:
import shutil
all_dir = glob("./*/")
for dir in all_dir:
    shutil.rmtree(dir)