In [None]:
import torch
from tqdm.auto import tqdm
import itertools
import random
import logging
import pickle
from os.path import expanduser
logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision import transforms, datasets
# Imports for plotting our result curves
import matplotlib
import matplotlib.pyplot as plt
import torch.nn.utils.prune as prune
home = expanduser("~/Model_compression")

In [None]:
class CIFAR3(Dataset):

    def __init__(self,split="train",transform=None):
      if split=="train":
        with open("cifar10_hst_train", 'rb') as fo:
          self.data = pickle.load(fo) 
      elif split=="val":
        with open("cifar10_hst_val", 'rb') as fo:
          self.data = pickle.load(fo)
      else:
        with open("cifar10_hst_test", 'rb') as fo:
          self.data = pickle.load(fo)
      
      self.transform = transform

    def __len__(self):
        return len(self.data['labels'])

    def __getitem__(self, idx):
        
        x = self.data['images'][idx,:]
        r = x[:1024].reshape(32,32)
        g = x[1024:2048].reshape(32,32)
        b = x[2048:].reshape(32,32)
        
        x = Tensor(np.stack([r,g,b]))

        if self.transform is not None:
          x = self.transform(x)
        
        y = self.data['labels'][idx,0]
        return x,y 

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.relu4 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
        self.fc1 = nn.Linear(4096, 512)
        self.relu5 = nn.ReLU()
        self.fc2 = nn.Linear(512, 3)
        self.batchnorm1 = nn.BatchNorm1d(512)
       
    def forward(self, x):
        #TODO
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.pool1(self.conv2(x)))
        x = self.relu3(self.conv3(x))
        x = self.relu4(self.conv4(x))
        x = self.pool2(x)
        #print(x.shape, "###########1")
        nff = self.num_flat_features(x)
        x = x.view(-1 , nff)
        #print(x.shape, "###########")
        x = self.batchnorm1(self.fc1(x))
        x = self.relu5(x)
        x = self.fc2(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


In [None]:
train_transform = transforms.Compose([
        transforms.ColorJitter(),
        transforms.RandomRotation(30),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize(mean=[127.5, 127.5, 127.5],
                             std=[127.5, 127.5, 127.5])
    ])

test_transform = transforms.Compose([
        transforms.Normalize(mean=[127.5, 127.5, 127.5],
                             std=[127.5, 127.5, 127.5])
    ])

train_data = CIFAR3("train", transform=train_transform)
val_data = CIFAR3("val", transform=test_transform)
test_data = CIFAR3("test", transform=test_transform)

batch_size = 256
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
valloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2)
testloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)


In [None]:
device = 'cpu'
model = Net()
model_dict = torch.load(home+'/outputs/best_mode.pth')
model.load_state_dict(model_dict['model_state_dict'])


In [None]:
# module = model.conv1
# module2 = model.conv2
# module3 = model.conv3
# module4 = model.conv4
# module5 = model.conv5
# #print(list(module.named_parameters()))
# #print(list(module.named_buffers()))
# prune.random_unstructured(module, name='weight', amount=0.3) 
# #print(list(module.named_parameters()))
# #print(module.weight)
# prune.l1_unstructured(module, name="bias", amount=3)

# parameters_to_prune = [
#             (model.conv1, 'weight'),
#             (model.conv2, 'weight'),
#             (model.conv3, 'weight'),
#             (model.conv4, 'weight'),
#             (model.fc1, 'weight'),
#             (model.fc2, 'weight'),
#             ]
# prune.global_unstructured(
#             parameters_to_prune,
#             pruning_method=prune.L1Unstructured,
#             amount=0.9
#             )
#print(model._forward_pre_hooks)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.001)
criterion = torch.nn.CrossEntropyLoss()

device = torch.device('cpu')

model.to(device)

In [None]:
loss_log = []
acc_log = []
val_acc_log = []
val_loss_log = []
best_val_acc = 0.85
for i in range(57):
  parameters_to_prune = [
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.conv3, 'weight'),
            (model.conv4, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'),
            ]
  prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=0.45
            )
  #prune.random_unstructured(model.conv1, 'weight', 0.6)
  # Run an epoch of training
  train_running_loss = 0
  train_running_acc = 0
  model.train()
  for j,input in enumerate(trainloader,0):
   
    x = input[0].to(device)
    y = input[1].type(torch.LongTensor).to(device)
    out = model(x)
    loss = criterion(out,y)

    model.zero_grad()
    loss.backward()

    optimizer.step()

    _, predicted = torch.max(out.data, 1)
    correct = (predicted == y).sum()

    train_running_loss += loss.item()
    train_running_acc += correct.item()
    loss_log.append(loss.item())
    acc_log.append(correct.item()/len(y))

  train_running_loss /= j
  train_running_acc /= len(train_data)

  # Evaluate on validation
  val_acc = 0
  val_loss = 0
  model.eval()
  for j,input in enumerate(valloader,0):

    x = input[0].to(device)
    y = input[1].type(torch.LongTensor).to(device)

    
    out = model(x)

    loss = criterion(out,y)
    _, predicted = torch.max(out.data, 1)
    correct = (predicted == y).sum()

    val_acc += correct.item()
    val_loss += loss.item()

  val_acc /= len(val_data)
  val_loss /= j
  prune.remove(model.conv1, 'weight')
  prune.remove(model.conv2, 'weight')
  prune.remove(model.conv3, 'weight')
  prune.remove(model.conv4, 'weight')
  prune.remove(model.fc1, 'weight')
  prune.remove(model.fc2, 'weight')
  if val_acc > best_val_acc and i > 2:
    best_val_acc = val_acc
    print("saving model")
    torch.save({
                'epoch': i+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, home+'/outputs/prune45.pth')
    
  val_acc_log.append(val_acc)

  val_loss_log.append(val_loss)

  logging.info("[Epoch {:3}]   Loss:  {:8.4}     Train Acc:  {:8.4}%      Val Acc:  {:8.4}%".format(i,train_running_loss, train_running_acc*100,val_acc*100))



In [9]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(pytorch_total_params)

2165795
