In [14]:
import torch
import numpy as np
from torchvision import datasets
import torchvision
import torchvision.transforms as transforms

# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 20

# convert data to torch.FloatTensor

use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu") 

PATH = "save/trained_cnn_model.pt"

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307), (0.3081))])


# choose the training and test datasets
train_data = datasets.MNIST(root='data', train=True, download=True,  transform=transform)
test_data  = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, 
    num_workers=num_workers)


import torch.nn as nn
import torch.nn.functional as F

## Define the NN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

model = Net().to(device)
print(model)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [None]:

# number of epochs to train the model
n_epochs = 10  
# per epoch, all the training data set is used once
model.train() # prep model for training


for epoch in range(n_epochs):
    # monitor training loss
    train_loss = 0.0
    
    ###################
    # train the model #
    ###################
    for data, target in train_loader:
        data, target = data.to(device), target.to(device) # loading to GPU
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the loss
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0) # as loss is tensor, .item() needed to get the value
        
    # print training statistics 
    # calculate average loss over an epoch
    train_loss = train_loss/len(train_loader.dataset)

    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch+1, train_loss))
    
# see following link for details of state_dict   
# https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss,
            }, PATH)

In [2]:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['loss']

model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


Test set: Accuracy: 9897/10000 (99%)



In [16]:

import torch.nn.utils.prune as prune



prune.l1_unstructured(model.conv1, name='weight', amount=0.9)
prune.l1_unstructured(model.conv2, name='weight', amount=0.9)
prune.l1_unstructured(model.fc1, name='weight',   amount=0.9)
prune.l1_unstructured(model.fc2, name='weight',   amount=0.9)



'''
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.9,
)
'''

"\nparameters_to_prune = (\n    (model.conv1, 'weight'),\n    (model.conv2, 'weight'),\n    (model.fc1, 'weight'),\n    (model.fc2, 'weight'),\n)\n\nprune.global_unstructured(\n    parameters_to_prune,\n    pruning_method=prune.L1Unstructured,\n    amount=0.9,\n)\n"

In [None]:
list(model.named_parameters())

In [None]:
model.conv1.weight

In [None]:
model.conv1.weight_mask

In [None]:
print(model.conv1.weight_orig)

In [13]:
mask1 = model.conv1.weight_mask
mask2 = model.conv2.weight_mask
mask3 = model.fc1.weight_mask
mask4 = model.fc2.weight_mask
sparsity_mask1 = (mask1 == 0).sum() / mask1.nelement()
sparsity_mask2 = (mask2 == 0).sum() / mask2.nelement()
sparsity_mask3 = (mask3 == 0).sum() / mask3.nelement()
sparsity_mask4 = (mask4 == 0).sum() / mask4.nelement()
print("Conv1: ", sparsity_mask1)
print("Conv2: ", sparsity_mask2)
print("FC1:   ", sparsity_mask3)
print("FC2:   ", sparsity_mask4)

total_zeros = (mask1 == 0).sum() + (mask2 == 0).sum() + (mask3 == 0).sum() + (mask4 == 0).sum()
total_elements = mask1.nelement() + mask2.nelement() + mask3.nelement() + mask4.nelement()

print("total: ", total_zeros / total_elements)

Conv1:  tensor(0.8993, device='cuda:0')
Conv2:  tensor(0.9000, device='cuda:0')
FC1:    tensor(0.9000, device='cuda:0')
FC2:    tensor(0.9000, device='cuda:0')
total:  tensor(0.9000, device='cuda:0')


In [8]:

model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


Test set: Accuracy: 9866/10000 (99%)



In [None]:
### This cell removes the weight_orig and weight_mask, but only store the pruned weight
### Note that if you want to fine-tune with the next cell, you should not run this cell
### If you run this cell and finetune, the pruned weight will be updated again.
prune.remove(model.conv1, 'weight')
prune.remove(model.conv2, 'weight')
prune.remove(model.fc1, 'weight')
prune.remove(model.fc2, 'weight')

In [6]:

# number of epochs to train the model
n_epochs = 10  
# per epoch, all the training data set is used once
model.train() # prep model for training


for epoch in range(n_epochs):
    # monitor training loss
    train_loss = 0.0
    
    ###################
    # train the model #
    ###################
    for data, target in train_loader:
        data, target = data.to(device), target.to(device) # loading to GPU
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the loss
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0) # as loss is tensor, .item() needed to get the value
        
    # print training statistics 
    # calculate average loss over an epoch
    train_loss = train_loss/len(train_loader.dataset)

    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch+1, train_loss))
    


Epoch: 1 	Training Loss: 0.367446
Epoch: 2 	Training Loss: 0.211514
Epoch: 3 	Training Loss: 0.178869
Epoch: 4 	Training Loss: 0.157241
Epoch: 5 	Training Loss: 0.141307
Epoch: 6 	Training Loss: 0.135052
Epoch: 7 	Training Loss: 0.128355
Epoch: 8 	Training Loss: 0.119584
Epoch: 9 	Training Loss: 0.115202
Epoch: 10 	Training Loss: 0.110699


In [7]:
# see following link for details of state_dict   
# https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
PATH_prune = "save/trained_cnn_model_pruned.pt"

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss,
            }, PATH_prune)

In [None]:
# After above finetuning, run the 7th cell to check the sparsity.
# Then, check the accuracy by running the 8th cell

In [17]:
## Now, after running the first cell, let's try to load the stored model.
## It won't work because the named_parameters has been changed.
## Thus, run 4th cell to make a pruned model.
## Then, run this cell.

PATH_prune = "save/trained_cnn_model_pruned.pt"
checkpoint = torch.load(PATH_prune)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['loss']

model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


Test set: Accuracy: 9866/10000 (99%)

