<a href="https://colab.research.google.com/github/sour4bh/cifar-10/blob/master/cifar10_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/sour4bh/cifar-10
!mv cifar-10/* .
!mkdir logs epochs

Cloning into 'cifar-10'...
remote: Enumerating objects: 121, done.[K
remote: Counting objects: 100% (121/121), done.[K
remote: Compressing objects: 100% (95/95), done.[K
remote: Total 121 (delta 61), reused 66 (delta 21), pack-reused 0[K
Receiving objects: 100% (121/121), 38.64 KiB | 534.00 KiB/s, done.
Resolving deltas: 100% (61/61), done.
mkdir: cannot create directory ‘logs’: File exists


In [0]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /gdrive
/gdrive


In [0]:
# imports 
import os
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import MultiStepLR

from torchvision import datasets, transforms

from cutout import Cutout
from wide_resnet import WideResNet

In [0]:
# Control center

num_classes = 10
batch_size = 128
epochs = 200
seed = 0

LR_MILESTONES = [40, 60, 80, 90, 150, 155] # step down lr milestones
gamma = 0.2 #gamma for step lr 0.2 == 5x 
learning_rate = 0.1


data_augmentation = True
# cutout hyperparams
n_holes = 1
length = 16
# model - wideresnet hyperparams
depth = 28
widen_factor = 10
drop_rate = 0.3

# recover training
resume = False
resume_checkpoint =  ''


In [0]:
def save_this(epoch, test_acc, accuracy, model, optimizer, scheduler, on_drive=True):
    checkpoint = {
    'epoch': epoch,
    'test_acc' : test_acc,
    'train_acc' : accuracy,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),    
    'scheduler' : scheduler.state_dict()}
    if on_drive:
        # save checkpoints to google drive 
        torch.save(checkpoint, '/content/drive/My Drive/Colab Notebooks/checkpoints/' + test_id + f'{epoch}.pt')
    else:
        # save checkpoints to local 
        torch.save(checkpoint, 'epochs/' +  test_id +  f'{epoch}.pt')

def test(loader):
    model.eval()   
    correct = 0.
    total = 0.
    for images, labels in loader:
        images = images.cuda()
        labels = labels.cuda()

        with torch.no_grad():
            pred = model(images)

        pred = torch.max(pred.data, 1)[1]
        total += labels.size(0)
        correct += (pred == labels).sum().item()

    val_acc = correct / total
    model.train()
    return val_acc


In [5]:
cuda = True
cudnn.benchmark = True 
torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

test_id = 'cifar10_wideresent'

normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                    std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

train_transform = transforms.Compose([])
if data_augmentation:
    train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
    train_transform.transforms.append(transforms.RandomHorizontalFlip())
train_transform.transforms.append(transforms.ToTensor())
train_transform.transforms.append(normalize)

train_transform.transforms.append(Cutout(n_holes=n_holes, length=length)) # cutout augemntation 

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize])

train_dataset = datasets.CIFAR10(root='data/',
                                    train=True,
                                    transform=train_transform,
                                    download=True)

test_dataset = datasets.CIFAR10(root='data/',
                                train=False,
                                transform=test_transform,
                                download=True)


train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        pin_memory=True,
                                        num_workers=2)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        pin_memory=True,
                                        num_workers=2)

model = WideResNet(depth=depth, num_classes=num_classes, widen_factor=widen_factor, dropRate=drop_rate)
model = model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=0.0005)

scheduler = MultiStepLR(optimizer, milestones=LR_MILESTONES, gamma=gamma)

begin = -1
try:
    checkpoint_fpath =  resume_checkpoint
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    begin = checkpoint['epoch']
    scheduler = MultiStepLR(optimizer, milestones=LR_MILESTONES, gamma=gamma, last_epoch=begin)
    # print('test_acc :', checkpoint['test_acc'], 'train_acc :', checkpoint['train_acc'])
    # print('last_lr :', checkpoint['scheduler']['_last_lr'])
    print('resuming from', checkpoint['epoch'])
except FileNotFoundError:
    print('starting over..')
    pass

best_acc = 0
for epoch in range(epochs):
    if epoch <= begin:
        # scheduler.step()
        continue
    
    xentropy_loss_avg = 0.
    correct = 0.
    total = 0.

    progress_bar = tqdm(train_loader)
    for i, (images, labels) in enumerate(progress_bar):
        progress_bar.set_description('Epoch ' + str(epoch))

        images = images.cuda()
        labels = labels.cuda()

        model.zero_grad()
        pred = model(images)

        xentropy_loss = criterion(pred, labels)
        xentropy_loss.backward()
        optimizer.step()

        xentropy_loss_avg += xentropy_loss.item()

        # Calculate running average of accuracy
        pred = torch.max(pred.data, 1)[1]
        total += labels.size(0)
        correct += (pred == labels.data).sum().item()
        accuracy = correct / total
        _lr=optimizer.param_groups[0]["lr"]

        progress_bar.set_postfix(
            xentropy='%.3f' % (xentropy_loss_avg / (i + 1)),
            acc='%.3f' %(accuracy),
            lr='%.2E'%(_lr))

    test_acc = test(test_loader)
    tqdm.write('test_acc: %.3f' % (test_acc))

    if test_acc > best_acc:
        best_acc = test_acc
        save_this(epoch, test_acc, accuracy, model, optimizer, scheduler, on_drive=False)
    
    if epoch % 10 == 0:
        save_this(epoch, test_acc, accuracy, model, optimizer, scheduler, on_drive=False)
    
    _lr = optimizer.param_groups[0]["lr"]
    with open('log.csv', 'a') as f:
        f.write(f"epoch: {str(epoch)}, train_acc: {str(accuracy)}, test_acc: {str(test_acc)}, lr:{'%.2E'%(_lr)}" + '\n')
    
    scheduler.step()
    # end loop

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/cifar-10-python.tar.gz to data/
Files already downloaded and verified


  0%|          | 0/391 [00:00<?, ?it/s]

starting over..


Epoch 0: 100%|██████████| 391/391 [04:25<00:00,  1.18s/it, acc=0.387, lr=1.00E-01, xentropy=1.642]


test_acc: 0.535


Epoch 1: 100%|██████████| 391/391 [04:25<00:00,  1.64it/s, acc=0.578, lr=1.00E-01, xentropy=1.174]


test_acc: 0.602


Epoch 2: 100%|██████████| 391/391 [04:25<00:00,  1.65it/s, acc=0.653, lr=1.00E-01, xentropy=0.990]


test_acc: 0.608


Epoch 3: 100%|██████████| 391/391 [04:26<00:00,  1.65it/s, acc=0.696, lr=1.00E-01, xentropy=0.863]


test_acc: 0.672


Epoch 4: 100%|██████████| 391/391 [04:25<00:00,  1.66it/s, acc=0.724, lr=1.00E-01, xentropy=0.787]


test_acc: 0.725


Epoch 5: 100%|██████████| 391/391 [04:26<00:00,  1.65it/s, acc=0.747, lr=1.00E-01, xentropy=0.728]


test_acc: 0.743


Epoch 6: 100%|██████████| 391/391 [04:26<00:00,  1.66it/s, acc=0.759, lr=1.00E-01, xentropy=0.696]


test_acc: 0.801


Epoch 7: 100%|██████████| 391/391 [04:25<00:00,  1.65it/s, acc=0.771, lr=1.00E-01, xentropy=0.659]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.749


Epoch 8: 100%|██████████| 391/391 [04:26<00:00,  1.65it/s, acc=0.778, lr=1.00E-01, xentropy=0.641]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.754


Epoch 9: 100%|██████████| 391/391 [04:23<00:00,  1.66it/s, acc=0.785, lr=1.00E-01, xentropy=0.619]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.761


Epoch 10: 100%|██████████| 391/391 [04:24<00:00,  1.66it/s, acc=0.789, lr=1.00E-01, xentropy=0.607]


test_acc: 0.820


Epoch 11: 100%|██████████| 391/391 [04:25<00:00,  1.67it/s, acc=0.796, lr=1.00E-01, xentropy=0.589]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.776


Epoch 12: 100%|██████████| 391/391 [04:24<00:00,  1.65it/s, acc=0.800, lr=1.00E-01, xentropy=0.578]


test_acc: 0.821


Epoch 13: 100%|██████████| 391/391 [04:25<00:00,  1.65it/s, acc=0.804, lr=1.00E-01, xentropy=0.567]


test_acc: 0.831


Epoch 14: 100%|██████████| 391/391 [04:25<00:00,  1.66it/s, acc=0.808, lr=1.00E-01, xentropy=0.553]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.816


Epoch 15: 100%|██████████| 391/391 [04:25<00:00,  1.66it/s, acc=0.812, lr=1.00E-01, xentropy=0.544]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.803


Epoch 16: 100%|██████████| 391/391 [04:24<00:00,  1.66it/s, acc=0.814, lr=1.00E-01, xentropy=0.542]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.794


Epoch 17: 100%|██████████| 391/391 [04:23<00:00,  1.66it/s, acc=0.821, lr=1.00E-01, xentropy=0.526]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.828


Epoch 18: 100%|██████████| 391/391 [04:23<00:00,  1.67it/s, acc=0.821, lr=1.00E-01, xentropy=0.519]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.793


Epoch 19: 100%|██████████| 391/391 [04:24<00:00,  1.66it/s, acc=0.822, lr=1.00E-01, xentropy=0.515]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.808


Epoch 20: 100%|██████████| 391/391 [04:24<00:00,  1.66it/s, acc=0.825, lr=1.00E-01, xentropy=0.507]


test_acc: 0.842


Epoch 21: 100%|██████████| 391/391 [04:24<00:00,  1.66it/s, acc=0.828, lr=1.00E-01, xentropy=0.502]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.837


Epoch 22: 100%|██████████| 391/391 [04:25<00:00,  1.66it/s, acc=0.829, lr=1.00E-01, xentropy=0.495]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.805


Epoch 23: 100%|██████████| 391/391 [04:25<00:00,  1.65it/s, acc=0.829, lr=1.00E-01, xentropy=0.495]


test_acc: 0.870


Epoch 24: 100%|██████████| 391/391 [04:25<00:00,  1.67it/s, acc=0.831, lr=1.00E-01, xentropy=0.494]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.811


Epoch 25: 100%|██████████| 391/391 [04:23<00:00,  1.66it/s, acc=0.832, lr=1.00E-01, xentropy=0.485]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.838


Epoch 26: 100%|██████████| 391/391 [04:24<00:00,  1.64it/s, acc=0.832, lr=1.00E-01, xentropy=0.482]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.861


Epoch 27: 100%|██████████| 391/391 [04:24<00:00,  1.67it/s, acc=0.835, lr=1.00E-01, xentropy=0.478]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.844


Epoch 28: 100%|██████████| 391/391 [04:25<00:00,  1.66it/s, acc=0.835, lr=1.00E-01, xentropy=0.477]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.798


Epoch 29: 100%|██████████| 391/391 [04:24<00:00,  1.65it/s, acc=0.837, lr=1.00E-01, xentropy=0.473]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.833


Epoch 30: 100%|██████████| 391/391 [04:23<00:00,  1.66it/s, acc=0.835, lr=1.00E-01, xentropy=0.473]


test_acc: 0.848


Epoch 31: 100%|██████████| 391/391 [04:23<00:00,  1.67it/s, acc=0.839, lr=1.00E-01, xentropy=0.467]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.849


Epoch 32: 100%|██████████| 391/391 [04:22<00:00,  1.66it/s, acc=0.839, lr=1.00E-01, xentropy=0.467]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.846


Epoch 33: 100%|██████████| 391/391 [04:24<00:00,  1.66it/s, acc=0.840, lr=1.00E-01, xentropy=0.466]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.859


Epoch 34: 100%|██████████| 391/391 [04:25<00:00,  1.65it/s, acc=0.842, lr=1.00E-01, xentropy=0.458]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.862


Epoch 35: 100%|██████████| 391/391 [04:23<00:00,  1.66it/s, acc=0.839, lr=1.00E-01, xentropy=0.463]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.847


Epoch 36: 100%|██████████| 391/391 [04:23<00:00,  1.67it/s, acc=0.843, lr=1.00E-01, xentropy=0.456]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.825


Epoch 37: 100%|██████████| 391/391 [04:23<00:00,  1.67it/s, acc=0.842, lr=1.00E-01, xentropy=0.456]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.852


Epoch 38: 100%|██████████| 391/391 [04:25<00:00,  1.66it/s, acc=0.842, lr=1.00E-01, xentropy=0.460]


test_acc: 0.874


Epoch 39: 100%|██████████| 391/391 [04:24<00:00,  1.65it/s, acc=0.840, lr=1.00E-01, xentropy=0.461]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.842


Epoch 40: 100%|██████████| 391/391 [04:24<00:00,  1.66it/s, acc=0.899, lr=2.00E-02, xentropy=0.293]


test_acc: 0.928


Epoch 41: 100%|██████████| 391/391 [04:24<00:00,  1.67it/s, acc=0.912, lr=2.00E-02, xentropy=0.254]


test_acc: 0.934


Epoch 42: 100%|██████████| 391/391 [04:26<00:00,  1.66it/s, acc=0.917, lr=2.00E-02, xentropy=0.240]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.933


Epoch 43: 100%|██████████| 391/391 [04:25<00:00,  1.65it/s, acc=0.920, lr=2.00E-02, xentropy=0.230]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.934


Epoch 44: 100%|██████████| 391/391 [04:22<00:00,  1.66it/s, acc=0.921, lr=2.00E-02, xentropy=0.226]


test_acc: 0.938


Epoch 45: 100%|██████████| 391/391 [04:23<00:00,  1.67it/s, acc=0.922, lr=2.00E-02, xentropy=0.221]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.936


Epoch 46: 100%|██████████| 391/391 [04:24<00:00,  1.65it/s, acc=0.924, lr=2.00E-02, xentropy=0.219]


test_acc: 0.938


Epoch 47: 100%|██████████| 391/391 [04:25<00:00,  1.66it/s, acc=0.925, lr=2.00E-02, xentropy=0.217]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.936


Epoch 48: 100%|██████████| 391/391 [04:24<00:00,  1.65it/s, acc=0.925, lr=2.00E-02, xentropy=0.218]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.936


Epoch 49: 100%|██████████| 391/391 [04:23<00:00,  1.66it/s, acc=0.925, lr=2.00E-02, xentropy=0.216]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.934


Epoch 50: 100%|██████████| 391/391 [04:24<00:00,  1.67it/s, acc=0.925, lr=2.00E-02, xentropy=0.214]


test_acc: 0.935


Epoch 51: 100%|██████████| 391/391 [04:24<00:00,  1.66it/s, acc=0.924, lr=2.00E-02, xentropy=0.213]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.929


Epoch 52: 100%|██████████| 391/391 [04:24<00:00,  1.65it/s, acc=0.925, lr=2.00E-02, xentropy=0.216]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.922


Epoch 53: 100%|██████████| 391/391 [04:25<00:00,  1.65it/s, acc=0.925, lr=2.00E-02, xentropy=0.218]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.924


Epoch 54: 100%|██████████| 391/391 [04:23<00:00,  1.67it/s, acc=0.925, lr=2.00E-02, xentropy=0.216]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.934


Epoch 55: 100%|██████████| 391/391 [04:24<00:00,  1.67it/s, acc=0.926, lr=2.00E-02, xentropy=0.217]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.927


Epoch 56: 100%|██████████| 391/391 [04:24<00:00,  1.64it/s, acc=0.925, lr=2.00E-02, xentropy=0.215]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.932


Epoch 57: 100%|██████████| 391/391 [04:25<00:00,  1.66it/s, acc=0.925, lr=2.00E-02, xentropy=0.217]
  0%|          | 0/391 [00:00<?, ?it/s]

test_acc: 0.925


Epoch 58:  67%|██████▋   | 261/391 [02:57<01:28,  1.47it/s, acc=0.927, lr=2.00E-02, xentropy=0.211]

Buffered data was truncated after reaching the output size limit.