In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch 
import torchvision
import pandas as pd
import numpy as np
import functools


from pathlib import Path
from collections import defaultdict

import core.experiment
import nb_common

from sklearn.metrics.pairwise import cosine_similarity


from pytorch_utils.logging import LoggerReader
from pytorch_utils.evaluation import apply_model, argmax_and_accuracy
from nb_common import load_experiment_context, compute_latent

from pytorch_utils.ipynb import plt_img_grid

In [3]:
DEVICE = 'cuda:0'
data_root = Path('/scratch2/chofer/data/cifar10-c/')
res_root = Path('/home/pma/chofer/repositories/py_supcon_vs_ce/results_xmas_performance/')



args_white_list = {
    'num_batches',
    'batch_size',
    'tag', 
    'weight_decay',
    'ds_train',
    'ds_test' ,
    'augment',
    'label_noise_fraction',
    'scheduler',
}

args_simple = {
    'model_comp': lambda a: a['model'][1]['compactification_cfg'][0], 
    'model_lin': lambda a: a['model'][1]['linear_cfg'][0], 
    'loss': lambda a: a['losses'][0][0],
}

args_df_from_results = functools.partial(nb_common.args_df_from_results, args_white_list=args_white_list, args_simple=args_simple)

load_results = functools.partial(nb_common.load_results, root=res_root)

RESULTS = load_results()

In [4]:
class NpyImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_file_pth, label_file_pth):
        self.X = torch.from_numpy(np.load(image_file_pth)).permute(0, 3, 1, 2).float()/256
        self.Y = torch.from_numpy(np.load(label_file_pth)).long()
        
    def __len__(self):
        return self.X.size(0)
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx].item()      

class CorruptedDatasetFactory():
    def __init__(self, root):
        self.root = root      
        
        self.dataset_names = [p.name.split('.')[0] for p in root.glob('*.npy') if p.name != 'labels.npy']
        
    def __iter__(self):
        return iter(self.dataset_names)
    
    def __call__(self, dataset_name):
        return self[dataset_name]
        
    def __getitem__(self, idx):
        label_file_pth = self.root/'labels.npy'
        data_file_pth = str(self.root/idx) + '.npy'
        
        return NpyImageDataset(data_file_pth, label_file_pth)

In [5]:
def compute_corruption_performance(result):

    run_i = 0
    exp_context = load_experiment_context(result.path, run_i=run_i)
    normalize = exp_context['ds_test'].transform.transforms[-1]
    assert isinstance(normalize, torchvision.transforms.Normalize)

    model = exp_context['model']
    linear = result.load_model(0, 'retrained_linear')
    model.cls = linear

    ds_factory = CorruptedDatasetFactory(data_root)

    ret = {}
    for k in ds_factory:
        print(k)
        ds = ds_factory[k]
        Y_hat, Y = apply_model(model, ds, device=DEVICE)
        acc = argmax_and_accuracy(Y_hat, Y)

        ret[k] = acc
        
    ret['score'] = np.mean(list(ret.values()))
    return ret

In [6]:
args_df = args_df_from_results(load_results())
tmp = args_df.query("ds_train == 'cifar10_train' and batch_size == 256 and scheduler == 'cosine'")
idxs = tmp.index.tolist()
idxs
args_df.iloc[idxs]

Unnamed: 0,model_comp,model_lin,loss,num_batches,tag,weight_decay,ds_train,ds_test,augment,batch_size,scheduler,progress
1,sphere_l2,Linear,SupConLoss,100000,performance,0.0001,cifar10_train,cifar10_test,none,256,cosine,True
3,sphere_l2,Linear,SupConLoss,100000,performance,0.0001,cifar10_train,cifar10_test,standard,256,cosine,True
10,none,Linear,CrossEntropy,100000,performance,0.0001,cifar10_train,cifar10_test,none,256,cosine,True
11,none,Linear,CrossEntropy,100000,performance,0.0001,cifar10_train,cifar10_test,standard,256,cosine,True
18,sphere_l2,Linear,CrossEntropy,100000,performance,0.0001,cifar10_train,cifar10_test,none,256,cosine,True
19,sphere_l2,Linear,CrossEntropy,100000,performance,0.0001,cifar10_train,cifar10_test,standard,256,cosine,True
52,none,FixedSphericalSimplexLinear,CrossEntropy,100000,performance_rerun_fixed_weights,0.0001,cifar10_train,cifar10_test,none,256,cosine,True
54,none,FixedSphericalSimplexLinear,CrossEntropy,100000,performance_rerun_fixed_weights,0.0001,cifar10_train,cifar10_test,standard,256,cosine,True


In [7]:
tmp = []
for i, idx in enumerate(idxs):
    print(i+1, '/', len(idxs))
    d = args_df.iloc[idx]
    d = d.to_dict()
    corruption_performance = compute_corruption_performance(RESULTS[idx])
    
    d.update(corruption_performance)
    d = pd.DataFrame(d, index=[idx])
    tmp.append(d)
    
result = pd.concat(tmp)

1 / 8
glass_blur
zoom_blur
pixelate
gaussian_blur
impulse_noise
spatter
frost
motion_blur
speckle_noise
snow
elastic_transform
brightness
fog
shot_noise
contrast
gaussian_noise
defocus_blur
jpeg_compression
saturate
2 / 8
glass_blur
zoom_blur
pixelate
gaussian_blur
impulse_noise
spatter
frost
motion_blur
speckle_noise
snow
elastic_transform
brightness
fog
shot_noise
contrast
gaussian_noise
defocus_blur
jpeg_compression
saturate
3 / 8
glass_blur
zoom_blur
pixelate
gaussian_blur
impulse_noise
spatter
frost
motion_blur
speckle_noise
snow
elastic_transform
brightness
fog
shot_noise
contrast
gaussian_noise
defocus_blur
jpeg_compression
saturate
4 / 8
glass_blur
zoom_blur
pixelate
gaussian_blur
impulse_noise
spatter
frost
motion_blur
speckle_noise
snow
elastic_transform
brightness
fog
shot_noise
contrast
gaussian_noise
defocus_blur
jpeg_compression
saturate
5 / 8
glass_blur
zoom_blur
pixelate
gaussian_blur
impulse_noise
spatter
frost
motion_blur
speckle_noise
snow
elastic_transform
brightnes

In [16]:
result[['loss', 'model_lin', 'model_comp', 'augment', 'score']].groupby(['loss', 'model_lin', 'model_comp', 'augment']).sum()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,score
loss,model_lin,model_comp,augment,Unnamed: 4_level_1
CrossEntropy,FixedSphericalSimplexLinear,none,none,32.351789
CrossEntropy,FixedSphericalSimplexLinear,none,standard,39.126526
CrossEntropy,Linear,none,none,29.231263
CrossEntropy,Linear,none,standard,30.463263
CrossEntropy,Linear,sphere_l2,none,36.072737
CrossEntropy,Linear,sphere_l2,standard,34.735579
SupConLoss,Linear,sphere_l2,none,35.359895
SupConLoss,Linear,sphere_l2,standard,33.799579
