In [1]:
import torch
import torchvision as tv
import numpy as np
from pytorch_resnet_cifar10 import resnet
import model_utils
import pruning

  from .autonotebook import tqdm as notebook_tqdm


Load the CIFAR10 Data

In [2]:
# perform the same transform on all the data
transform = tv.transforms.Compose(
    [tv.transforms.ToTensor(), # scale the data between 0..1
     tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # normalize the data

# get the cifar10 data
data_train = tv.datasets.CIFAR10(root="data/CIFAR10/train", train=True, download=True, transform=transform)
data_test = tv.datasets.CIFAR10(root="data/CIFAR10/test", train=False, download=True, transform=transform)

# using a 90-10 split of the data for training and validation
# set to a fixed seed so that results will be reproduced
data_train, data_val = torch.utils.data.random_split(data_train, [0.9, 0.1], generator=torch.Generator().manual_seed(31415))

# create the dataloaders for each dataset. These will be used for training and getting accuracy of a model
dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=128, shuffle=False)
dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=128, shuffle=False)
dataloader_test = torch.utils.data.DataLoader(data_test, batch_size=128, shuffle=False)

print(np.shape(data_test.data))

Files already downloaded and verified
Files already downloaded and verified
(10000, 32, 32, 3)


Load ResNet-56 model for CIFAR10

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
model = resnet.resnet56()

# load the model with the correct weights
# must use DataParallel since the data was saved as a DataParallel
torch.nn.DataParallel(model).load_state_dict(torch.load('pytorch_resnet_cifar10/pretrained_models/resnet56-4bfd9763.th', map_location=device)['state_dict'])

model.to(device)

None

Prepare the model:

In [4]:
model_utils.prepare_model(model)

['bias', 'weight']

Testing the pretrained model

In [5]:
# this is not data snooping. This is just verifying that the already trained model was loaded in correctly. 

print(model_utils.get_accuracy(model, dataloader_test, device))

0.9136


Apply Pruning

In [6]:
pruning.global_mag_weight_prune(model, 0.9)

Test Pruned Percentage and get the new accuracy without training

In [7]:
print(model_utils.pruned_percentage(model, 'weight'))
print(model_utils.get_accuracy(model, dataloader_val, device))

0.899999529951491
0.7044


Train

In [8]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001, weight_decay=0.0001)

for epoch in range(1):
    running_loss = 0.0
    running_acc = 0.0

    for (images, labels) in dataloader_train:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()



Make sure that the model is still mostly pruned and test the accuracy on test data

In [9]:
print(model_utils.pruned_percentage(model, 'weight'))
print(model_utils.get_accuracy(model, dataloader_test, device))

0.899999529951491
0.8895


Remove the pruning

In [None]:
model_utils.remove_pruning(model)

In [2]:
import time
start  = time.perf_counter()
time.sleep(2)
stop = time.perf_counter()

stop - start

2.002233248203993