In [18]:
from pathlib import Path
from pruneshift.networks import create_network
from pruneshift.datamodules import ShiftDataModule
from collections import defaultdict
from functools import partial
from itertools import chain
from tqdm.notebook import tqdm
from pruneshift.losses import ActivationCollector
from pruneshift.network_markers import classifier
import pytorch_lightning as pl
import submitit
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import ConcatDataset, DataLoader


model_path = Path("/work/dlclarge2/hoffmaja-pruneshift/models/")
dataset_path = Path("/work/dlclarge2/hoffmaja-pruneshift/datasets/")
activation_path = Path("/work/dlclarge2/hoffmaja-pruneshift/activations/")
img100_path = dataset_path / "ILSVRC2012-100"
img100_train_path = img100_path / "train"
deepaugment_path = "/data/datasets/DeepAugment"
# data_amda_path = model_path / "/imagenetr_models/deepaugment_and_augmix.pth.tar"
data_swsl_path = activation_path / "img100_swsl_resnet50.npy"
data_amda_path = activation_path / "img100_amda_resnet50.npy"
data_std_path = activation_path / "img100_std_resnet50.npy"

amda_path = model_path / "imagenetr_models/augmix.pth.tar"

In [19]:
INTRESTING_NETS = {
    "ResNet18": "/work/dlclarge2/hoffmaja-pruneshift/experiments/img100/workshop/baselines/1times_amda/0/checkpoint/last.ckpt",
    "KD_ResNet18": "/work/dlclarge2/agnihotr-shashank-pruneshift/hoffmaja-pruneshift/experiments/img100/workshop/amda_distillation/0/checkpoint/last.ckpt",
    "AT_ResNet18": "/work/dlclarge2/hoffmaja-pruneshift/experiments/img100/workshop/tryout_at/1/checkpoint/last.ckpt",
    "CRD_ResNet18": "/work/dlclarge2/agnihotr-shashank-pruneshift/hoffmaja-pruneshift/shashank_runs/imagenet100_runs/amda_teacher_amda_student_crd/crd_1.0/kd_1.0/checkpoint/last.ckpt",
    "SupCon_kd_augmix": "/work/dlclarge2/agnihotr-shashank-pruneshift/runs_supcon/imagenet100_again/kd_augmix/classification_amda/checkpoint/last.ckpt",
    "SupCon_augmix" : "/work/dlclarge2/agnihotr-shashank-pruneshift/runs_supcon/imagenet100_again/augmix/classification_amda/checkpoint/last.ckpt"}

#"Amda_ResNet50": "/work/dlclarge2/hoffmaja-pruneshift/models/imagenetr_models/deepaugment_and_augmix.pth.tar",
def create_executor(folder, num_gpus):
    executor = submitit.AutoExecutor(folder=folder)
    executor.update_parameters(nodes=1,
                               gpus_per_node=num_gpus,
                               slurm_partition="lmbdlc_gpu-rtx2080",
                               slurm_array_parallelism=35,
                               timeout_min=100)
    return executor

In [20]:
MODES = ["amda", "test", "renditions", "corrupted"]
TEST_MODES = MODES[-3:]

def dataloaders(mode: str, batch_size: int):
    assert mode in MODES
    
    kws = {}
    kws["test_renditions"] = True if mode == "renditions" else False
    kws["test_corrupted"] = True if mode == "corrupted" else False
    kws["augmix"] = "no_jsd" if mode == "augmix" or mode == "amda" else False
    kws["deepaugment_path"] = deepaugment_path if mode == "deepaugment" or mode == "amda" else None
        
    dm = ShiftDataModule("imagenet", img100_path, batch_size=batch_size, **kws)
    
    if mode in TEST_MODES:
        dm.setup("test")
        
        if mode == "test":
            dataset = dm.test_datasets[0]
        elif mode == "renditions":
            dataset = dm.test_datasets[1]
        else:
            dataset = ConcatDataset(dm.test_datasets[1 :])
            
        return DataLoader(dataset, batch_size, True, num_workers=2)
    
    dm.setup("fit")
    return dm.train_dataloader()
        


