In [15]:
%matplotlib inline
import matplotlib.pylab as plt
import torchvision.models as models
import torch.nn as nn
import torch
import numpy as np
import pandas as pd
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

In [16]:
from pruner import * 
from models import LeNet

In [17]:
TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE=512

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

trainset = torchvision.datasets.MNIST(root='./mnist-data', train=True,
                                     download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE,
                                         shuffle=True, num_workers=1)

testset = torchvision.datasets.MNIST(root='./mnist-data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=TEST_BATCH_SIZE,
                                         shuffle=False, num_workers=1)

In [9]:
device = 'cuda'
net = LeNet(trainloader, testloader)
net = net.to(device)
net.load_state_dict(torch.load('./checkpoints/iterative-pruning/lt-mnist-1-trained'))
print(net.test())
print(net.param_count())

(0.96435546875, 0.03764837586786598)
431078.0


In [None]:
import pickle
_pruner = pickle.load(open('./experiment_data/iterative-pruning/lt-mnist-1.p', 'rb'))
_pruner.masks.keys()

pruner = SparsityPruner(net)
pruner.masks = _pruner.masks

pruner.apply_mask(mask_classifier=True)
print(net.test())
print(net.param_count())

In [None]:
86676.0

In [7]:
for name, param in net.named_parameters():
    print(name)

conv1.weight
conv1.bias
conv2.weight
conv2.bias
classifier.0.weight
classifier.0.bias
classifier.2.weight
classifier.2.bias


#### Scratch

In [None]:
net = models.resnet50(pretrained=True)

In [None]:
pruner = SparsityPruner(net)
pruner.prune(0.7)

In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
net = LeNet(trainloader, testloader)
net = net.to(device)
# torch.save(net.state_dict(), './checkpoints/lenet-lt-init-02')
# net.load_state_dict(torch.load('./checkpoints/lenet-lt-03-3-trained'))
net.test()

cuda


(0.1125, 2.301564133167267)

In [None]:
N_EPOCH = 15
optimizer = optim.Adam(net.parameters(), lr=12e-4, weight_decay=5e-4)
train_losses, val_losses, train_accs, val_accs = [], [], [], []

for epoch in range(N_EPOCH):
    print('Starting epoch {}'.format(epoch+1))
    plt_data = (train_losses, val_losses, train_accs, val_accs)
    train_losses, val_losses, train_accs, val_accs = net.train_epoch(epoch, optimizer, plot='loss', data=plt_data, LOG=10)

In [None]:
# torch.save(net.state_dict(), './checkpoints/lenet-lt-trained-02')
# net.test()

In [None]:
net = LeNet(trainloader, testloader)
net = net.to(device)
net.load_state_dict(torch.load('./checkpoints/lenet-lt-trained-02'))
val_acc, _ = net.test()
print('Before pruning: {}, params: {}'.format(val_acc, net.param_count()))
pruner = SparsityPruner(net)
pruner.prune(0.2)
print('After pruning: {}, params: {}'.format(net.test()[0], net.param_count()))

In [None]:
optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
train_losses, val_losses, train_accs, val_accs = [], [], [], []
plt_data = (train_losses, val_losses, train_accs, val_accs)
print('After pruning: {}'.format(net.test()))
for epoch in range(1):
    train_losses, val_losses, train_accs, val_accs = \
        net.train_epoch(epoch, optimizer, plot='acc', data=plt_data, LOG=25, pruner=pruner, early_stop=val_acc)
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    plt_data = (train_losses, val_losses, train_accs, val_accs)
    pruner.apply_mask()
    print('After retraining: accuracy: {}, params: {}'.format(net.test(), net.param_count()))

In [None]:
net.test()

In [None]:
# torch.save(net.state_dict(), './checkpoints/lenet-lt-finetuned-02')

### Retrain from winning ticket initialization

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net_retrain = LeNet(trainloader, testloader)
net_retrain = net.to(device)
net_retrain.load_state_dict(torch.load('./checkpoints/lenet-lt-init-02'))
net_retrain.test()

In [None]:
val_acc

In [None]:
_masks = pruner.masks
pruner_retrain = SparsityPruner(net_retrain)
pruner_retrain.masks = _masks

