In [1]:
import pathlib
import os
import time
import pytorch_lightning as pl
import torch.nn as nn
import torch
from typing import NamedTuple, Sequence
from pruneshift import datamodule
from pytorch_lightning.loggers import CSVLogger
from pruneshift import topology
from pytorch_lightning.metrics import Accuracy

from pruneshift.prune.utils import simple_prune
from pruneshift.prune.strategies import Absolute


DATA_PATH = os.environ["DATASET_PATH"]
TRAIN_PATH = pathlib.Path(os.environ["EXPERIMENT_PATH"]) / "train"

In [21]:
class ChptPath(NamedTuple):
    model_name: str
    datamodule_name: str
    version: int
    epoch: int
    epoch_path: pathlib.Path

def find_paths():
    train_path = TRAIN_PATH
    for config_path in train_path.iterdir():
        model_name, dataset_name = config_path.stem.split('_')
        for version_path in (config_path / "lightning_logs").iterdir():
            version = version_path.stem.split('_')[-1]
            for epoch_path in (version_path / "checkpoints").glob("*.ckpt"):
                epoch = epoch_path.stem.split('=')[-1]
                yield ChptPath(model_name, dataset_name, version, epoch, epoch_path)


def load_network(chpt: ChptPath):
    print("Recreating network for {}.".format(chpt))
    start = time.time()
    network = topology(chpt.model_name, num_classes=10)
    state = torch.load(chpt.epoch_path)
    # Remove the structure of the network 
    conv_state = {}
    for name in state["state_dict"]:
        conv_state[name[8:]] = state["state_dict"][name]
    network.load_state_dict(conv_state)
    print("Finished recreating network in {:.1f}s.".format(time.time() - start))
    return network


class TestModule(pl.LightningModule):
    """Module for training models."""
    def __init__(self,
                 network: nn.Module,
                 labels: Sequence[str]):
        super(TestModule, self).__init__()
        self.network = network
        self.labels = labels
        self.accuracy = nn.ModuleDict({l: Accuracy() for l in labels})
        self.statistics = None

    def forward(self, x):
        return self.network(x)

    def test_step(self, batch, batch_idx, dataset_idx):
        x, y = batch
        self.accuracy[self.labels[dataset_idx]](y, torch.argmax(self(x), -1))
        
    def test_epoch_end(self, output):
        self.statistics = {l: a.compute() for l, a in self.accuracy.items()}

In [22]:
example = list(find_paths())[-1]
network = load_network(example)
data = datamodule("CIFAR10Corrupted", root=DATA_PATH, num_workers=5)

Recreating network for ChptPath(model_name='vgg13', datamodule_name='CIFAR10', version='2', epoch='15', epoch_path=PosixPath('/misc/lmbraid19/hoffmaja/experiments/train/vgg13_CIFAR10/lightning_logs/version_2/checkpoints/epoch=15.ckpt')).
Finished recreating network in 4.2s.


In [26]:
module = TestModule(network, data.labels)
trainer = pl.Trainer(logger=CSVLogger(save_dir="/tmp/test_pruneshift"),
                     checkpoint_callback=False,
                     gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


In [27]:
start = time.time()
trainer.test(module, datamodule=data)
time.time() - start

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------



378.3447206020355

In [28]:
module.statistics

{'undistorted': tensor(0.8190, device='cuda:0'),
 'shot_noise_1': tensor(0.7855, device='cuda:0'),
 'shot_noise_2': tensor(0.7466, device='cuda:0'),
 'shot_noise_3': tensor(0.6405, device='cuda:0'),
 'shot_noise_4': tensor(0.5942, device='cuda:0'),
 'shot_noise_5': tensor(0.5130, device='cuda:0'),
 'gaussian_noise_1': tensor(0.7513, device='cuda:0'),
 'gaussian_noise_2': tensor(0.6682, device='cuda:0'),
 'gaussian_noise_3': tensor(0.5771, device='cuda:0'),
 'gaussian_noise_4': tensor(0.5350, device='cuda:0'),
 'gaussian_noise_5': tensor(0.4871, device='cuda:0'),
 'saturate_1': tensor(0.7953, device='cuda:0'),
 'saturate_2': tensor(0.7734, device='cuda:0'),
 'saturate_3': tensor(0.8065, device='cuda:0'),
 'saturate_4': tensor(0.7681, device='cuda:0'),
 'saturate_5': tensor(0.7130, device='cuda:0'),
 'jpeg_compression_1': tensor(0.7926, device='cuda:0'),
 'jpeg_compression_2': tensor(0.7675, device='cuda:0'),
 'jpeg_compression_3': tensor(0.7604, device='cuda:0'),
 'jpeg_compression_4': 

In [32]:
6 * 20 * 6 * 1 / 60 * 2

24.0

In [33]:
24 / 8

3.0