# MobileNet

In [None]:
import argparse
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler

from mobilenets import mobilenetV1

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
valid_size=0.1

In [None]:
# define transforms
valid_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
])

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

# load the dataset
train_dataset = datasets.CIFAR10(root="../../datasets/cifar-data", train=True, 
            download=True, transform=train_transform)

valid_dataset = datasets.CIFAR10(root="../../datasets/cifar-data", train=True, 
            download=True, transform=valid_transform)

num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))

np.random.seed(42)
np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(train_dataset, 
                batch_size=128, sampler=train_sampler)

valid_loader = torch.utils.data.DataLoader(valid_dataset, 
                batch_size=128, sampler=valid_sampler)

test_transform = transforms.Compose([
    transforms.ToTensor(), normalize
])

test_dataset = datasets.CIFAR10(root="../../datasets/cifar-data", train=False, 
                                download=True, transform=test_transform)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
model = mobilenetV1(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
criterion = nn.CrossEntropyLoss()

In [None]:
# Implement validation

def train(epoch):
    model.train()
    train_loss = 0
    train_acc = 0
    train_n = 0
    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        output = model(data)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * label.size(0)
        train_acc += (output.max(1)[1] == label).sum().item()
        train_n += label.size(0)
    scheduler.step()
    print('Epoch: {}, lr:{}, Training loss: {:.4f}, Training acc: {:.4f}'.format(
            epoch, scheduler.get_last_lr(), train_loss/train_n, train_acc/train_n))

In [None]:
def validate():
    model.eval()
    valid_loss = 0
    valid_acc = 0
    valid_n = 0
    for data, label in valid_loader:
        data, label = data.to(device), label.to(device)
        output = model(data)
        loss = criterion(output, label)

        valid_loss += loss.item() * label.size(0)
        valid_acc += (output.max(1)[1] == label).sum().item()
        valid_n += label.size(0)

    print('Validation -- Validate loss: {:.4f}, Validate acc: {:.4f}'.format(
            valid_loss/valid_n,valid_acc/valid_n))

In [None]:
def test():
    model.eval()
    test_loss = 0
    test_acc = 0
    test_n = 0
    for data, label in test_loader:
        data, label = data.to(device), label.to(device)
        output = model(data)
        loss = criterion(output, label)

        test_loss += loss.item() * label.size(0)
        test_acc += (output.max(1)[1] == label).sum().item()
        test_n += label.size(0)

    print('Testing -- Test loss: {:.4f}, Test acc: {:.4f}'.format(
            test_loss/test_n,test_acc/test_n))

In [None]:
for epoch in range(40):
    train(epoch)
    if epoch % 5 == 0:
        validate()
test()