In [26]:
import os
from pathlib import Path
import math
from functools import partial
from argparse import Namespace

import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
from torch.nn.utils.prune import custom_from_mask
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from torch.optim import SGD
import matplotlib.pyplot as plt

from pruneshift.modules import VisionModule
from pruneshift.networks import network as create_network
from pruneshift.datamodules import datamodule
from pruneshift.prune_info import PruneInfo
from pruneshift.prune import prune


DATASET_PATH = os.environ["DATASET_PATH"]
MODEL_PATH = Path(os.environ["MODEL_PATH"])/"hydra"

In [27]:
from hydro.models import resnet18
from hydro.models.layers import SubnetLinear, SubnetConv
from hydro.utils.model import prepare_model, initialize_scaled_score


In [28]:
net = resnet18(SubnetConv, SubnetLinear, "kaiming_normal")
args = Namespace(freeze_bn = False, k = 0.02, scores_init_type = "kaiming_uniform", exp_mode="prune")
prepare_model(net, args)
state = torch.load(MODEL_PATH/"cifar100_resnet18.0")
net.load_state_dict(state, strict=False)
initialize_scaled_score(net)

#################### Pruning network ####################
===>>  gradient for weights: None  | training importance scores only
Initialization relevance score with kaiming_uniform initialization
Initialization relevance score proportional to weight magnitudes (OVERWRITING SOURCE NET SCORES)


In [29]:
T = 10000
optim_fn = partial(optim.SGD, nesterov=True, weight_decay= 0.0005, momentum= 0.9)
scheduler_fn = partial(optim.lr_scheduler.CosineAnnealingLR, T_max=T)

lr_monitor = LearningRateMonitor(logging_interval='step')
callbacks = [lr_monitor]
trainer = pl.Trainer(gpus=1, max_epochs=12, callbacks=[lr_monitor], checkpoint_callback=False, weights_summary=None)

# net.conv1.is_protected = True
# net.fc.is_protected = False
data = datamodule("cifar100", DATASET_PATH, batch_size=128)

module = VisionModule(net, data.labels, optimizer_fn=optim_fn, learning_rate=0.1, scheduler_fn=scheduler_fn)

trainer.fit(module, datamodule=data)

GPU available: True, used: True
I0130 22:18:11.433606 140023598487360 distributed.py:49] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0130 22:18:11.440435 140023598487360 distributed.py:49] TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
I0130 22:18:11.442483 140023598487360 accelerator_connector.py:402] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Files already downloaded and verified
Files already downloaded and verified


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [61]:
from torch.nn import Parameter
from hydro.models.layers import GetSubnet
import torch.nn.functional as F

class SubnetLinear(nn.Linear):
    # self.k is the % of weights remaining, a real number in [0,1]
    # self.popup_scores is a Parameter which has the same shape as self.weight
    # Gradients to self.weight, self.bias have been turned off.

    def __init__(self, in_features, out_features, bias=True):
        super(SubnetLinear, self).__init__(in_features, out_features, bias=True)
        self.popup_scores = Parameter(torch.Tensor(self.weight.shape))
        nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5))
        self.weight.requires_grad = False
        self.bias.requires_grad = False
        self.w = 0
        # self.register_buffer('w', None)

    def set_prune_rate(self, k):
        self.k = k

    def forward(self, x):
        # Get the subnetwork by sorting the scores.
        adj = GetSubnet.apply(self.popup_scores.abs(), self.k)
        print(">>> Mask ", adj)
        # Use only the subnetwork in the forward pass.
        self.w = self.weight * adj
        x = F.linear(x, self.w, self.bias)
        print(self.w)
        print(x)

        return x


pl.seed_everything(0)
module = SubnetLinear(4, 2)
net = nn.Sequential(module, nn.Linear(2, 1))
args = Namespace(freeze_bn = False, k = 0.25, scores_init_type = "kaiming_uniform", exp_mode="prune")

module.weight = nn.Parameter(torch.tensor([[1., 0.5, 0.1, 0.5],
                                           [1., -0.5, 0.1 ,0.5 ]]))
module.bias = nn.Parameter(torch.tensor([1., 2.]))

net[1].weight = nn.Parameter(torch.tensor([[1.5,
                                            1.]]))
net[1].bias = nn.Parameter(torch.tensor([1.]))

prepare_model(module, args)
initialize_scaled_score(net)


o = optim.SGD(net.parameters(), lr=1)
result = net(torch.tensor([[0.2, 1, 0.2, 1]]))

print(">>>>>>>>>>SANITY CHECK")
print(">>>>>>>>Net")
print(list(net.named_parameters()))
print(">>>>>>Forward")
print(result)
print([(n, b) for n, b in module.named_buffers()])
print("\n")
print([(n, b) for n, b in module.named_parameters()])
result.backward()
o.step()
print("\n\n")
print([(n, b) for n, b in module.named_buffers()])
print("\n")
print([(n, b) for n, b in module.named_parameters()])
print(">>>>>>>>Grad")
print(net[1].weight.grad)
print(">>>>>>>>Net")
print(list(net.named_parameters()))

#################### Pruning network ####################
===>>  gradient for weights: None  | training importance scores only
Initialization relevance score with kaiming_uniform initialization
Initialization relevance score proportional to weight magnitudes (OVERWRITING SOURCE NET SCORES)
>>> Mask  tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.]], grad_fn=<GetSubnetBackward>)
tensor([[1., 0., 0., 0.],
        [1., -0., 0., 0.]], grad_fn=<MulBackward0>)
tensor([[1.2000, 2.2000]], grad_fn=<AddmmBackward>)
>>>>>>>>>>SANITY CHECK
>>>>>>>>Net
[('0.weight', Parameter containing:
tensor([[ 1.0000,  0.5000,  0.1000,  0.5000],
        [ 1.0000, -0.5000,  0.1000,  0.5000]])), ('0.bias', Parameter containing:
tensor([1., 2.])), ('0.popup_scores', Parameter containing:
tensor([[ 1.2247,  0.6124,  0.1225,  0.6124],
        [ 1.2247, -0.6124,  0.1225,  0.6124]], requires_grad=True)), ('1.weight', Parameter containing:
tensor([[1.5000, 1.0000]], requires_grad=True)), ('1.bias', Parameter containi