In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import time

from prune import *

In [2]:
batch_size_train = 64
batch_size_test = 1000

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.smx = nn.LogSoftmax(dim=1)

    def forward(self, x):      
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.smx(x)
        return x

net = Net()
#net

In [4]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)
        self.smx = nn.LogSoftmax(dim=1)

    def forward(self, x):      
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.smx(x)
        return x

net = Net()
#net

In [5]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)

In [6]:
#initialize locked masks:
locked_masks = {n: torch.zeros(w.size(), dtype=torch.bool) for n, w in net.named_parameters()}

In [9]:
#start time
start=time.time()

for prune_step in range(5):

    if prune_step > 0: 
        print('Start Pruning')
        prune(net,locked_masks, prune_random=False, prune_weight=True, prune_bias=True, ratio=0.25,
          threshold=None, threshold_bias=None, function=None, function_bias=None, prune_across_layers=True) 
        print('Done Pruning')
        
        correct(test_loader,net)
        
    prune_diag(net,locked_masks)
    
    for epoch in range(4):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(train_loader,0):
            
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.view(inputs.shape[0], -1)
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + prune_grad + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            net = prune_grad(net,locked_masks) #zeros gradients of the pruned weights
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 200 == 199:    # print every 200 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0

        correct(test_loader,net)
    print('Finished Training')

print('time: ',time.time()-start)

#set prunes to zero
#non-random pruning (threshold)

name     false  true  total
fc1.weight 100352 0 100352
fc1.bias 128 0 128
fc2.weight 1280 0 1280
fc2.bias 10 0 10
[1,   200] loss: 0.889
[1,   400] loss: 0.846
[1,   600] loss: 0.834
[1,   800] loss: 0.828
Accuracy: 74 %
[2,   200] loss: 0.798
[2,   400] loss: 0.770
[2,   600] loss: 0.608
[2,   800] loss: 0.507
Accuracy: 93 %


KeyboardInterrupt: 