In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
import numpy as np
from utils import *
from model import *

In [20]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Model, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        return F.softmax(out, dim=1)

In [21]:
from model import Model
model = Model(784, 128, 10)
model.load_state_dict(torch.load('model.ckpt'))

<All keys matched successfully>

In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
test_loader = torch.utils.data.DataLoader(MNIST('data', train=False, download=False, transform=ToTensor()), batch_size=64, shuffle=True)
print('Acc: ', calculate_acc(model, test_loader, device))
print('Infer time: ',calculate_inference_time(model))
print('Sparsity: ', calculate_sparsity(model))
print('FLOPs: ', measure_flops(model, device=device))
print('Model size: ', get_model_size(model))

Acc:  0.9581
Infer time:  0.0
Sparsity:  0.0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
FLOPs:  101632.0
Model size:  0.38823699951171875


In [23]:
model.load_state_dict(torch.load('prunedmodel/pruned_model.pth'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
test_loader = torch.utils.data.DataLoader(MNIST('data', train=False, download=False, transform=ToTensor()), batch_size=64, shuffle=True)
print('Acc: ', calculate_acc(model, test_loader, device))
print('Infer time: ',calculate_inference_time(model))
print('Sparsity: ', calculate_sparsity(model))
print('FLOPs: ', measure_flops(model, device=device))
print('Model size: ', get_model_size(model))

Acc:  0.0385
Infer time:  0.0009968280792236328
Sparsity:  9.83942065491128e-06
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
FLOPs:  101632.0
Model size:  0.38823699951171875
