In [13]:
from models_to_prune import *
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import numpy as np

import torchvision.transforms as tf
import torchvision.datasets as ds
import torch.utils.data as data

import os
import time
import torch

# setup dataset
cifar10_trans = tf.Compose([tf.ToTensor(),
                            tf.Normalize((0.5,0.5,0.5),
                                         (0.5,0.5,0.5))])

train = ds.CIFAR10(root = os.getcwd(),
                   train = True,
                   download = True,
                   transform = cifar10_trans)
train_loader = data.DataLoader(train,
                               batch_size = 64,
                               shuffle = True,
                               num_workers = 0)

test = ds.CIFAR10(root = os.getcwd(),
                  train = False,
                  download = True,
                  transform = cifar10_trans)
test_loader = data.DataLoader(test,
                              batch_size = 64, # testing use less 
                                               # memory, can afford 
                                               # larger batch_size
                              shuffle = False,
                              num_workers = 0)

#set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Files already downloaded and verified
Files already downloaded and verified
cuda:0


In [14]:


# Train
model = BasicCNN()
model.to(device)

loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.0008)

num_epochs = 100
best_val_acc = -1

print('Training CNN model')

for epoch in range(num_epochs):
    n_correct, n_total = 0, 0
    start_time = time.time()
    for idx, train_data in enumerate(train_loader):
        model.train()
        inputs, labels = train_data[0].to(device), train_data[1].to(device)

        optimizer.zero_grad()

        # forward
        preds = model(inputs) # forward pass
        loss = loss_func(preds,labels) # compute loss

        # backward
        loss.backward()  # compute grads
        optimizer.step() # update params w/ Adam update rule        

        # include batch accuracy to epoch training accuracy calculation
        _, prediction = torch.max(preds, dim=1) # idx w/ max val is
                                                # most confident class
        n_correct += (prediction == labels).sum().item()
        n_total += prediction.size(0)
        train_acc = n_correct/n_total

        # periodically evaluate performance on validation set 
        if idx % 500 == 0:            
            model.eval();           # switch model to evaluation mode
            correct, total = 0, 0   

            # calculate accuracy on validation set
            with torch.no_grad():
                for val_inputs, val_labels in test_loader:
                    val_inputs = val_inputs.to(device)
                    val_labels = val_labels.to(device)
                    answer = model(val_inputs)
                    val_loss = loss_func(answer, val_labels)
                    _, val_preds = torch.max(answer, dim=1)
                    total += val_labels.size(0)
                    correct += (val_preds == val_labels).sum().item()
                val_acc = correct/total

            # print progress
            elapsed_time = time.time() - start_time
            str_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
            print('Epoch [{}/{}] Step [{}/{}] | Loss: {:.4f} Acc: {:.4f} Val_Loss: {:.4f} Val_Acc: {:.4f} Time: {}'
            .format(epoch+1, num_epochs, idx+1, len(train_loader), 
                    loss.item(), train_acc, 
                    val_loss.item(), val_acc,
                    str_time))

            # update best valiation set accuracy
            if val_acc >= best_val_acc:
                # found a model with better validation set accuracy
                snapshot_prefix = os.path.join(os.getcwd(), 'best_snapshot')
                snapshot_path = snapshot_prefix + '.pt'# '_val_acc_{}_val_loss_{}__epoch-iter_{}-{}_model.pt'.format(val_acc[-1], val_loss.item(), epoch, idx)

                # save model, delete previous 'best_snapshot' files
                torch.save(model, snapshot_path)
                best_val_acc = val_acc
                print("Model saved @ val acc =", best_val_acc)
            
print('Training Done')

Training CNN model
Epoch [1/100] Step [1/782] | Loss: 2.3001 Acc: 0.1094 Val_Loss: 2.3023 Val_Acc: 0.1132 Time: 00:00:02
Model saved @ val acc = 0.1132
Epoch [1/100] Step [501/782] | Loss: 2.0512 Acc: 0.3792 Val_Loss: 1.9369 Val_Acc: 0.4469 Time: 00:00:17
Model saved @ val acc = 0.4469
Epoch [2/100] Step [1/782] | Loss: 1.9666 Acc: 0.4844 Val_Loss: 1.9841 Val_Acc: 0.4759 Time: 00:00:02
Model saved @ val acc = 0.4759
Epoch [2/100] Step [501/782] | Loss: 1.9649 Acc: 0.4799 Val_Loss: 2.0285 Val_Acc: 0.5009 Time: 00:00:17
Model saved @ val acc = 0.5009
Epoch [3/100] Step [1/782] | Loss: 1.8429 Acc: 0.6094 Val_Loss: 1.7965 Val_Acc: 0.5169 Time: 00:00:02
Model saved @ val acc = 0.5169
Epoch [3/100] Step [501/782] | Loss: 1.9614 Acc: 0.5187 Val_Loss: 2.0381 Val_Acc: 0.5294 Time: 00:00:16
Model saved @ val acc = 0.5294
Epoch [4/100] Step [1/782] | Loss: 1.9089 Acc: 0.5469 Val_Loss: 1.9606 Val_Acc: 0.5394 Time: 00:00:02
Model saved @ val acc = 0.5394
Epoch [4/100] Step [501/782] | Loss: 1.9077 

