In [26]:
# Adding the correct hyperparameter to the experiments.
# Find the checkpoints.
# Write a load function.
# Write evaluation loop with a Trainer.
# Make a plot.

# We need a smart load pruned model function, that finds out which modules, param pairs have been pruned. 30 minutes
# We need a way to smartly 
# We should log everything into  test folder.
from pathlib import Path
import yaml
import re
import os
import torch.nn.utils.prune as torch_prune
import torch
import pruneshift.prune as prune
from pruneshift.prune_info import PruneInfo
from pruneshift.topologies import network_topology
from pruneshift.modules import VisionModule
from pruneshift.datamodules import datamodule
from pytorch_lightning.loggers import CSVLogger
import pytorch_lightning as pl


dataset_path = os.environ["DATASET_PATH"]
CKPT_FILENAME_REGEX = r"epoch=(?P<epoch>\d+)[-]val_acc=(?P<val_acc>\d+\.\d+)"
# Collect all paths.
base = Path("/misc/lmbraid19/hoffmaja/prune_shift/data")

In [27]:
exp_dirs = [ip for op in base.glob("oneshot_global_weight*") for ip in op.glob("oneshot*")]

In [87]:
data = datamodule("cifar10_corrupted", dataset_path)

def find_stuff(path: Path):
    paths, val_accs, epochs = [], [], []
    for p in (path/"checkpoints").glob("epoch*"):
        epochs.append(re.match(CKPT_FILENAME_REGEX, p.stem)["epoch"])
        val_accs.append(re.match(CKPT_FILENAME_REGEX, p.stem)["val_acc"])
        paths.append(p)
    
    with open(path/"hparams.yaml", "r") as file:
        hparams = yaml.load(file)

    idx = max(range(len(epochs)), key=epochs.__getitem__)

    return paths[idx], hparams

def load_network(path: Path, load=True):
    path, hparams = find_stuff(path)
    net = network_topology(hparams["network"])
    info = PruneInfo(net)
    prune.simple_prune(info, torch_prune.Identity)
    module = VisionModule(net, test_labels=data.labels, hparams=hparams)
    if load:
        module.load_state_dict(torch.load(path)["state_dict"], strict=False)
    return module, info, hparams

def evaluate(module):
    trainer = pl.Trainer(gpus=1, logger=CSVLogger("/tmp/debug_pruneshift"), limit_test_batches=0.2)
    trainer.test(module, datamodule=data)

In [88]:
m, i, h = load_network(exp_dirs[-5], False)
evaluate(m)

  # This is added back by InteractiveShellApp.init_path()
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Files already downloaded and verified


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc_brightness_1': tensor(0.9304, device='cuda:0'),
 'test_acc_brightness_2': tensor(0.9279, device='cuda:0'),
 'test_acc_brightness_3': tensor(0.9219, device='cuda:0'),
 'test_acc_brightness_4': tensor(0.9108, device='cuda:0'),
 'test_acc_brightness_5': tensor(0.8881, device='cuda:0'),
 'test_acc_contrast_1': tensor(0.9098, device='cuda:0'),
 'test_acc_contrast_2': tensor(0.8191, device='cuda:0'),
 'test_acc_contrast_3': tensor(0.7117, device='cuda:0'),
 'test_acc_contrast_4': tensor(0.5066, device='cuda:0'),
 'test_acc_contrast_5': tensor(0.1825, device='cuda:0'),
 'test_acc_defocus_blur_1': tensor(0.9224, device='cuda:0'),
 'test_acc_defocus_blur_2': tensor(0.9022, device='cuda:0'),
 'test_acc_defocus_blur_3': tensor(0.8528, device='cuda:0'),
 'test_acc_defocus_blur_4': tensor(0.7641, device='cuda:0'),
 'test_acc_defocus_blur_5': tensor(0.5217, device='cuda:0'),
 'test_a

In [42]:
for p in exp_dirs:
    print(find_stuff(p))

  # This is added back by InteractiveShellApp.init_path()


(PosixPath('/misc/lmbraid19/hoffmaja/prune_shift/data/oneshot_global_weight000006/oneshot_global_weight000006_00/checkpoints/epoch=139-val_acc=0.95.ckpt'), {'datamodule': 'cifar10', 'network': 'cifar10_resnet50', 'ratio': 32})
(PosixPath('/misc/lmbraid19/hoffmaja/prune_shift/data/oneshot_global_weight000006/oneshot_global_weight000006_01/checkpoints/epoch=139-val_acc=0.95.ckpt'), {'datamodule': 'cifar10', 'network': 'cifar10_resnet50', 'ratio': 32})
(PosixPath('/misc/lmbraid19/hoffmaja/prune_shift/data/oneshot_global_weight000008/oneshot_global_weight000008_00/checkpoints/epoch=79-val_acc=0.93.ckpt'), {'datamodule': 'cifar10', 'network': 'cifar10_resnet18', 'ratio': 2})
(PosixPath('/misc/lmbraid19/hoffmaja/prune_shift/data/oneshot_global_weight000008/oneshot_global_weight000008_01/checkpoints/epoch=59-val_acc=0.93.ckpt'), {'datamodule': 'cifar10', 'network': 'cifar10_resnet18', 'ratio': 2})
(PosixPath('/misc/lmbraid19/hoffmaja/prune_shift/data/oneshot_global_weight000005/oneshot_global

In [63]:
i.network_comp()

29.906243722075153

In [44]:
import pandas as pd

In [89]:
df_50_2 = pd.read_csv("/tmp/debug_pruneshift/default/version_6/metrics.csv")
df_50_1 = pd.read_csv("/tmp/debug_pruneshift/default/version_7/metrics.csv")
df_50_32 = pd.read_csv("/tmp/debug_pruneshift/default/version_5/metrics.csv")
df_50_orig = pd.read_csv("/tmp/debug_pruneshift/default/version_8/metrics.csv")

In [97]:
(df_50_orig < df_50_2).to_numpy().mean()

0.2857142857142857

In [93]:
df = pd.concat([df_50_1, df_50_2, df_50_32, df_50_orig])

In [62]:
1 - 1 / 32

0.96875

In [94]:
df["test_acc_original"]

0    0.924395
0    0.919859
0    0.927419
0    0.931452
Name: test_acc_original, dtype: float64