In [None]:
N_EPOCH = 5
optimizer = optim.SGD(net_retrain.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
train_losses, val_losses, train_accs, val_accs = [], [], [], []

pruner_retrain.apply_mask()
print(net_retrain.param_count())

for epoch in range(N_EPOCH):
    print('Starting epoch {}'.format(epoch+1))
    plt_data = (train_losses, val_losses, train_accs, val_accs)
    train_losses, val_losses, train_accs, val_accs = net.train_epoch(epoch, optimizer, plot='acc', data=plt_data, pruner=pruner_retrain, early_stop=val_acc)

In [None]:
5e-3

In [None]:
net_retrain.test()[0] >= (val_acc- 5e-3)

### Retrain from random reinit

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = LeNet(trainloader, testloader)
net = net.to(device)
net.load_state_dict(torch.load('./checkpoints/lenet-lt-03-init'))
net.test()

In [None]:
_masks = pruner.masks
pruner_reinit = SparsityPruner(net)
pruner_reinit.masks = _masks

In [None]:
N_EPOCH = 5
optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
train_losses, val_losses, train_accs, val_accs = [], [], [], []

# pruner_reinit.apply_mask()
# print(net.param_count())

for epoch in range(N_EPOCH):
    print('Starting epoch {}'.format(epoch+1))
    plt_data = (train_losses, val_losses, train_accs, val_accs)
    train_losses, val_losses, train_accs, val_accs = net.train_epoch(epoch, optimizer, plot='acc', data=plt_data, pruner=pruner_reinit)

### Results

In [None]:
WINDOW=1870

orig_val = np.array(pd.read_csv('experiment_data/lt-mnist-init.csv')['val_accs'])
orig_val = orig_val[:WINDOW]
orig_train = np.array(pd.read_csv('experiment_data/lt-mnist-init.csv')['train_accs'])
orig_train = orig_train[:WINDOW]

rand_val = np.array(pd.read_csv('experiment_data/lt-mnist-rand-init.csv')['val_accs'])
rand_train = np.array(pd.read_csv('experiment_data/lt-mnist-rand-init.csv')['train_accs'])
rand_val = rand_val[:WINDOW]
rand_train = rand_train[:WINDOW]

reinit_train = list(pd.read_csv('experiment_data/lt-mnist-reinit.csv')['train_accs'])
reinit_train = [0.0958984375] + reinit_train
reinit_train = np.array(reinit_train)
reinit_train = (reinit_train,np.repeat(reinit_train[-1], len(orig) - len(reinit_train)))
reinit_train = np.concatenate(reinit_train)
reinit_train = reinit_train[:WINDOW]


In [None]:
ax1 = pd.DataFrame({'base (0.2)': orig_val, 
              'retrain (0.2)': reinit_val,
              'random (0.2)': rand_val}).plot(figsize=(3,3), legend=None)
# fig = ax.get_figure()
# fig.savefig('figures/mnist-lenet-lt-0.2.png')

In [None]:
WINDOW = 1870

orig_val = pd.read_csv('lt-mnist-init-0.1.csv')['val_accs']
orig_val_01 = orig_val[:WINDOW]
orig_train = np.array(pd.read_csv('lt-mnist-init-0.1.csv')['train_accs'])
orig_train = orig_train[:WINDOW]

rand_val = pd.read_csv('lt-mnist-rand-init-0.1.csv')['val_accs']
rand_train = pd.read_csv('lt-mnist-rand-init-0.1.csv')['train_accs']
rand_val_01 = rand_val[:WINDOW]
rand_train = rand_train[:WINDOW]

reinit_train = list(pd.read_csv('lt-mnist-reinit-0.1.csv')['train_accs'])
reinit_train = [orig_train[-1]] + reinit_train
reinit_train = np.array(reinit_train)
reinit_train = (reinit_train,np.repeat(reinit_train[-1], len(orig) - len(reinit_train)))
reinit_train = np.concatenate(reinit_train)
reinit_train = reinit_train[:WINDOW]

reinit_val = np.array(pd.read_csv('lt-mnist-reinit-0.1.csv')['val_accs'])
reinit_vals = (reinit_val,np.repeat(reinit_val[-1], len(orig_val) - len(reinit_val)))
reinit_val = np.concatenate(reinit_vals)
reinit_val_01 = reinit_val[:WINDOW]



In [None]:
ax2 = pd.DataFrame({'base (0.1)': orig_val_01, 
              'retrain (0.1)': reinit_val_01,
              'random (0.1)': rand_val_01}).plot(figsize=(3,3), legend=None)
# fig = ax.get_figure()
# fig.savefig('figures/mnist-lenet-lt-0.1.png')

In [None]:
WINDOW = 600

orig_val = np.array(pd.read_csv('lt-mnist-init-0.15.csv')['val_accs'])
orig_val_15 = orig_val[:WINDOW]
orig_train = np.array(pd.read_csv('lt-mnist-init-0.15.csv')['train_accs'])
orig_train = orig_train[:WINDOW]

rand_val = np.array(pd.read_csv('lt-mnist-rand-init-0.15.csv')['val_accs'])
rand_train = np.array(pd.read_csv('lt-mnist-rand-init-0.15.csv')['train_accs'])
rand_val_15 = rand_val[:WINDOW]
rand_train = rand_train[:WINDOW]

reinit_train = list(pd.read_csv('lt-mnist-reinit-0.15.csv')['train_accs'])
reinit_train = [orig_train[-1]] + reinit_train
reinit_train = np.array(reinit_train)
reinit_train = (reinit_train,np.repeat(reinit_train[-1], len(orig) - len(reinit_train)))
reinit_train = np.concatenate(reinit_train)
reinit_train = reinit_train[:WINDOW]

reinit_val = np.array(pd.read_csv('lt-mnist-reinit-0.15.csv')['val_accs'])
reinit_vals = (reinit_val,np.repeat(reinit_val[-1], len(orig_val) - len(reinit_val)))
reinit_val = np.concatenate(reinit_vals)
reinit_val_15 = reinit_val[:WINDOW]

In [None]:
ax3 = pd.DataFrame({'base (0.15)': orig_val_15, 
              'retrain (0.15)': reinit_val_15,
              'random (0.15)': rand_val_15}).plot(figsize=(3,3), legend=None)
# fig = ax.get_figure()
# fig.savefig('figures/mnist-lenet-lt-0.15.png')

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=3)