def collect_samples(network_name,
                    network_path,
                    mode: str,
                    with_logits: bool = True,
                    with_features: bool = True,
                    num_samples: int = 256,
                    batch_size: int = 256):
    """ Makes some forward passes on ImageNet100""" #, imagenet_subset_path=img100_path, imagnet_path=img100_train_path
    if '50' in network_name:
        net = create_network("imagenet", "resnet50", 100, ckpt_path=network_path)
    elif 'SupCon' in network_name:
        net = create_network("imagenet", "resnet18", 100, ckpt_path=network_path, supConLoss=True, classifying=True, loading_final_supcon=True)
    else:
        net = create_network("imagenet", "resnet18", 100, ckpt_path=network_path)
    device = torch.device("cuda:0")

    print("Create datamodule.")

    loader = dataloaders(mode, batch_size)

    num_samples = num_samples // batch_size
    collector = ActivationCollector({"fc": classifier(net)}, mode="in")

    net = net.to(device)
    # Meta Information
    info = defaultdict(list)
    # Logits and Activations
    # data = defaultdict(list)

    for _, batch in tqdm(zip(range(num_samples), loader)):
        _, x, y = batch
        x = x.to(device)

        with torch.no_grad():
            logits_batch = net(x)

        for label, activations, logits in zip(y, collector["fc"], logits_batch):
            # if train
            
            info["Class Idx"].append(label.item())
            info["Network Name"].append(network_name)
            info["Distribution"].append(mode)
            info["Logits"].append(logits.cpu().numpy())
            info["Features"].append(activations.cpu().numpy())

        collector.reset()

    df_info = pd.DataFrame(info)
    # df_info.columns = pd.MultiIndex.from_product([["Info"], df_info.columns])
    return df_info
    
    num_logits = 100
    num_features = data["Features"][0].shape[-1]

    df_logits = pd.DataFrame(np.stack(data["Logits"]))
    df_logits.columns = pd.MultiIndex.from_product([["Logits"], range(num_logits)])
    df_features = pd.DataFrame(np.stack(data["Features"]))
    df_features.columns = pd.MultiIndex.from_product([["Features"], range(num_features)])

    df_parts = [df_info]
    
    if with_features:
        df_parts.append(df_logits)
    
    if with_logits:
        df_parts.append(df_features)
    
    return pd.concat([df_info, df_logits, df_features], axis=1)


In [21]:
# We want to sweep over augmix and deepaugment, train, test
def sweep(**kws):
    executor = create_executor("./submitit_runs", 1)
    names = list(INTRESTING_NETS.keys())
    paths = [INTRESTING_NETS[n] for n in INTRESTING_NETS]

    jobs = []
    with executor.batch():
        for mode in MODES:
            for name, path in INTRESTING_NETS.items():
                job = executor.submit(collect_samples, name, path, mode, **kws)
                jobs.append(job)

    return pd.concat([j.result() for j in jobs])

In [23]:
df = sweep(num_samples=256*20)

In [24]:
df["Logits"].iloc[0]

array([-1.2522832 ,  8.020609  ,  3.7160256 ,  1.5858282 , -0.6066107 ,
       -1.8614795 , -2.0555408 ,  1.303268  ,  4.309605  ,  1.9008783 ,
       -2.6020973 ,  0.11717378, -2.502744  ,  0.05916312, -2.641747  ,
       -1.3596619 , -1.8055562 , -0.8236914 , -2.0011315 ,  0.22227022,
       -3.303973  , -0.3437702 , -2.0403488 ,  1.5712047 ,  4.6484623 ,
        3.1017349 , -2.068096  , -0.8709048 ,  0.46873292, -2.2186825 ,
       -1.4154774 ,  0.14736672,  3.058052  ,  5.0133076 ,  4.577844  ,
       12.068056  , 14.700689  ,  5.826222  , -0.26689857, -0.11844517,
       -0.50574195, -0.71693957, -0.38355532, -1.1683133 , -0.5376086 ,
       -0.77848536, -0.2509188 , -1.7561966 ,  0.976316  , -0.07694538,
        1.6302834 , -0.4737046 ,  0.66315603, -0.24893896,  0.19256271,
        1.659791  , -1.0942785 , -0.4849017 , -0.27781862,  0.49353683,
        0.33059588,  1.029828  , -0.58503395, -1.1760771 , -0.69789886,
       -0.60041296, -0.17096189, -0.3012425 ,  0.92729354, -0.49

In [25]:
df.to_pickle("supcon_analysis.pkl")

In [61]:
df.to_csv("std.csv")

In [59]:
df.to_feather("short_features_logits.bird")

ValueError: feather does not support serializing a non-default index for the index; you can .reset_index() to make the index into column(s)

In [51]:
job = executor.submit(collect_samples, amda_path, num_samples=1000)

In [53]:
df = job.result()

In [54]:
df["Info"]

Unnamed: 0,Class Idx,DeepAugment,Augmix,Train
0,33,False,False,True
1,60,False,False,True
2,42,False,False,True
3,42,False,False,True
4,74,False,False,True
...,...,...,...,...
763,40,False,False,True
764,98,False,False,True
765,93,False,False,True
766,81,False,False,True


In [None]:
executor = create_executor("./test_runs", 1)
jobs = executor.map_array(collect_samples, paths)