In [30]:
import os
from pathlib import Path
import math
from functools import partial
from argparse import Namespace

import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
from torch.nn.utils.prune import custom_from_mask
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

from torch.optim import SGD
import matplotlib.pyplot as plt

from pruneshift.modules import VisionModule
from pruneshift.networks import network as create_network
from pruneshift.datamodules import datamodule
from pruneshift.prune_info import PruneInfo
from pruneshift.prune import prune
from pruneshift.prune_hydra import hydrate, dehydrate
from pruneshift.utils import load_pruned_state_dict, load_state_dict


DATASET_PATH = os.environ["DATASET_PATH"]
MODEL_PATH = Path(os.environ["MODEL_PATH"])/"augmix"

In [69]:
net = create_network("cifar100_resnet18", model_path=MODEL_PATH)
prune(net, "layer_weight", 32)

<pruneshift.prune_info.PruneInfo at 0x7f4228afbc50>

In [70]:
# dehydrate(net)

In [71]:
T = 30
optim_fn = partial(optim.SGD, nesterov=True, weight_decay= 0.0005, momentum= 0.9)
scheduler_fn = partial(optim.lr_scheduler.CosineAnnealingLR, T_max=T)

lr_monitor = LearningRateMonitor(logging_interval='step')
ckpt_callback = ModelCheckpoint(save_top_k=1, monitor="val_acc")
callbacks = [lr_monitor, ckpt_callback]
trainer = pl.Trainer(gpus=1, max_epochs=30, callbacks=[lr_monitor], weights_summary=None)

# net.conv1.is_protected = True
# net.fc.is_protected = False
data = datamodule("cifar100_augmix", DATASET_PATH, batch_size=128)

# info = hydrate(net, 32)
module = VisionModule(net, data.labels, optimizer_fn=optim_fn, learning_rate=0.1,
                      scheduler_fn=scheduler_fn, augmix_loss_alpha=24.)

trainer.fit(module, datamodule=data)

GPU available: True, used: True
I0131 21:19:58.065671 139927539267392 distributed.py:49] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0131 21:19:58.067357 139927539267392 distributed.py:49] TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
I0131 21:19:58.068796 139927539267392 accelerator_connector.py:402] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Files already downloaded and verified
Files already downloaded and verified


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [55]:
dehydrate(net)

In [56]:
trainer.save_checkpoint("70_prune_without.ckpt")

In [73]:
import pandas as pd

df = pd.read_csv("csv_logs/default/version_6/metrics.csv")
1 - df.filter(regex="test_acc*").mean(axis=1)

0    0.568441
dtype: float64

In [74]:
df["test_acc_clean"]

0    0.5527
Name: test_acc_clean, dtype: float64

In [18]:
dehydrate(net)

In [57]:
T = 30
optim_fn = partial(optim.SGD, nesterov=True, weight_decay= 0.0005, momentum= 0.9)
scheduler_fn = partial(optim.lr_scheduler.CosineAnnealingLR, T_max=T)

data = datamodule("cifar100_augmix", DATASET_PATH, batch_size=128)

module = VisionModule(net, data.labels, optimizer_fn=optim_fn, learning_rate=0.1,
                      scheduler_fn=scheduler_fn, augmix_loss_alpha=24.)

lr_monitor = LearningRateMonitor(logging_interval='step')
ckpt_callback = ModelCheckpoint(save_top_k=1, monitor="val_acc")
callbacks = [lr_monitor, ckpt_callback]
trainer = pl.Trainer(gpus=1, max_epochs=30, callbacks=[lr_monitor], weights_summary=None)

GPU available: True, used: True
I0131 19:29:34.263937 139927539267392 distributed.py:49] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0131 19:29:34.265509 139927539267392 distributed.py:49] TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
I0131 19:29:34.266748 139927539267392 accelerator_connector.py:402] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


In [58]:
trainer.fit(module, datamodule=data)

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [59]:
trainer.save_checkpoint("70_24_finetuned.ckpt")

In [21]:
data = datamodule("cifar100_corrupted", DATASET_PATH, batch_size=128)
module = VisionModule(net, data.labels)

In [22]:
PruneInfo(module).summary()

