In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable
import torch.nn.functional as F

from torchvision import datasets, transforms
from utils import build_dataset, flops_to_string, params_to_string
from tqdm import tqdm_notebook

import time
from tensorboardX import SummaryWriter
writer = SummaryWriter(f'runs/MobileNet_{int(time.time())}')

from thop import profile

In [2]:
from nets.LowRankLayer import *

class MobileNet_CIFAR_LowRank(nn.Module):
    def __init__(self):
        super(MobileNet_CIFAR_LowRank, self).__init__()

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, kernel_size=3, stride=stride, padding=1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),
                
                LowRankLayer(32*32, output_size=2 ** 8, d=8, K=2, pi_size=28, adaptive=True),
                nn.Sigmoid(),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
            conv_dw( 32,  32, 1),
            conv_dw( 32,  32, 1),
            conv_dw( 32,  64, 1),
            
            conv_dw(64, 64, 2),
            conv_dw(64, 64, 2),
            conv_dw(64, 64, 2),
            conv_dw(64, 128, 2),

            conv_dw(128, 128, 2),
            conv_dw(128, 128, 2),
            conv_dw(128, 128, 2),
            conv_dw(128, 256, 2),
            nn.AvgPool2d(8, ceil_mode=True, count_include_pad=True),
        )
        self.fc = nn.Linear(256, 10)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 256)
        x = self.fc(x)
        return x

In [3]:
class Solver(object):
    def __init__(self, model, dataset='MNIST', n_epochs=100, lr=0.001):
        self.n_epochs = n_epochs
        
        self.train_loader, self.test_loader = build_dataset(dataset, './data', batch_size = 4)
        
        self.image_dim = {'MNIST': 28*28, 'CIFAR10': 3*32*32}[dataset]
                
        self.net = model().cuda()
        
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
                
    def train(self):
        self.net.train()
        best_loss = 999
        for epoch_i in tqdm_notebook(range(self.n_epochs)):
            epoch_loss = 0
            for batch_idx, (images, labels) in enumerate(self.train_loader):
                images = Variable(images).cuda()
                labels = Variable(labels).cuda()
                
                logits = self.net(images)
                
                loss = self.loss_fn(logits, labels)
                
                total_loss = loss

                self.optimizer.zero_grad()
                total_loss.backward()

                self.optimizer.step()
                
                epoch_loss += float(loss.data)    
            epoch_loss /= len(self.train_loader.dataset)
            # save best model
            if best_loss > epoch_loss:
                best_loss = epoch_loss
                torch.save(self.net.state_dict(), 'models/MobileNet_CIFAR_test.pth')
                
            print(f'Epoch {epoch_i} | loss: {epoch_loss:.8f}')
            
            niter = epoch_i*len(self.train_loader)+batch_idx
            writer.add_scalar('Train/Loss', loss.data[0], niter)
            self.evaluate(epoch_i)
            
        
        self.net.load_state_dict(torch.load('models/MobileNet_CIFAR_test.pth'))
        print (f'Best Loss {best_loss} | Best')
        self.evaluate(epoch_i+1)
    
    def evaluate(self, epoch_i = None):
        total = 0
        correct = 0
        self.net.eval()
        for batch_idx, (images, labels) in enumerate(self.test_loader):
            images = Variable(images).cuda()

            logits = self.net(images.cuda())
            
            _, predicted = torch.max(logits.data, 1)
            
            acc1, acc5 = accuracy(logits.cpu(), labels.cpu(), topk=(1, 5))
            
            total += labels.size(0)
            correct += (predicted.cpu() == labels).sum()
        if epoch_i is not None:
            writer.add_scalar('Test/Acc@1', acc1, epoch_i)
            writer.add_scalar('Test/Acc@5', acc5, epoch_i)
        print(f'Top-1 Accuracy: {acc1}')
        print(f'Top-5 Accuracy: {acc2}')

In [4]:
standard_solver = Solver(MobileNet_CIFAR_LowRank, 'CIFAR10')

standard_solver.train()

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0), HTML(value='')))

Input shape:  torch.Size([4, 32, 32, 32])
Input shape after smth:  torch.Size([4, 32, 1024])
Output shape:  torch.Size([4, 32, 256])



ValueError: expected 4D input (got 3D input)

In [None]:
nn.BatchNorm2d?