In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable
import torch.nn.functional as F

from torchvision import datasets, transforms
from utils import build_dataset
from tqdm import tqdm_notebook

In [2]:
from nets.MobileNet_CIFAR import MobileNet_CIFAR

In [3]:
class Solver(object):
    def __init__(self, model, dataset='MNIST', n_epochs=100, lr=0.001):
        self.n_epochs = n_epochs
        
        self.train_loader, self.test_loader = build_dataset(dataset, './data', batch_size = 256)
        
        self.image_dim = {'MNIST': 28*28, 'CIFAR10': 3*32*32}[dataset]
                
        self.net =model().cuda()
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
                
    def train(self):
        self.net.train()
        
        for epoch_i in tqdm_notebook(range(self.n_epochs)):
            epoch_i += 1
            epoch_loss = 0
            epoch_kl = 0
            for images, labels in self.train_loader:
                images = Variable(images).cuda()
                labels = Variable(labels).cuda()
                
                logits = self.net(images)
                
                loss = self.loss_fn(logits, labels)
                
                total_loss = loss

                self.optimizer.zero_grad()
                total_loss.backward()

                self.optimizer.step()
                
                epoch_loss += float(loss.data)
            
            epoch_loss /= len(self.train_loader.dataset)
            print(f'Epoch {epoch_i} | loss: {epoch_loss:.4f}')
            self.evaluate()
            
    def evaluate(self):
        total = 0
        correct = 0
        self.net.eval()
        for images, labels in self.test_loader:
            images = Variable(images).cuda()

            logits = self.net(images.cuda())
            
            _, predicted = torch.max(logits.data, 1)
            
            total += labels.size(0)
            correct += (predicted.cpu() == labels).sum()
                
        print(f'Accuracy: {100 * correct / total:.2f}%')

In [4]:
standard_solver = Solver(MobileNet_CIFAR, 'CIFAR10')

standard_solver.train()

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch 1 | loss: 0.0071
Accuracy: 43.00%
Epoch 2 | loss: 0.0061
Accuracy: 47.00%
Epoch 3 | loss: 0.0052
Accuracy: 51.00%


KeyboardInterrupt: 

In [None]:
nn.AvgPool2d?

##### 