In [67]:
from pathlib import Path
from pruneshift.networks import create_network
from pruneshift.datamodules import ShiftDataModule
from collections import defaultdict
from functools import partial
from itertools import chain
from tqdm.notebook import tqdm
from pruneshift.losses import ActivationCollector
from pruneshift.network_markers import classifier
import pytorch_lightning as pl
import submitit
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import ConcatDataset, DataLoader


model_path = Path("/work/dlclarge2/hoffmaja-pruneshift/models/")
dataset_path = Path("/work/dlclarge2/hoffmaja-pruneshift/datasets/")
activation_path = Path("/work/dlclarge2/hoffmaja-pruneshift/activations/")
img100_path = dataset_path / "ILSVRC2012-100"
img100_train_path = img100_path / "train"
deepaugment_path = "/data/datasets/DeepAugment"
# data_amda_path = model_path / "/imagenetr_models/deepaugment_and_augmix.pth.tar"
data_swsl_path = activation_path / "img100_swsl_resnet50.npy"
data_amda_path = activation_path / "img100_amda_resnet50.npy"
data_std_path = activation_path / "img100_std_resnet50.npy"

amda_path = model_path / "imagenetr_models/augmix.pth.tar"

In [68]:
INTRESTING_NETS = {
    "Augmix": "/work/dlclarge2/hoffmaja-pruneshift/models/imagenetr_models/augmix.pth.tar",
    "DeepAugment": "/work/dlclarge2/hoffmaja-pruneshift/models/imagenetr_models/deepaugment.pth.tar",
    "Amda": "/work/dlclarge2/hoffmaja-pruneshift/models/imagenetr_models/deepaugment_and_augmix.pth.tar",
    "Swsl": "/work/dlclarge2/hoffmaja-pruneshift/models/torch_home/hub/checkpoints/semi_weakly_supervised_resnet50-16a12f1b.pth",
    "Standard": "/work/dlclarge2/hoffmaja-pruneshift/models/torch_home/hub/checkpoints/resnet50-19c8e357.pth"}


def create_executor(folder, num_gpus):
    executor = submitit.AutoExecutor(folder=folder)
    executor.update_parameters(nodes=1,
                               gpus_per_node=num_gpus,
                               slurm_partition="lmbdlc_gpu-rtx2080",
                               slurm_array_parallelism=35,
                               timeout_min=100)
    return executor

In [71]:
MODES = ["train", "augmix", "deepaugment", "amda", "test", "renditions", "corrupted"]
TEST_MODES = MODES[-3:]

def dataloaders(mode: str, batch_size: int):
    assert mode in MODES
    
    kws = {}
    kws["test_renditions"] = True if mode == "renditions" else False
    kws["test_corrupted"] = True if mode == "corrupted" else False
    kws["augmix"] = "no_jsd" if mode == "augmix" or mode == "amda" else False
    kws["deepaugment_path"] = deepaugment_path if mode == "deepaugment" or mode == "amda" else None
        
    dm = ShiftDataModule("imagenet", img100_path, batch_size=batch_size, **kws)
    
    if mode in TEST_MODES:
        dm.setup("test")
        
        if mode == "test":
            dataset = dm.test_datasets[0]
        elif mode == "renditions":
            dataset = dm.test_datasets[1]
        else:
            dataset = ConcatDataset(dm.test_datasets[1 :])
            
        return DataLoader(dataset, batch_size, True, num_workers=2)
    
    dm.setup("fit")
    return dm.train_dataloader()
        


