In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable
import torch.nn.functional as F

from torchvision import datasets, transforms

from tensorboardX import SummaryWriter
from tqdm import tqdm_notebook

import time

from nets.MobileNet_CIFAR_LowRank import MobileNet_CIFAR_LowRank
from utils import build_dataset, flops_to_string, params_to_string, accuracy, file_size

In [2]:
class Solver(object):
    def __init__(self, model, model_params, dataset='MNIST', n_epochs=500, lr=0.001, batch_size=32):
        self.n_epochs = n_epochs
        
        self.d, self.K, self.pi_size = model_params['d'], model_params['K'], model_params['pi_size']
        
        self.train_loader, self.test_loader = build_dataset(dataset, './data', batch_size = batch_size)
        
        self.writer = SummaryWriter(f'runs/MobileNet_LowRank_(d={self.d}, K={self.K}, pi_size={self.pi_size})-{int(time.time())}')

        self.image_dim = {'MNIST': 28*28, 'CIFAR10': 3*32*32}[dataset]
                
        self.net = model(d=self.d, K=self.K, pi_size=self.pi_size).cuda()
        
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
                
    def train(self):
        self.net.train()
        self.best_loss = 999
        for epoch_i in tqdm_notebook(range(self.n_epochs)):
            epoch_loss = 0
            for batch_idx, (images, labels) in enumerate(self.train_loader):
                images = Variable(images).cuda()
                labels = Variable(labels).cuda()
                
                logits = self.net(images)
                
                loss = self.loss_fn(logits, labels)
                
                total_loss = loss

                self.optimizer.zero_grad()
                total_loss.backward()

                self.optimizer.step()
                
                epoch_loss += float(loss.data)    
            epoch_loss /= len(self.train_loader.dataset)
            
            # save best model
            if self.best_loss > epoch_loss:
                self.best_loss = epoch_loss
                torch.save(self.net.state_dict(), f'models/MobileNet_CIFAR_LowRank_(d={self.d}, K={self.K}, pi_size={self.pi_size})_best_loss.pth')
                
            print(f'Epoch {epoch_i} | loss: {epoch_loss:.8f}')
            
            niter = epoch_i*len(self.train_loader)+batch_idx
            self.writer.add_scalar('Train/Loss', loss.data[0], niter)
            self.evaluate(epoch_i)
            
        # load and test
        self.net.load_state_dict(torch.load(f'models/MobileNet_CIFAR_LowRank_(d={self.d}, K={self.K}, pi_size={self.pi_size})_best_loss.pth'))
        print (f'Best Loss {self.best_loss} | Its Acc:')
        self.evaluate(epoch_i+1)
        print (f'Best Top-1 acc over all {self.best_acc1}')
        print (f'Best Top-5 acc over all {self.best_acc5}')
        
        
        # Print info
        pytorch_total_params = sum(p.numel() for p in self.net.parameters())
        print (f'Model MobileNet_LowRank_(d={self.d}, K={self.K}, pi_size={self.pi_size}) | Params = ', params_to_string(pytorch_total_params))
        print('Model size: ', file_size(f'models/MobileNet_CIFAR_LowRank_(d={self.d}, K={self.K}, pi_size={self.pi_size})_best_loss.pth'))
        
        
    def evaluate(self, epoch_i = None):
        total = 0
        correct = 0
        self.net.eval()
        
        self.best_acc1 = 0
        self.best_acc5 = 0
        acc1_list, acc5_list = [], []
        for batch_idx, (images, labels) in enumerate(self.test_loader):
            images = Variable(images).cuda()

            logits = self.net(images)
            _, predicted = torch.max(logits.data, 1)

            acc1, acc5 = accuracy(logits.cpu(), labels.cpu(), topk=(1, 5))
            
            acc1_list.append(acc1)
            acc5_list.append(acc5)
        
        acc1 = float(sum(acc1_list)/len(self.test_loader))
        acc5 = float(sum(acc5_list)/len(self.test_loader))
        
        if self.best_acc1 < acc1:
            self.best_acc1 = acc1
            # Et tu, Overfit?
            torch.save(self.net.state_dict(), f'models/MobileNet_CIFAR_LowRank_(d={self.d}, K={self.K}, pi_size={self.pi_size})_best_acc.pth') 
        
        if self.best_acc5 < acc5:
            self.best_acc5 = acc5        
        
            total += labels.size(0)
            correct += (predicted.cpu() == labels).sum()
        if epoch_i is not None:
            self.writer.add_scalar('Test/Acc@1', acc1, epoch_i)
            self.writer.add_scalar('Test/Acc@5', acc5, epoch_i)
        print(f'Top-1 Accuracy: {acc1}')
        print(f'Top-5 Accuracy: {acc5}')

