In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets

import os

# model
from models import ResNet18, ResNet50, to_one_hot

In [None]:
# Data will be downloaded in the ./data directory

# configuration
config = {
'num_epochs':50,
'lr':0.1,
'weight_decay':5e-4,
'print_freq': 100,
'use_mixup': True,
'mixup_alpha':4.0,
'num_classes':10,
'dataset': 'FashionMNIST',
'in_channels': 1,
'batch_size':128
}

# set the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Dataset Transform CIFAR-10
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


transform_train_FMNIST = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

transform_test_FMNIST = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])


# dataset cifar
if config['dataset'] == 'cifar10':
    trainset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    testset = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
elif config['dataset'] == "FashionMNIST":
    trainset = datasets.FashionMNIST(
        root='./data', train=True, download=True, transform=transform_train_FMNIST)
    testset = datasets.FashionMNIST(
        root='./data', train=False, download=True, transform=transform_test_FMNIST)
else:
    print('Invalid Dataset')

    
print('=> Number of images in train set {}'.format(len(trainset)))
print('=> Number of images in test set: {}'.format(len(testset)))    

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

In [None]:
def test(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _ = net(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.*correct/total
    return acc

In [None]:
print('Configuration------------------------------------')
print(config)
def train(cfg):
    
    # create model
    print('=> Creating the model')
    model = ResNet18(num_classes=cfg['num_classes'],in_channels=cfg['in_channels'])
    model = model.to(device)
    
    # loss function, optimizer and scheduler
    print('=> loss function, optimizer and scheduler')
    bce_loss = nn.BCELoss()
    criterion = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=cfg['lr'],
                      momentum=0.9, weight_decay=cfg['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    
    # training loop
    best_acc = 0
    best_epoch = 0
    model.train()
    print('=> Start Training, max epochs {}'.format(cfg['num_epochs']))
    for epoch in range(1, cfg['num_epochs']+1):
        print('-'*50)
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # foward
            optimizer.zero_grad()
            outputs, targets_weighted = model(inputs, use_mixup=cfg['use_mixup'],
                                              mixup_alpha=cfg['mixup_alpha'],
                                             labels=targets)
            if cfg['use_mixup']:
                loss = bce_loss(F.softmax(outputs,dim=1),targets_weighted) # + mse_loss(F.softmax(outputs,dim=1),targets_weighted) 
            else:
                targets = to_one_hot(targets,cfg['num_classes'])
                loss = bce_loss(F.softmax(outputs,dim=1),targets)
            
            # backward
            loss.backward()
            optimizer.step()
            
            if batch_idx % cfg['print_freq'] == 0:
                print('epoch: {:03d}, step: {:03d}, train_loss : {:.5f}'.format(epoch,batch_idx,loss.item()))
        
        # test after each epoch
        model.eval()
        acc = test(model)
        model.train()
        
        
        print('epoch: {:03d}, Accuracy : {}'.format(epoch, acc))
        if acc > best_acc:
            best_epoch = epoch
            best_acc = acc
        
        # scheduler step
        scheduler.step()
        
            
    print('=> Best Accuracy : {} at epoch {}'.format(best_acc,best_epoch))

train(config)
