In [63]:
import torch
import torch.nn as nn
from utils.readData import read_dataset
from utils.ResNet import ResNet18

In [64]:
# set device and parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_class = 10
batch_size = 100
train_loader,valid_loader,test_loader = read_dataset(batch_size=batch_size,pic_path='dataset')

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [65]:
#loading model
model = ResNet18()
model.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
model.fc = torch.nn.Linear(512, n_class) 
model.load_state_dict(torch.load('checkpoint/resnet18_cifar10.pt'))
model = model.to(device)

In [66]:
# Count the number of parameters after pruning
total_params_pruned = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters after pruning: ", total_params_pruned)


Total number of parameters after pruning:  11173962


In [67]:
import torch.nn.utils.prune as prune
import torch.optim as optim

# Prune the model
parameters_to_prune = []
for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.56
)


# Remove pruned parameters
for module, name in parameters_to_prune:
    prune.remove(module, name)
    
# Count the number of parameters after pruning 
total_params_pruned = sum(torch.count_nonzero(p).item() for p in model.parameters() if p.requires_grad)
print("Total number of non-zero parameters after pruning: ", total_params_pruned)


Total number of non-zero parameters after pruning:  4924792


In [68]:

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data, target in test_loader:
        data = data.to(device)
        target = target.to(device)
        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = 100 * correct / total
print('Accuracy of the pruned model on the test data: {:.2f}%'.format(accuracy))


Accuracy of the pruned model on the test data: 94.47%