Epoch [34/100] Step [1/782] | Loss: 1.6395 Acc: 0.8125 Val_Loss: 1.8322 Val_Acc: 0.7357 Time: 00:00:02
Epoch [34/100] Step [501/782] | Loss: 1.6952 Acc: 0.7822 Val_Loss: 1.6522 Val_Acc: 0.7378 Time: 00:00:17
Epoch [35/100] Step [1/782] | Loss: 1.6066 Acc: 0.8438 Val_Loss: 1.6488 Val_Acc: 0.7528 Time: 00:00:02
Model saved @ val acc = 0.7528
Epoch [35/100] Step [501/782] | Loss: 1.5844 Acc: 0.7800 Val_Loss: 1.7113 Val_Acc: 0.7452 Time: 00:00:16
Epoch [36/100] Step [1/782] | Loss: 1.6029 Acc: 0.8594 Val_Loss: 1.5954 Val_Acc: 0.7488 Time: 00:00:02
Epoch [36/100] Step [501/782] | Loss: 1.6830 Acc: 0.7871 Val_Loss: 1.7664 Val_Acc: 0.7564 Time: 00:00:16
Model saved @ val acc = 0.7564
Epoch [37/100] Step [1/782] | Loss: 1.6568 Acc: 0.8125 Val_Loss: 1.7732 Val_Acc: 0.7533 Time: 00:00:02
Epoch [37/100] Step [501/782] | Loss: 1.7504 Acc: 0.7914 Val_Loss: 1.7150 Val_Acc: 0.7501 Time: 00:00:16
Epoch [38/100] Step [1/782] | Loss: 1.6102 Acc: 0.8594 Val_Loss: 1.6714 Val_Acc: 0.7450 Time: 00:00:02
Epo

Epoch [71/100] Step [501/782] | Loss: 1.6857 Acc: 0.8343 Val_Loss: 1.7111 Val_Acc: 0.7684 Time: 00:00:25
Epoch [72/100] Step [1/782] | Loss: 1.6344 Acc: 0.8281 Val_Loss: 1.6225 Val_Acc: 0.7745 Time: 00:00:03
Model saved @ val acc = 0.7745
Epoch [72/100] Step [501/782] | Loss: 1.5393 Acc: 0.8367 Val_Loss: 1.6448 Val_Acc: 0.7723 Time: 00:00:37
Epoch [73/100] Step [1/782] | Loss: 1.5881 Acc: 0.8750 Val_Loss: 1.7109 Val_Acc: 0.7731 Time: 00:00:02
Epoch [73/100] Step [501/782] | Loss: 1.6174 Acc: 0.8389 Val_Loss: 1.5549 Val_Acc: 0.7744 Time: 00:00:22
Epoch [74/100] Step [1/782] | Loss: 1.5689 Acc: 0.8906 Val_Loss: 1.6432 Val_Acc: 0.7778 Time: 00:00:03
Model saved @ val acc = 0.7778
Epoch [74/100] Step [501/782] | Loss: 1.6380 Acc: 0.8393 Val_Loss: 1.6906 Val_Acc: 0.7743 Time: 00:00:29
Epoch [75/100] Step [1/782] | Loss: 1.6600 Acc: 0.7969 Val_Loss: 1.7024 Val_Acc: 0.7771 Time: 00:00:03
Epoch [75/100] Step [501/782] | Loss: 1.5475 Acc: 0.8411 Val_Loss: 1.7106 Val_Acc: 0.7725 Time: 00:00:40
E

In [16]:
# Load Best Model
model = torch.load("best_snapshot.pt")
model.eval()

correct = 0
total = 0
# Test Model
with torch.no_grad():
    for val_inputs, val_labels in test_loader:
        val_inputs = val_inputs.to(device)
        val_labels = val_labels.to(device)
        answer = model(val_inputs)
        _, val_preds = torch.max(answer, dim=1)
        total += val_labels.size(0)
        correct += (val_preds == val_labels).sum().item()
    
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))    

Accuracy of the network on the 10000 test images: 78 %


## All blocks below are just test codes

In [None]:
cnn = BasicCNN()
for name, layer in cnn.named_modules():
    if 'conv' in name:
        filters = layer.weight.data.clone()
        print(name,':',filters.size())
        pooled_filter = torch.flatten(F.avg_pool2d(filters,
                                                   filters.size()[-1]))
        print("pooled :",pooled_filter.size())

In [None]:
64*8

In [43]:
print(os.getcwd())

/content