Unnamed: 0,module,param,comp,amount,size,shape,target,protected
0,network.conv1,weight,38.4,0.973958,1728,"(64, 3, 3, 3)",True,False
1,network.layer1.0.conv1,weight,38.480167,0.974013,36864,"(64, 64, 3, 3)",True,False
2,network.layer1.0.conv2,weight,38.480167,0.974013,36864,"(64, 64, 3, 3)",True,False
3,network.layer1.1.conv1,weight,38.480167,0.974013,36864,"(64, 64, 3, 3)",True,False
4,network.layer1.1.conv2,weight,38.480167,0.974013,36864,"(64, 64, 3, 3)",True,False
5,network.layer2.0.conv1,weight,38.500261,0.974026,73728,"(128, 64, 3, 3)",True,False
6,network.layer2.0.conv2,weight,38.510316,0.974033,147456,"(128, 128, 3, 3)",True,False
7,network.layer2.0.downsample.0,weight,38.460094,0.973999,8192,"(128, 64, 1, 1)",True,False
8,network.layer2.1.conv1,weight,38.510316,0.974033,147456,"(128, 128, 3, 3)",True,False
9,network.layer2.1.conv2,weight,38.510316,0.974033,147456,"(128, 128, 3, 3)",True,False


In [72]:
loggers = [CSVLogger("csv_logs")]
data = datamodule("cifar100_corrupted", DATASET_PATH, batch_size=128)
module = VisionModule(net, data.labels)

trainer = pl.Trainer(gpus=1, max_epochs=30, weights_summary=None, logger=loggers)
trainer.test(module, datamodule=data)

GPU available: True, used: True
I0131 21:57:58.331181 139927539267392 distributed.py:49] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0131 21:57:58.333100 139927539267392 distributed.py:49] TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
I0131 21:57:58.334436 139927539267392 accelerator_connector.py:402] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Files already downloaded and verified
Files already downloaded and verified


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc_brightness_1': tensor(0.5499, device='cuda:0'),
 'test_acc_brightness_2': tensor(0.5389, device='cuda:0'),
 'test_acc_brightness_3': tensor(0.5173, device='cuda:0'),
 'test_acc_brightness_4': tensor(0.4904, device='cuda:0'),
 'test_acc_brightness_5': tensor(0.4243, device='cuda:0'),
 'test_acc_clean': tensor(0.5527, device='cuda:0'),
 'test_acc_contrast_1': tensor(0.5361, device='cuda:0'),
 'test_acc_contrast_2': tensor(0.4532, device='cuda:0'),
 'test_acc_contrast_3': tensor(0.3940, device='cuda:0'),
 'test_acc_contrast_4': tensor(0.3088, device='cuda:0'),
 'test_acc_contrast_5': tensor(0.1240, device='cuda:0'),
 'test_acc_defocus_blur_1': tensor(0.5485, device='cuda:0'),
 'test_acc_defocus_blur_2': tensor(0.5320, device='cuda:0'),
 'test_acc_defocus_blur_3': tensor(0.5134, device='cuda:0'),
 'test_acc_defocus_blur_4': tensor(0.4910, device='cuda:0'),
 'test_acc_defocu

[{'test_acc_clean': 0.5526999831199646,
  'test_acc_shot_noise_1': 0.46630001068115234,
  'test_acc_shot_noise_2': 0.41260001063346863,
  'test_acc_shot_noise_3': 0.32019999623298645,
  'test_acc_shot_noise_4': 0.29010000824928284,
  'test_acc_shot_noise_5': 0.2460000067949295,
  'test_acc_gaussian_noise_1': 0.41519999504089355,
  'test_acc_gaussian_noise_2': 0.32339999079704285,
  'test_acc_gaussian_noise_3': 0.2605000138282776,
  'test_acc_gaussian_noise_4': 0.23499999940395355,
  'test_acc_gaussian_noise_5': 0.2134000062942505,
  'test_acc_saturate_1': 0.40799999237060547,
  'test_acc_saturate_2': 0.33869999647140503,
  'test_acc_saturate_3': 0.5184999704360962,
  'test_acc_saturate_4': 0.4413999915122986,
  'test_acc_saturate_5': 0.3686000108718872,
  'test_acc_jpeg_compression_1': 0.5196999907493591,
  'test_acc_jpeg_compression_2': 0.5088000297546387,
  'test_acc_jpeg_compression_3': 0.5004000067710876,
  'test_acc_jpeg_compression_4': 0.5006999969482422,
  'test_acc_jpeg_compres