In [20]:
import pathlib
import os
import time
from copy import deepcopy
from typing import NamedTuple, Sequence

import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm.autonotebook import tqdm

from pruneshift import datamodule
from pruneshift import topology
from pruneshift.prune.utils import simple_prune
from pruneshift.prune.strategies import Absolute


DATA_PATH = os.environ["DATASET_PATH"]
TRAIN_PATH = pathlib.Path(os.environ["EXPERIMENT_PATH"]) / "train"

In [21]:
class ChptPath(NamedTuple):
    model_name: str
    datamodule_name: str
    version: int
    epoch: int
    epoch_path: pathlib.Path

def find_paths():
    train_path = TRAIN_PATH
    for config_path in train_path.iterdir():
        model_name, dataset_name = config_path.stem.split('_')
        for version_path in (config_path / "lightning_logs").iterdir():
            version = version_path.stem.split('_')[-1]
            for epoch_path in (version_path / "checkpoints").glob("*.ckpt"):
                epoch = int(epoch_path.stem.split('=')[-1])
                yield ChptPath(model_name, dataset_name, version, epoch, epoch_path)


def load_network(chpt: ChptPath):
    print("Recreating network for {}.".format(chpt))
    start = time.time()
    network = topology(chpt.model_name, num_classes=10)
    state = torch.load(chpt.epoch_path)
    # Remove the structure of the network 
    conv_state = {}
    for name in state["state_dict"]:
        conv_state[name[8:]] = state["state_dict"][name]
    network.load_state_dict(conv_state)
    print("Finished recreating network in {:.1f}s.".format(time.time() - start))
    return network