In [3]:
standard_solver = Solver(MobileNet_CIFAR_LowRank, {'d': 1, 'K':1, 'pi_size': 8},  'CIFAR10', n_epochs=500, lr=0.001, batch_size=128)

standard_solver.train()

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Epoch 0 | loss: 0.03426212




Top-1 Accuracy: 19.03861427307129
Top-5 Accuracy: 73.97492218017578

Best Loss 0.03426211676120758 | Its Acc:
Top-1 Accuracy: 19.03861427307129
Top-5 Accuracy: 73.97492218017578
Best Top-1 acc over all 19.03861427307129
Best Top-5 acc over all 73.97492218017578
Model MobileNet_LowRank_(d=1, K=1, pi_size=8) | Params =  17.15k
Model size:  108.0 KB


In [4]:
standard_solver = Solver(MobileNet_CIFAR_LowRank, {'d': 4, 'K':4, 'pi_size': 8},  'CIFAR10', n_epochs=500, lr=0.001, batch_size=128)

standard_solver.train()

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Epoch 0 | loss: 0.03018690




Top-1 Accuracy: 36.68391799926758
Top-5 Accuracy: 89.07244873046875

Best Loss 0.030186895580291747 | Its Acc:
Top-1 Accuracy: 36.68391799926758
Top-5 Accuracy: 89.07244873046875
Best Top-1 acc over all 36.68391799926758
Best Top-5 acc over all 89.07244873046875
Model MobileNet_LowRank_(d=4, K=4, pi_size=8) | Params =  46.7k
Model size:  232.8 KB


In [5]:
standard_solver = Solver(MobileNet_CIFAR_LowRank, {'d': 2, 'K':8, 'pi_size': 8},  'CIFAR10', n_epochs=500, lr=0.001, batch_size=128)

standard_solver.train()

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Epoch 0 | loss: 0.03076314




Top-1 Accuracy: 36.83320236206055
Top-5 Accuracy: 88.56488800048828

Best Loss 0.030763139152526854 | Its Acc:
Top-1 Accuracy: 36.83320236206055
Top-5 Accuracy: 88.56488800048828
Best Top-1 acc over all 36.83320236206055
Best Top-5 acc over all 88.56488800048828
Model MobileNet_LowRank_(d=2, K=8, pi_size=8) | Params =  47.05k
Model size:  246.8 KB


In [6]:
standard_solver = Solver(MobileNet_CIFAR_LowRank, {'d': 8, 'K':2, 'pi_size': 8},  'CIFAR10', n_epochs=500, lr=0.001, batch_size=128)

standard_solver.train()

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Epoch 0 | loss: 0.02982646




Top-1 Accuracy: 39.33121109008789
Top-5 Accuracy: 89.1222152709961

Best Loss 0.02982646359682083 | Its Acc:
Top-1 Accuracy: 39.33121109008789
Top-5 Accuracy: 89.1222152709961
Best Top-1 acc over all 39.33121109008789
Best Top-5 acc over all 89.1222152709961
Model MobileNet_LowRank_(d=8, K=2, pi_size=8) | Params =  46.52k
Model size:  225.8 KB
