In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable
import torch.nn.functional as F

from torchvision import datasets, transforms
from utils import build_dataset
from tqdm import tqdm_notebook

In [2]:
class MobileNet_CIFAR(nn.Module):
    def __init__(self):
        super(MobileNet_CIFAR, self).__init__()

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, kernel_size=3, stride=stride, padding=1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),
    
                nn.Conv2d(inp, oup, kernel_size=1, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
            conv_dw( 32,  32, 1),
            conv_dw( 32,  32, 1),
            conv_dw( 32,  64, 1),
            
            conv_dw(64, 64, 2),
            conv_dw(64, 64, 2),
            conv_dw(64, 64, 2),
            conv_dw(64, 128, 2),

            conv_dw(128, 128, 2),
            conv_dw(128, 128, 2),
            conv_dw(128, 128, 2),
            conv_dw(128, 256, 2),
            nn.AvgPool2d(8, ceil_mode=True, count_include_pad=True),
        )
        self.fc = nn.Linear(256, 10)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 256)
        x = self.fc(x)
        return x

In [3]:
class Solver(object):
    def __init__(self, dataset='MNIST', n_epochs=100, lr=0.001):
        self.n_epochs = n_epochs
        
        self.train_loader, self.test_loader = build_dataset(dataset, './data', batch_size = 256)
        
        self.image_dim = {'MNIST': 28*28, 'CIFAR10': 3*32*32}[dataset]
                
        self.net = MobileNet_CIFAR().cuda()
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
                
    def train(self):
        self.net.train()
        
        for epoch_i in tqdm_notebook(range(self.n_epochs)):
            epoch_i += 1
            epoch_loss = 0
            epoch_kl = 0
            for images, labels in self.train_loader:
                images = Variable(images).cuda()
                labels = Variable(labels).cuda()
                
                logits = self.net(images)
                
                loss = self.loss_fn(logits, labels)
                
                total_loss = loss

                self.optimizer.zero_grad()
                total_loss.backward()

                self.optimizer.step()
                
                epoch_loss += float(loss.data)
            
            epoch_loss /= len(self.train_loader.dataset)
            print(f'Epoch {epoch_i} | loss: {epoch_loss:.4f}')
            self.evaluate()
            
    def evaluate(self):
        total = 0
        correct = 0
        self.net.eval()
        for images, labels in self.test_loader:
            images = Variable(images).cuda()

            logits = self.net(images.cuda())
            
            _, predicted = torch.max(logits.data, 1)
            
            total += labels.size(0)
            correct += (predicted.cpu() == labels).sum()
                
        print(f'Accuracy: {100 * correct / total:.2f}%')

In [None]:
standard_solver = Solver('CIFAR10')

standard_solver.train()

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch 1 | loss: 0.0070
Accuracy: 41.00%
Epoch 2 | loss: 0.0060
Accuracy: 49.00%
Epoch 3 | loss: 0.0052
Accuracy: 53.00%
Epoch 4 | loss: 0.0047
Accuracy: 55.00%
Epoch 5 | loss: 0.0044
Accuracy: 56.00%
Epoch 6 | loss: 0.0042
Accuracy: 60.00%
Epoch 7 | loss: 0.0040
Accuracy: 62.00%
Epoch 8 | loss: 0.0039
Accuracy: 63.00%
Epoch 9 | loss: 0.0037
Accuracy: 63.00%
Epoch 10 | loss: 0.0036
Accuracy: 64.00%
Epoch 11 | loss: 0.0035
Accuracy: 64.00%
Epoch 12 | loss: 0.0033
Accuracy: 64.00%
Epoch 13 | loss: 0.0033
Accuracy: 67.00%
Epoch 14 | loss: 0.0032
Accuracy: 66.00%
Epoch 15 | loss: 0.0031
Accuracy: 67.00%
Epoch 16 | loss: 0.0031
Accuracy: 68.00%
Epoch 17 | loss: 0.0030
Accuracy: 68.00%
Epoch 18 | loss: 0.0029
Accuracy: 68.00%
Epoch 19 | loss: 0.0029
Accuracy: 69.00%
Epoch 20 | loss: 0.0028
Accuracy: 68.00%
Epoch 21 | loss: 0.0027
Accuracy: 69.00%
Epoch 22 | loss: 0.0027
Accuracy: 68.00%
Epoch 23 | loss: 0.0026
Accuracy: 69.00%
Epoch 24 | loss: 0.0025
Accuracy: 70.00%
Epoch 25 | loss: 0.0025
A

In [None]:
nn.AvgPool2d?

##### 