class TestModule(pl.LightningModule):
    """Module for training models."""
    def __init__(self,
                 network: nn.Module,
                 labels: Sequence[str],
                 lr: float = 0.0001):
        super(TestModule, self).__init__()
        self.network = network
        self.labels = labels
        self.accuracy = nn.ModuleDict({l: Accuracy() for l in labels})
        self.lr = lr
        self.test_statistics = None

    def _predict(self, batch):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = accuracy(torch.argmax(logits, 1), y)
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self._predict(batch)
        self.log("Training/Loss", loss)
        self.log("Training/Accuracy", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._predict(batch)
        self.log("Validation/Loss", loss)
        self.log("Validation/Accuracy", acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, x):
        return self.network(x)
        
    def test_step(self, batch, batch_idx, dataset_idx):
        x, y = batch
        self.accuracy[self.labels[dataset_idx]](y, torch.argmax(self(x), -1))
        
    def test_epoch_end(self, output):
        self.test_statistics = {l: a.compute().item() for l, a in self.accuracy.items()}
        

def prune_shift(chpt, strategy, train_data, test_data, module_map):
    compression_ratios = [1, 2, 4, 8, 16]
    original_network = load_network(chpt)

    
    early_stop_callback = EarlyStopping(
        monitor="Validation/Accuracy")
    trainer = pl.Trainer(callbacks=[early_stop_callback],
                         logger=CSVLogger(save_dir="/tmp/test_pruneshift"),
                         checkpoint_callback=False,
                         gpus=1)
    statistics = []

    for ratio in compression_ratios:
        network = deepcopy(original_network)
        # Prune the network
        simple_prune(network, strategy, module_map=module_map, amount=1 - 1 / ratio)
        module = TestModule(network, test_data.labels)
        trainer.fit(module, datamodule=train_data)
        trainer.test(module, datamodule=test_data)
        statistics.append({"network": chpt.model_name,
                           "dataset": chpt.datamodule_name,
                           "epoch": chpt.epoch + 1,
                           "ratio": ratio,
                           **module.test_statistics})
    return statistics

In [22]:
example = next(find_paths())
module_map = {nn.Linear: ["weight"], nn.Conv2d: ["weight"]}
train_data = datamodule("CIFAR10", root=DATA_PATH, num_workers=5)
test_data = datamodule("CIFAR10Corrupted", root=DATA_PATH, num_workers=5, lvls=[5])
stats = prune_shift(example, Absolute, train_data, test_data, module_map)

Recreating network for ChptPath(model_name='resnet18', datamodule_name='CIFAR10', version='4', epoch=19, epoch_path=PosixPath('/misc/lmbraid19/hoffmaja/experiments/train/resnet18_CIFAR10/lightning_logs/version_4/checkpoints/epoch=19.ckpt')).


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Finished recreating network in 0.5s.
Files already downloaded and verified
Files already downloaded and verified



  | Name     | Type       | Params
----------------------------------------
0 | network  | ResNet     | 11 M  
1 | accuracy | ModuleDict | 0     


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




Files already downloaded and verified


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…




KeyboardInterrupt: 

In [47]:
pd.DataFrame(stats)

Unnamed: 0,brightness_5,contrast_5,dataset,defocus_blur_5,elastic_transform_5,epoch,fog_5,frost_5,gaussian_blur_5,gaussian_noise_5,...,network,pixelate_5,ratio,saturate_5,shot_noise_5,snow_5,spatter_5,speckle_noise_5,undistorted,zoom_blur_5
0,"tensor(0.6920, device='cuda:0')","tensor(0.2188, device='cuda:0')",CIFAR10,"tensor(0.5227, device='cuda:0')","tensor(0.6653, device='cuda:0')",19,"tensor(0.3945, device='cuda:0')","tensor(0.6209, device='cuda:0')","tensor(0.4810, device='cuda:0')","tensor(0.5431, device='cuda:0')",...,resnet18,"tensor(0.7219, device='cuda:0')",1,"tensor(0.5732, device='cuda:0')","tensor(0.5663, device='cuda:0')","tensor(0.6734, device='cuda:0')","tensor(0.5718, device='cuda:0')","tensor(0.5631, device='cuda:0')","tensor(0.7706, device='cuda:0')","tensor(0.5593, device='cuda:0')"
1,"tensor(0.6668, device='cuda:0')","tensor(0.2222, device='cuda:0')",CIFAR10,"tensor(0.5152, device='cuda:0')","tensor(0.6425, device='cuda:0')",19,"tensor(0.3893, device='cuda:0')","tensor(0.5751, device='cuda:0')","tensor(0.4793, device='cuda:0')","tensor(0.5521, device='cuda:0')",...,resnet18,"tensor(0.7058, device='cuda:0')",2,"tensor(0.5772, device='cuda:0')","tensor(0.5702, device='cuda:0')","tensor(0.6340, device='cuda:0')","tensor(0.5715, device='cuda:0')","tensor(0.5692, device='cuda:0')","tensor(0.7651, device='cuda:0')","tensor(0.5490, device='cuda:0')"
2,"tensor(0.4756, device='cuda:0')","tensor(0.1699, device='cuda:0')",CIFAR10,"tensor(0.4086, device='cuda:0')","tensor(0.4638, device='cuda:0')",19,"tensor(0.2487, device='cuda:0')","tensor(0.3436, device='cuda:0')","tensor(0.3723, device='cuda:0')","tensor(0.3677, device='cuda:0')",...,resnet18,"tensor(0.5473, device='cuda:0')",4,"tensor(0.4514, device='cuda:0')","tensor(0.3762, device='cuda:0')","tensor(0.3892, device='cuda:0')","tensor(0.3787, device='cuda:0')","tensor(0.3721, device='cuda:0')","tensor(0.6145, device='cuda:0')","tensor(0.4272, device='cuda:0')"
3,"tensor(0.2188, device='cuda:0')","tensor(0.1221, device='cuda:0')",CIFAR10,"tensor(0.1875, device='cuda:0')","tensor(0.1991, device='cuda:0')",19,"tensor(0.1373, device='cuda:0')","tensor(0.1691, device='cuda:0')","tensor(0.1782, device='cuda:0')","tensor(0.1925, device='cuda:0')",...,resnet18,"tensor(0.2212, device='cuda:0')",8,"tensor(0.2282, device='cuda:0')","tensor(0.1956, device='cuda:0')","tensor(0.1943, device='cuda:0')","tensor(0.1986, device='cuda:0')","tensor(0.1932, device='cuda:0')","tensor(0.2472, device='cuda:0')","tensor(0.1881, device='cuda:0')"
4,"tensor(0.1046, device='cuda:0')","tensor(0.1047, device='cuda:0')",CIFAR10,"tensor(0.1193, device='cuda:0')","tensor(0.1191, device='cuda:0')",19,"tensor(0.1048, device='cuda:0')","tensor(0.0999, device='cuda:0')","tensor(0.1150, device='cuda:0')","tensor(0.1188, device='cuda:0')",...,resnet18,"tensor(0.1194, device='cuda:0')",16,"tensor(0.1133, device='cuda:0')","tensor(0.1178, device='cuda:0')","tensor(0.1040, device='cuda:0')","tensor(0.1180, device='cuda:0')","tensor(0.1163, device='cuda:0')","tensor(0.1202, device='cuda:0')","tensor(0.1214, device='cuda:0')"


In [51]:
list(find_paths())[:10]

[ChptPath(model_name='resnet18', datamodule_name='CIFAR10', version='4', epoch='19', epoch_path=PosixPath('/misc/lmbraid19/hoffmaja/experiments/train/resnet18_CIFAR10/lightning_logs/version_4/checkpoints/epoch=19.ckpt')),
 ChptPath(model_name='resnet18', datamodule_name='CIFAR10', version='4', epoch='13', epoch_path=PosixPath('/misc/lmbraid19/hoffmaja/experiments/train/resnet18_CIFAR10/lightning_logs/version_4/checkpoints/epoch=13.ckpt')),
 ChptPath(model_name='resnet18', datamodule_name='CIFAR10', version='4', epoch='9', epoch_path=PosixPath('/misc/lmbraid19/hoffmaja/experiments/train/resnet18_CIFAR10/lightning_logs/version_4/checkpoints/epoch=9.ckpt')),
 ChptPath(model_name='resnet18', datamodule_name='CIFAR10', version='4', epoch='2', epoch_path=PosixPath('/misc/lmbraid19/hoffmaja/experiments/train/resnet18_CIFAR10/lightning_logs/version_4/checkpoints/epoch=2.ckpt')),
 ChptPath(model_name='resnet18', datamodule_name='CIFAR10', version='4', epoch='11', epoch_path=PosixPath('/misc/lmb

In [13]:
len(list(find_paths())) * 76 / 60 / 60 / 8

1.952777777777778