# df1.plot(ax=axes[0,0])
# df2.plot(ax=axes[0,1])

a1 = pd.DataFrame({'base (0.15)': orig_val, 
              'reinit (org.)': reinit_val,
              'reinit (rand.)': rand_val}).plot(figsize=(3,3), legend=None, ax=axes[0])
a1.set(xlabel='p=0.8')

a2 = pd.DataFrame({'base': orig_val_15, 
              'reinit (org.)': reinit_val_15,
              'reinit (rand.)': rand_val_15}).plot(figsize=(3,3), legend=None, ax=axes[1])
a2.set(xlabel='p=0.85')

a3 = pd.DataFrame({'base': orig_val_01, 
              'reinit (org.)': reinit_val_01,
              'reinit (rand.)': rand_val_01}).plot(figsize=(3,3), ax=axes[2], )
a3.legend(loc='center left', bbox_to_anchor=(1, 0.5))
a3.set(xlabel='p=0.9')
fig.set_size_inches(12,3)

fig.savefig('figures/mnist-base-lt-small.png')

In [None]:
axes[0]

In [None]:
net = models.resnet18(pretrained=False)

In [None]:
WINDOW = len(orig_val)

# lt-resnet18-cifar-init-.csv  lt-resnet18-cifar-rand-init.csv  lt-resnet18-cifar-reinit.csv

orig_val = np.array(pd.read_csv('lt-resnet18-cifar-100e-init.csv')['val_accs'])
# orig_val = orig_val[:WINDOW]
orig_train = np.array(pd.read_csv('lt-resnet18-cifar-100e-init.csv')['train_accs'])
# orig_train = orig_train[:WINDOW]

rand_val = np.array(pd.read_csv('lt-resnet18-cifar-100e-rand-init.csv')['val_accs'])
rand_train = np.array(pd.read_csv('lt-resnet18-cifar-100e-rand-init.csv')['train_accs'])
reinit_vals = (rand_val,np.repeat(rand_val[-1], len(orig_val) - len(rand_val)))
rand_val = rand_val[:WINDOW]
rand_train = rand_train[:WINDOW]

reinit_train = list(pd.read_csv('lt-resnet18-cifar-100e-reinit.csv')['train_accs'])
reinit_train = [orig_train[-1]] + reinit_train
reinit_train = np.array(reinit_train)
reinit_train = (reinit_train,np.repeat(reinit_train[-1], len(orig_train) - len(reinit_train)))
reinit_train = np.concatenate(reinit_train)
# reinit_train = reinit_train[:WINDOW]

reinit_val = np.array(pd.read_csv('lt-resnet18-cifar-100e-reinit.csv')['val_accs'])
reinit_vals = (reinit_val,np.repeat(reinit_val[-1], len(orig_val) - len(reinit_val)))
reinit_val = np.concatenate(reinit_vals)
# reinit_val = reinit_val[:WINDOW]

In [None]:
len(rand_val)

In [None]:
pd.DataFrame({'base': orig_val, 
              'reinit (org)': reinit_val,
              'reinig (rand)': rand_val}).plot()

In [None]:
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE=512

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

trainset = torchvision.datasets.CIFAR10(root='./cifar-data', train=True,
                                     download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE,
                                         shuffle=True, num_workers=1)

testset = torchvision.datasets.CIFAR10(root='./cifar-data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=TEST_BATCH_SIZE,
                                         shuffle=False, num_workers=1)

In [None]:
idx, (x, label) = next(enumerate(trainloader))

In [14]:
25000/len(trainloader)

13.333333333333334

In [None]:
net(x)