In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

import Resnet50 as res
from utils import progress_bar

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch


In [3]:
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [4]:
print('==> Building model..')
net = res.resnet50()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

==> Building model..


In [5]:
#print('==> Resuming from checkpoint..')
#assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
#checkpoint = torch.load('./checkpoint/ckpt.pth')
#net.load_state_dict(checkpoint['net'])
#best_acc = checkpoint['acc']
#start_epoch = checkpoint['epoch']

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


In [7]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

In [8]:
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    # Save checkpoint.    
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [9]:
for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    test(epoch)
    scheduler.step()


Epoch: 0






Saving..

Epoch: 1






Saving..

Epoch: 2






Saving..

Epoch: 3








Saving..

Epoch: 4







Epoch: 5








Saving..

Epoch: 6








Saving..

Epoch: 7






Saving..

Epoch: 8








Saving..

Epoch: 9









Epoch: 10









Epoch: 11








Saving..

Epoch: 12








Saving..

Epoch: 13








Saving..

Epoch: 14









Epoch: 15









Epoch: 16







Epoch: 17







Epoch: 18








Saving..

Epoch: 19








Saving..

Epoch: 20







Epoch: 21









Epoch: 22








Saving..

Epoch: 23









Epoch: 24









Epoch: 25







Epoch: 26









Epoch: 27









Epoch: 28









Epoch: 29









Epoch: 30







Epoch: 31









Epoch: 32









Epoch: 33








Saving..

Epoch: 34








Saving..

Epoch: 35









Epoch: 36







Epoch: 37







Epoch: 38









Epoch: 39









Epoch: 40








Saving..

Epoch: 41









Epoch: 42









Epoch: 43









Epoch: 44









Epoch: 45







Epoch: 46









Epoch: 47









Epoch: 48








Saving..

Epoch: 49









Epoch: 50









Epoch: 51









Epoch: 52








Saving..

Epoch: 53







Epoch: 54







Epoch: 55





Epoch: 56







Epoch: 57







Epoch: 58







Epoch: 59







Epoch: 60





Epoch: 61





Epoch: 62





Epoch: 63







Epoch: 64





Epoch: 65





Epoch: 66





Epoch: 67





Epoch: 68





Epoch: 69







Epoch: 70







Epoch: 71






Saving..

Epoch: 72




Saving..

Epoch: 73







Epoch: 74





Epoch: 75





Epoch: 76







Epoch: 77







Epoch: 78






Saving..

Epoch: 79







Epoch: 80





Epoch: 81





Epoch: 82





Epoch: 83





Epoch: 84



Epoch: 85





Epoch: 86



Epoch: 87





Epoch: 88





Epoch: 89





Epoch: 90





Epoch: 91





Epoch: 92





Epoch: 93





Epoch: 94





Epoch: 95





Epoch: 96





Epoch: 97




Saving..

Epoch: 98





Epoch: 99





Epoch: 100





Epoch: 101





Epoch: 102





Epoch: 103





Epoch: 104





Epoch: 105





Epoch: 106





Epoch: 107




Saving..

Epoch: 108





Epoch: 109





Epoch: 110





Epoch: 111





Epoch: 112





Epoch: 113





Epoch: 114





Epoch: 115





Epoch: 116





Epoch: 117





Epoch: 118





Epoch: 119





Epoch: 120





Epoch: 121





Epoch: 122





Epoch: 123





Epoch: 124





Epoch: 125




Saving..

Epoch: 126





Epoch: 127





Epoch: 128





Epoch: 129





Epoch: 130




Saving..

Epoch: 131





Epoch: 132





Epoch: 133





Epoch: 134





Epoch: 135





Epoch: 136





Epoch: 137





Epoch: 138




Saving..

Epoch: 139





Epoch: 140





Epoch: 141





Epoch: 142




Saving..

Epoch: 143





Epoch: 144






Saving..

Epoch: 145





Epoch: 146





Epoch: 147





Epoch: 148





Epoch: 149





Epoch: 150





Epoch: 151





Epoch: 152





Epoch: 153






Saving..

Epoch: 154





Epoch: 155




Saving..

Epoch: 156





Epoch: 157





Epoch: 158





Epoch: 159





Epoch: 160





Epoch: 161





Epoch: 162




Saving..

Epoch: 163





Epoch: 164





Epoch: 165






Saving..

Epoch: 166






Saving..

Epoch: 167





Epoch: 168




Saving..

Epoch: 169





Epoch: 170





Epoch: 171




Saving..

Epoch: 172





Epoch: 173





Epoch: 174






Saving..

Epoch: 175






Saving..

Epoch: 176




Saving..

Epoch: 177




Saving..

Epoch: 178





Epoch: 179





Epoch: 180




Saving..

Epoch: 181




Saving..

Epoch: 182




Saving..

Epoch: 183





Epoch: 184





Epoch: 185





Epoch: 186





Epoch: 187




Saving..

Epoch: 188





Epoch: 189





Epoch: 190





Epoch: 191





Epoch: 192





Epoch: 193





Epoch: 194





Epoch: 195





Epoch: 196




Saving..

Epoch: 197





Epoch: 198





Epoch: 199




