In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
import damselfly.models.cnn1d as cnn
import damselfly.utils.prune as pruning
from pathlib import Path
import glob
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns


save_models= Path.home()/'group'/'project'/'scripting'/\
'output'/'220823_iterative_prune_cnn1d'/'model11'
model_paths = glob.glob(str(save_models/'*.tar'))

#checkpoint = torch.load(save_model_path, map_location=torch.device('cpu'))

train_info = []
with open(save_models/'train.out', 'r') as infile:
    lines = infile.readlines()
    for line in lines: 
        _list = []
        if line[0] == '|':
            temp = line.replace(' ', '').split('|')
            #print(temp)
            _list.append(int(temp[1]))
            _list.append(float(temp[2].split('loss=')[-1]))
            _list.append(float(temp[3].split('acc=')[-1]))
            _list.append(float(temp[4].split('val.acc=')[-1]))
            train_info.append(_list)
            
for i in range(len(train_info)):
    train_info[i][0]=i+1

train_info = np.array(train_info)


#with open(save_models/'train.out', 'r') as infile:
#    lines = infile.readlines()
#    for line in lines:
#        print(line)
sparsity = [1-(1-0.2) ** i for i in range(21)]


    

In [None]:
train_info

In [None]:
sns.set_theme(context='talk', style='ticks', font_scale=1.2)
clist = sns.color_palette('deep', n_colors=10)
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)
ax2 = ax.twinx()

ax.plot(train_info[:, 0], train_info[:, 1], color=clist[0])
ax.set_yticks(np.linspace(0.63, 0.69, 7))
ax.set_xticks(np.arange(0, 2000+100, 100),)
ax.set_xticklabels(labels= np.arange(0, 2000+100, 100), rotation=45)
xlims = ax.get_xlim()
ax.grid(axis='both')
ax2.plot(train_info[:, 0], train_info[:, 2], color=clist[1])
ax2.plot(train_info[:, 0], train_info[:, 3], color=clist[2])
ax2.set_yticks(np.linspace(0.54, 0.64, 6))

ax3 = ax.twiny()
ax3.set_xlim(xlims[0], xlims[1])
ax3.set_xticks(np.arange(0, 2000+100, 100))
ax3.set_xticklabels(np.round(np.array(sparsity), 3), rotation=45)





In [None]:
for model in model_paths:
    name = Path(model).name
    if name.split('_prune')[-1]=='17.tar':
        checkpoint = torch.load(model, map_location=torch.device('cpu'))
        
model_args = checkpoint['model_args']
model = cnn.Cnn1d(*model_args)
model = pruning.PruneModel(model, 0.0)
model.load_state_dict(checkpoint['model_state_dict'])

for module in model.modules():
    if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.Linear):
        prune.remove(module, 'weight')
    

In [None]:
module_list = []

for module in model.modules():
    if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.Linear):
        print(module)
        print(100 * float(torch.sum(module.weight==0))/float(module.weight.nelement()))
        module_list.append(module)


In [None]:
ind = 4

print(module_list[ind].weight.shape)
for i in range(module_list[ind].weight.shape[1]):
    #print(module_list[ind].weight[:, i, :])
    print(torch.sum(torch.sum(abs(module_list[ind].weight[:, i, :]), dim=-1)>0))


#tensor_list = [*module_list[ind].weight[i, :, :] for i in range(module_list[ind].weight.shape)]

torch.stack((module_list[ind].weight[0, :, :], module_list[ind].weight[1, :, :]), dim=-1).shape

In [None]:
modules[0]

# test pruning

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [None]:

optimizer = torch.optim.SGD(params = model.parameters(), lr=1e-3)
loss_fcn = torch.nn.CrossEntropyLoss()

for iteration in range(10):
    
    print("Before")
    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            print(100 * float(torch.sum(module.weight==0))/float(module.weight.nelement()))
    
    
    #module_tups = []
    #for module in model.modules():
    #    if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
    #        module_tups.append((module, 'weight'))
    #prune.global_unstructured(module_tups, pruning_method=prune.L1Unstructured, amount=0.2)
    
    #for module, _ in module_tups:
    #    prune.remove(module, 'weight')
    
    pruning_fraction = 1 - (1 - 0.15) ** (iteration + 1)
     
    model = prune.PruneModel2(model, pruning_fraction)
        
    print("After")
        
    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            print(100 * float(torch.sum(module.weight==0))/float(module.weight.nelement()))
    
        
            
    
        
    