In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
from pytorch_resnet_cifar10.resnet import resnet56

  from .autonotebook import tqdm as notebook_tqdm


Load the CIFAR10 Data

In [2]:
import os
os.environ['DATAPATH'] = 'data'

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data = datasets.CIFAR10(root="data/CIFAR10", train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(data, batch_size=32, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/CIFAR10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12388835.32it/s]


Extracting data/CIFAR10/cifar-10-python.tar.gz to data/CIFAR10


Load ResNet-56 model for CIFAR10

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
model = resnet56()

#load the model with the correct weights
check_point = torch.load('pytorch_resnet_cifar10/pretrained_models/resnet56-4bfd9763.th', map_location=device)
data_parallel = torch.nn.DataParallel(model) # don't fully understand what this does, but it is needed
data_parallel.load_state_dict(check_point['state_dict'])

model.to(device)

None

Testing the pretrained model

In [4]:

print(np.shape(data.targets))
print(np.shape(data.data))

def get_accuracy(model, dataloader):
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, prediction = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (prediction == labels).sum().item()
    return correct / total


print(get_accuracy(model, dataloader))

(50000,)
(50000, 32, 32, 3)
0.9759


In [5]:
'''
def identityPrune(model):
    for module in model.modules():
        param_names = {name for name, _ in module.named_parameters(recurse = False)}
        for param_name in param_names:
            if not prune.is_pruned(module=module):
                prune.identity(module=module, name=param_name)

def getWeights(model):
    identityPrune(model)
    modules = list()
    weights = list()
    masks = list()
    for module in model.modules():
        if (hasattr(module, 'weight_orig') and hasattr(module, 'weight_mask')):
            modules.append(module)
            weights.append(getattr(module, 'weight_orig'))
            masks.append(getattr(module, 'weight_mask'))
    return modules, weights, masks

modules, weights, masks = getWeights(model)
'''

"\ndef identityPrune(model):\n    for module in model.modules():\n        param_names = {name for name, _ in module.named_parameters(recurse = False)}\n        for param_name in param_names:\n            if not prune.is_pruned(module=module):\n                prune.identity(module=module, name=param_name)\n\ndef getWeights(model):\n    identityPrune(model)\n    modules = list()\n    weights = list()\n    masks = list()\n    for module in model.modules():\n        if (hasattr(module, 'weight_orig') and hasattr(module, 'weight_mask')):\n            modules.append(module)\n            weights.append(getattr(module, 'weight_orig'))\n            masks.append(getattr(module, 'weight_mask'))\n    return modules, weights, masks\n\nmodules, weights, masks = getWeights(model)\n"

In [6]:
class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = 'unstructured'
    def __init__(self, threshold):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return (torch.abs(tensor) >= self.threshold).float() * default_mask

def global_mag_weight_prune(model, amount):
    parameters = list()

    all_relevant_weights = np.array([])

    for module in model.modules():
        if (hasattr(module, 'weight')):
            prune.identity(module, 'weight')
            parameters.append((module, 'weight'))

            weight_mask = getattr(module, 'weight_mask')
            weight = getattr(module, 'weight')

            all_relevant_weights = np.append(all_relevant_weights, torch.masked_select(weight, weight_mask.bool()).flatten().cpu().detach().numpy())

    threshold = np.percentile(np.abs(all_relevant_weights), amount * 100.0)

    prune.global_unstructured(
        parameters=parameters,
        pruning_method=ThresholdPruning,
        threshold = threshold
    )

def apply_permanent_prune(model, name):
    for module in model.modules():
        if (prune.is_pruned(module) and hasattr(module, name)):
            module = prune.remove(module, name)
            


None

In [7]:
model.conv1.weight.size()

torch.Size([16, 3, 3, 3])

In [8]:
global_mag_weight_prune(model, 0.9)
# apply_permanent_prune(model, 'weight')

def zero_percentage(model, name):
    zeros = 0.0
    total = 0.0
    for module in model.modules():
        if (hasattr(module, name)):
            zeros = zeros + float(torch.sum(module.weight == 0.0))
            total = total + float(module.weight.nelement())
    return zeros / total

print(zero_percentage(model, 'weight'))
print(get_accuracy(model, dataloader))

0.899999529951491
0.67662


In [9]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001, weight_decay=0.0001)

for epoch in range(1):
    running_loss = 0.0
    running_acc = 0.0

    for i, (images, labels) in enumerate(dataloader, 0):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()



In [10]:
print(zero_percentage(model, 'weight'))
print(get_accuracy(model, dataloader))

0.899999529951491
0.95352
