In [86]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

import datasets
import networks

In [83]:
BATCH_SIZE = 256
LR = 1e-3
EPOCHS = 3
DEVICE = torch.device('mps')
DATASET = 'mnist'

In [84]:
if DATASET == 'mnist':
    dataset = datasets.get_mnist()
elif DATASET == 'cifar':
    dataset = datasets.get_cifar()

train_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=BATCH_SIZE, shuffle=True)

In [87]:
model = networks.MLP(hdims=[1000,500,250,100]).to(DEVICE)
optimiser = torch.optim.Adam(model.parameters(), lr=LR)
loss = nn.CrossEntropyLoss()
prog_bar = tqdm(total=EPOCHS * len(train_loader))

for epoch in range(EPOCHS):
    for step, (x, y_hat) in enumerate(train_loader):
        x, y_hat = x.to(DEVICE), y_hat.to(DEVICE)
        y = model(x)
        l = loss(y, y_hat)
        optimiser.zero_grad()
        l.backward()
        optimiser.step()
        prog_bar.update(1)
        if step % 100 == 0:
            prog_bar.set_description(f'epoch: {epoch} loss: {l.detach().cpu().item()}')
    

  0%|          | 0/705 [00:00<?, ?it/s]

In [88]:
accuracies = []
for x, y_hat in test_loader:
    x, y_hat = x.to(DEVICE), y_hat.to(DEVICE)
    with torch.no_grad():
        y = model(x)
    pred = torch.argmax(y, dim=-1)
    accuracy = (y_hat == pred).to(torch.float).mean().cpu().item()
    accuracies.append(accuracy)
print(f'Accuracy: {sum(accuracies)/len(accuracies)}')

Accuracy: 0.9646484375


## Pruning network

In [28]:
import torch.nn.utils.prune as prune

for name, param in model.named_parameters():
    print(name, param.shape)

fc1.weight torch.Size([500, 3072])
fc1.bias torch.Size([500])
fc2.weight torch.Size([100, 500])
fc2.bias torch.Size([100])
fc3.weight torch.Size([10, 100])
fc3.bias torch.Size([10])


In [30]:
parameters_to_prune = (
    # (model.conv1, 'weight'),
    # (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.9
)

In [31]:
sparsities = []
for name, param in model.named_buffers():
    sparsity = 100 * (param == 0.0).sum() / param.nelement()
    sparsities.append((param.nelement(), sparsity))
    print(name, param.shape, sparsity)
print('global sparsity:', sum([nel*s for nel,s in sparsities])/sum([nel for nel,_ in sparsities]))

fc1.weight_mask torch.Size([500, 3072]) tensor(91.2211, device='mps:0')
fc2.weight_mask torch.Size([100, 500]) tensor(53.8440, device='mps:0')
fc3.weight_mask torch.Size([10, 100]) tensor(22.2000, device='mps:0')
global sparsity: tensor(90.0000, device='mps:0')


In [32]:
accuracies = []
for x, y_hat in test_loader:
    x, y_hat = x.to(DEVICE), y_hat.to(DEVICE)
    with torch.no_grad():
        y = model(x)
    pred = torch.argmax(y, dim=-1)
    accuracy = (y_hat == pred).to(torch.float).mean().cpu().item()
    accuracies.append(accuracy)
print(f'Accuracy: {sum(accuracies)/len(accuracies)}')

Accuracy: 0.74951171875


In [33]:
# remove 'reparameterisation' and make pruning permanent
for module, pname in parameters_to_prune:
    prune.remove(module, pname)