def collect_samples(network_name,
                    network_path,
                    mode: str,
                    with_logits: bool = True,
                    with_features: bool = True,
                    num_samples: int = 256,
                    batch_size: int = 256):
    """ Makes some forward passes on ImageNet100"""
    net = create_network("imagenet", "resnet50", 100, ckpt_path=network_path, imagenet_subset=True)
    device = torch.device("cuda:0")

    print("Create datamodule.")

    loader = dataloaders(mode, batch_size)

    num_samples = num_samples // batch_size
    collector = ActivationCollector({"fc": classifier(net)}, mode="in")

    net = net.to(device)
    # Meta Information
    info = defaultdict(list)
    # Logits and Activations
    # data = defaultdict(list)

    for _, batch in tqdm(zip(range(num_samples), loader)):
        _, x, y = batch
        x = x.to(device)

        with torch.no_grad():
            logits_batch = net(x)

        for label, activations, logits in zip(y, collector["fc"], logits_batch):
            # if train
            
            info["Class Idx"].append(label.item())
            info["Network Name"].append(network_name)
            info["Distribution"].append(mode)
            info["Logits"].append(logits.cpu().numpy())
            info["Features"].append(activations.cpu().numpy())

        collector.reset()

    df_info = pd.DataFrame(info)
    # df_info.columns = pd.MultiIndex.from_product([["Info"], df_info.columns])
    return df_info
    
    num_logits = 100
    num_features = data["Features"][0].shape[-1]

    df_logits = pd.DataFrame(np.stack(data["Logits"]))
    df_logits.columns = pd.MultiIndex.from_product([["Logits"], range(num_logits)])
    df_features = pd.DataFrame(np.stack(data["Features"]))
    df_features.columns = pd.MultiIndex.from_product([["Features"], range(num_features)])

    df_parts = [df_info]
    
    if with_features:
        df_parts.append(df_logits)
    
    if with_logits:
        df_parts.append(df_features)
    
    return pd.concat([df_info, df_logits, df_features], axis=1)


In [72]:
# We want to sweep over augmix and deepaugment, train, test
def sweep(**kws):
    executor = create_executor("./submitit_runs", 1)
    names = list(INTRESTING_NETS.keys())
    paths = [INTRESTING_NETS[n] for n in INTRESTING_NETS]

    jobs = []
    with executor.batch():
        for mode in MODES:
            for name, path in INTRESTING_NETS.items():
                job = executor.submit(collect_samples, name, path, mode, **kws)
                jobs.append(job)

    return pd.concat([j.result() for j in jobs])

In [73]:
df = sweep(num_samples=256*20)

In [66]:
df["Logits"].iloc[0]

array([-2.1403482e+00, -1.5317304e-01, -1.8368506e+00, -1.1647618e+00,
       -2.0320818e+00, -1.3396443e+00, -4.2757797e+00, -1.3229892e+00,
       -2.7466087e+00, -7.7042747e-01, -4.2533612e+00, -3.5652115e+00,
       -3.1124725e+00, -9.1610336e-01, -2.5601909e+00, -2.7281718e+00,
       -1.4540725e+00,  1.6595414e-01, -9.4827637e-02, -6.0758024e-02,
       -2.4741373e+00, -1.5408002e+00, -2.5598581e+00, -1.1964335e+00,
       -8.5329264e-01, -3.1348257e+00,  2.7121606e+00, -3.3187635e+00,
       -4.3756020e-01, -1.7142701e-01, -1.6227074e-01, -3.8952718e+00,
       -1.2373120e+00, -1.7866646e+00, -1.0065330e+00, -1.8398194e+00,
       -9.3951029e-01, -4.4713203e-02,  9.5957023e-01,  9.8182306e-02,
       -2.0947974e+00,  1.0645647e-01,  8.8812327e-01, -1.0044398e+00,
        1.6820704e+00,  5.2743971e-01,  2.4413630e-01, -4.5053613e-01,
        2.6757627e+00, -7.0196742e-01,  1.9724169e+00, -1.7297271e+00,
        7.2338301e-01, -1.1591847e+00, -5.2212596e-01,  9.0325028e-03,
      

In [74]:
df.to_pickle("short_features_logits.pkl")

In [61]:
df.to_csv("std.csv")

In [59]:
df.to_feather("short_features_logits.bird")

ValueError: feather does not support serializing a non-default index for the index; you can .reset_index() to make the index into column(s)

In [51]:
job = executor.submit(collect_samples, amda_path, num_samples=1000)

In [53]:
df = job.result()

In [54]:
df["Info"]

Unnamed: 0,Class Idx,DeepAugment,Augmix,Train
0,33,False,False,True
1,60,False,False,True
2,42,False,False,True
3,42,False,False,True
4,74,False,False,True
...,...,...,...,...
763,40,False,False,True
764,98,False,False,True
765,93,False,False,True
766,81,False,False,True


In [None]:
executor = create_executor("./test_runs", 1)
jobs = executor.map_array(collect_samples, paths)