In [89]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets, models
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn

In [90]:
import matplotlib.pyplot as plt
import numpy as np
import copy
import time
import os

In [91]:
import utils
from models import ResNet

In [92]:
data_loaders, test_loader = utils.get_data_loaders()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [93]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [94]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#net = ResNet.ResNet50()
net = ResNet.ResNet18()

print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.device_count())

0
<torch.cuda.device object at 0x7f8ec7390320>
1


In [95]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()
else:
    print('CPU')
    criterion = nn.CrossEntropyLoss()

Tesla K80


In [96]:
#optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)
optimizer = optim.Adam(net.parameters())
#scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
#scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 32)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[81, 122], gamma=0.1)

Size of CIFAR-10

In [97]:
dataset_size = {'train': 4000,'val': 1000,'test': 1000}

Load Trained Model

In [98]:
old = False

In [99]:
SAVE_PATH = './trained-models/sgd-net.pth'

if old:
    old_epochs = utils.load_checkpoint(net, optimizer, scheduler, SAVE_PATH)

Implement SWATS

In [100]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_accuracy = 0.0
    
    for epoch in range(num_epochs):
        
        print(str(epoch) + "/" + str(num_epochs))
        
        if type(scheduler) is torch.optim.lr_scheduler.MultiStepLR:
            scheduler.step()
        
        for phase in ['train', 'val']:
            
            print(phase)
            
            if phase == 'train':
                model.train()
                
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            total = 0
            
            start = time.time()
            
            for index, (inputs, targets) in enumerate(data_loaders[phase]):
                
                inputs = inputs.to(device)
                targets = targets.to(device)

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = outputs.max(1)
                    loss = criterion(outputs, targets)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item()
                running_corrects += preds.eq(targets).sum().item()
                total += targets.size(0)
                
            epoch_loss = running_loss / total
            epoch_acc = running_corrects / total
                
            print('Loss: ' + str(epoch_loss) + ", Epoch Accuracy: " + str(epoch_acc))
            
            print('Time: ' + str((time.time() - start) / 60))
            
            if phase == 'val' and type(scheduler) is torch.optim.lr_scheduler.ReduceLROnPlateau:
                scheduler.step(epoch_loss)
            
            if phase == 'val' and epoch_acc > best_accuracy:
                best_accuracy = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
                if not os.path.isdir('trained-models'):
                    os.mkdir('trained-models')
                
                state = {
                    
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                    
                }

                if os.path.exists(SAVE_PATH):
                    os.remove(SAVE_PATH)
                
                torch.save(state, SAVE_PATH)
    
    print('Best Accuracy: ' + best_accuracy)
    
    model.load_state_dict(best_model_wts)
    return model

In [101]:
if old:
    epochs = 200 - old_epochs
else:
    epochs = 200

print(epochs)

net = train_model(net, criterion, optimizer, scheduler, epochs)

200
0/200
train
Loss: 0.16504867784380914, Epoch Accuracy: 0.391775
Time: 3.1166834632555642
val
Loss: 0.12403615301251411, Epoch Accuracy: 0.5578
Time: 0.21518206199010212
1/200
train
Loss: 0.1131749937608838, Epoch Accuracy: 0.5988
Time: 3.118698986371358
val
Loss: 0.09366305884420872, Epoch Accuracy: 0.6756
Time: 0.2164017915725708
2/200
train
Loss: 0.08715990147776902, Epoch Accuracy: 0.694075
Time: 3.1044236024220786
val
Loss: 0.08025051166508347, Epoch Accuracy: 0.7225
Time: 0.21326533555984498
3/200
train
Loss: 0.07126511678844691, Epoch Accuracy: 0.75225
Time: 3.12467391093572
val
Loss: 0.06717916249297559, Epoch Accuracy: 0.7731
Time: 0.2156707763671875
4/200
train
Loss: 0.06039436905700713, Epoch Accuracy: 0.79155
Time: 3.1185857971509297
val
Loss: 0.05276236581970006, Epoch Accuracy: 0.8202
Time: 0.2138242204984029
5/200
train
Loss: 0.05365123755014502, Epoch Accuracy: 0.8125
Time: 3.121328337987264
val
Loss: 0.04892759549571201, Epoch Accuracy: 0.8355
Time: 0.21665351390838

Process Process-557:
Process Process-558:
Process Process-559:
Process Process-560:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ubuntu/src/anaconda3/envs/f

KeyboardInterrupt: 

KeyboardInterrupt
KeyboardInterrupt


In [None]:
sgd_epochs = 10
utils.load_checkpoint(net, optimizer, scheduler, SAVE_PATH)
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5, weight_decay=5e-4, nesterov=True)
net = train_model(net, criterion, optimizer, scheduler, sgd_epochs)

0/10
train
Loss: 0.006820293080124247, Epoch Accuracy: 0.978225
Time: 2.950368916988373
val
Loss: 0.024997299911032313, Epoch Accuracy: 0.9199
Time: 0.2076003114382426
1/10
train
Loss: 0.006576670150706923, Epoch Accuracy: 0.98075
Time: 2.95213219722112
val
Loss: 0.024645183473995348, Epoch Accuracy: 0.9219
Time: 0.20718626181284586
2/10
train
Loss: 0.006388789913225628, Epoch Accuracy: 0.9814
Time: 2.947488232453664
val
Loss: 0.024614500575418786, Epoch Accuracy: 0.9209
Time: 0.2071125308672587
3/10
train
Loss: 0.006517260630999954, Epoch Accuracy: 0.98055
Time: 2.9477431774139404
val
Loss: 0.024773272204412205, Epoch Accuracy: 0.9182
Time: 0.2057653824488322
4/10
train
