In [1]:
!pip3 install torch torchvision



In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [0]:
device = torch.device("cuda")

In [0]:
batch_size = 128
lr = 3e-4
drop_rate = 0.8
targ_per = 0.8
num_epochs = 3

In [0]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)


In [0]:
class Net(nn.Module):
    def __init__(self, input_size, hidden_size, drop_rate, targ_perc):
        super(Net, self).__init__()
        self.module_list = nn.ModuleList([nn.Linear(input_size, hidden_size)
                                          if i == 0 else nn.Linear(hidden_size, hidden_size) for i in range(2)])
        self.logit_layer = nn.Linear(hidden_size, 10)
        self.drop_rate = drop_rate
        self.targ_perc = targ_perc

    def forward(self, inputs):
        for layer in self.module_list:
            weight = self.targeted_weight_dropout(layer.weight,
                                                  self.drop_rate,
                                                  self.targ_perc,
                                                  is_training=True)
            inputs = F.relu(F.linear(inputs, weight, layer.bias))
        logits = self.logit_layer(inputs)

        return logits

    @staticmethod
    def targeted_weight_dropout(weight, drop_rate, targ_perc, is_training):
        weight_size = list(weight.size())
        norm = torch.abs(weight)
        idx = int(targ_perc * float(weight_size[0]))
        threshold = torch.sort(norm, dim=0)[0][idx]
        mask = (norm < threshold).to(weight.device)

        if not is_training:
            weight = (1.0 - mask).float() * weight
            return weight

        dropout_mask = (torch.rand(weight_size) < drop_rate)
        dropout_mask = dropout_mask.to(weight.device)
        mask = (dropout_mask * mask).float()
        weight = (1.0 - mask) * weight
        return weight

    def prune_weights(self):
        for layer in self.module_list:
            weight = layer.weight
            pruned_weight = self.targeted_weight_dropout(weight,
                                                         self.drop_rate,
                                                         self.targ_perc,
                                                         is_training=False)
            layer.weight = nn.Parameter(pruned_weight)


**Pruning**

In [7]:
net = Net(784, 100, drop_rate, targ_per).to(device)
opt = torch.optim.Adam(net.parameters(), lr)

criterion = nn.CrossEntropyLoss()
for epoch in range(1, num_epochs + 1):
    for i, (train_x, train_y) in enumerate(train_loader):
        train_x = train_x.to(device).view(-1, 784)
        train_y = train_y.to(device)
        preds = net(train_x)
        loss = criterion(preds, train_y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if (i + 1) % 100 == 0:
            print("Train epoch:[{}/{}] loss :{:.4f}".format(epoch, num_epochs, loss))

# after training is done, prune(make sparse) weights
net.prune_weights()

with torch.no_grad():
    total = 0
    correct = 0
    for test_x, test_y in test_loader:
        test_x = test_x.view(-1, 784).to(device)
        test_y = test_y.to(device)
        logits = net(test_x)
        preds = torch.argmax(logits, 1)
        correct += (preds == test_y).sum().item()
        total += test_y.size(0)
    print("accuracy : {:.2f}".format(correct / total * 100))
    num_zeros = 0
    for layer in net.module_list:
        num_zeros += (layer.weight == 0.0).sum().item()
    print("number of zero weights: ", num_zeros)


Train epoch:[1/3] loss :0.9202
Train epoch:[1/3] loss :0.4517
Train epoch:[1/3] loss :0.3398
Train epoch:[1/3] loss :0.3029
Train epoch:[2/3] loss :0.3967
Train epoch:[2/3] loss :0.3733
Train epoch:[2/3] loss :0.2724
Train epoch:[2/3] loss :0.2756
Train epoch:[3/3] loss :0.2495
Train epoch:[3/3] loss :0.2526
Train epoch:[3/3] loss :0.2604
Train epoch:[3/3] loss :0.1744
accuracy : 93.62
number of zero weights:  70720


**No pruning**

In [0]:
class VanillaNet(nn.Module):
    def __init__(self, input_size, hidden_size, drop_rate, targ_perc):
        super(VanillaNet, self).__init__()
        self.module_list = nn.ModuleList([nn.Linear(input_size, hidden_size)
                                          if i == 0 else nn.Linear(hidden_size, hidden_size) for i in range(2)])
        self.logit_layer = nn.Linear(hidden_size, 10)
        self.drop_rate = drop_rate
        self.targ_perc = targ_perc

    def forward(self, inputs):
        for layer in self.module_list:
            inputs = F.relu(layer(inputs))            
        logits = self.logit_layer(inputs)

        return logits

In [9]:
vanilla_net = VanillaNet(784, 100, drop_rate=0.0, targ_perc=0.0).to(device)
opt2 = torch.optim.Adam(vanilla_net.parameters(), lr)

criterion = nn.CrossEntropyLoss()
for epoch in range(1, num_epochs + 1):
    for i, (train_x, train_y) in enumerate(train_loader):
        train_x = train_x.to(device).view(-1, 784)
        train_y = train_y.to(device)
        preds = vanilla_net(train_x)
        loss = criterion(preds, train_y)
        opt2.zero_grad()
        loss.backward()
        opt2.step()
        if (i + 1) % 100 == 0:
            print("Train epoch:[{}/{}] loss :{:.4f}".format(epoch, num_epochs, loss))


with torch.no_grad():
    total = 0
    correct = 0
    for test_x, test_y in test_loader:
        test_x = test_x.view(-1, 784).to(device)
        test_y = test_y.to(device)
        logits = vanilla_net(test_x)
        preds = torch.argmax(logits, 1)
        correct += (preds == test_y).sum().item()
        total += test_y.size(0)
    print("accuracy : {:.2f}".format(correct / total * 100))
    num_zeros = 0
    for layer in vanilla_net.module_list:
        num_zeros += (layer.weight == 0.0).sum().item()
    print("number of zero weights: ", num_zeros)

Train epoch:[1/3] loss :0.5412
Train epoch:[1/3] loss :0.4184
Train epoch:[1/3] loss :0.3136
Train epoch:[1/3] loss :0.2389
Train epoch:[2/3] loss :0.1732
Train epoch:[2/3] loss :0.1569
Train epoch:[2/3] loss :0.2125
Train epoch:[2/3] loss :0.1876
Train epoch:[3/3] loss :0.1348
Train epoch:[3/3] loss :0.1060
Train epoch:[3/3] loss :0.1090
Train epoch:[3/3] loss :0.2332
accuracy : 95.72
number of zero weights:  0
