In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import numpy as np
import time

from prune import *

In [None]:
#TODO

In [3]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [4]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

    

In [5]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

In [6]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [7]:
def train(epoch,locked_masks,network):
    #network=Net()
    network.train()
    
    for prune_step in range(5):

        if prune_step > 0: 
            print('Start Pruning')
            prune(network,locked_masks, prune_random=False, prune_weight=True, prune_bias=False, ratio=0.5,
              threshold=None, threshold_bias=None, function=None, function_bias=None, prune_across_layers=True) 
            print('Done Pruning')

            #correct(test_loader,network)
        prune_diag(network,locked_masks) 
        test(network)
        
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = network(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            network = prune_grad(network,locked_masks) #zeros gradients of the pruned weights
            #for n, w in network.named_parameters():  
            #    if w.grad is not None and n in locked_masks: 
            #        w.grad[locked_masks[n]] = 0
            optimizer.step()
            if batch_idx % log_interval == 0:
                  print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                      epoch, batch_idx * len(data), len(train_loader.dataset),
                      100. * batch_idx / len(train_loader), loss.item()))
                  train_losses.append(loss.item())
                  train_counter.append(
                      (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
                  torch.save(network.state_dict(), './results/model.pth')
                  torch.save(optimizer.state_dict(), './results/optimizer.pth')
    test(network)

In [8]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [9]:
#initialize locked masks:
locked_masks = {n: torch.zeros(w.size(), dtype=torch.bool) for n, w in network.named_parameters()}

In [10]:
for n, w in network.named_parameters():
    print(n)
    print(w.size())

conv1.weight
torch.Size([10, 1, 5, 5])
conv1.bias
torch.Size([10])
conv2.weight
torch.Size([20, 10, 5, 5])
conv2.bias
torch.Size([20])
fc1.weight
torch.Size([50, 320])
fc1.bias
torch.Size([50])
fc2.weight
torch.Size([10, 50])
fc2.bias
torch.Size([10])


In [11]:

for epoch in range(1, n_epochs + 1):

    train(epoch,locked_masks,network)

#    test()

name       pruned percentage
conv1.weight 0.0
conv1.bias 0.0
conv2.weight 0.0
conv2.bias 0.0
fc1.weight 0.0
fc1.bias 0.0
fc2.weight 0.0
fc2.bias 0.0





Test set: Avg. loss: 2.3316, Accuracy: 1137/10000 (11%)

Start Pruning
Done Pruning
name       pruned percentage
conv1.weight 0.124
conv1.bias 0.0
conv2.weight 0.4524
conv2.bias 0.0
fc1.weight 0.531375
fc1.bias 0.0
fc2.weight 0.16
fc2.bias 0.0

Test set: Avg. loss: 0.1602, Accuracy: 9487/10000 (95%)

Start Pruning
Done Pruning
name       pruned percentage
conv1.weight 0.212
conv1.bias 0.0
conv2.weight 0.678
conv2.bias 0.0
fc1.weight 0.7965625
fc1.bias 0.0
fc2.weight 0.25
fc2.bias 0.0

Test set: Avg. loss: 0.1516, Accuracy: 9580/10000 (96%)

Start Pruning
Done Pruning
name       pruned percentage
conv1.weight 0.248
conv1.bias 0.0
conv2.weight 0.8242
conv2.bias 0.0
fc1.weight 0.9186875
fc1.bias 0.1
fc2.weight 0.298
fc2.bias 0.0

Test set: Avg. loss: 0.4134, Accuracy: 8639/10000 (86%)

Start Pruning
Done Pruning
name       pruned percentage
conv1.weight 0.332
conv1.bias 0.0
conv2.weight 0.9178
conv2.bias 0.0
fc1.weight 0.9693125
fc1.bias 0.24
fc2.weight 0.42
fc2.bias 0.0

Test set: Avg. 

In [None]:


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)

In [None]:
#initialize locked masks:
locked_masks = {n: torch.zeros(w.size(), dtype=torch.bool) for n, w in net.named_parameters()}

In [None]:
#start time
start=time.time()

for prune_step in range(5):
    
    if prune_step > 0: 
        print('Start Pruning')
        prune(net,locked_masks, prune_random=False, prune_weight=True, prune_bias=False, ratio=0.75,
          threshold=None, threshold_bias=None, function=None, function_bias=None, prune_across_layers=True) 
        print('Done Pruning')
        
        correct(test_loader,net)
    
    prune_diag(net,locked_masks)
    
    #print('prune diag time: ',time.time()-s2)
    
    for epoch in range(4):  # loop over the dataset multiple times
        
        running_loss = 0.0
        for i, data in enumerate(train_loader,0):
            
                
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.view(inputs.shape[0], -1)
                       
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + prune_grad + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            net = prune_grad(net,locked_masks) #zeros gradients of the pruned weights
            optimizer.step()
            

            # print statistics
            running_loss += loss.item()
            if i % 200 == 199:    # print every 200 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
            
        correct(test_loader,net)
        
    print('Finished Training')

print('time: ',time.time()-start)

In [None]:
for n, w in network.named_parameters():
    #print(len(w.data.numpy().shape))
    print(len(w.shape))
    print(w)