In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torchvision


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using:', device)

Using: cuda


In [3]:
normalize_transform = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

# define transforms
train_transform = transforms.Compose([
    transforms.Resize((227)), #images are 32 by 32 so actually UPSIZING here
    transforms.CenterCrop(227),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize_transform,
])

test_transform = transforms.Compose([
    transforms.Resize((227)),
    transforms.CenterCrop(227),
    transforms.ToTensor(),
    normalize_transform,
])

In [4]:
# load datasets
data_dir = './data'
train_dataset = datasets.CIFAR100(
    root=data_dir, train=True,
    download=True, transform=train_transform)

valid_dataset = datasets.CIFAR100(
    root=data_dir, train=False,
    download=True, transform=test_transform,
)
print('train length:', len(train_dataset), 'val length:', len(valid_dataset))

Files already downloaded and verified
Files already downloaded and verified
train length: 50000 val length: 10000


In [5]:
# define loaders
batch_size = 64
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
 
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [6]:
from nets.AlexNet import AlexNet
# grab model and define training parameters
num_classes = 100
learning_rate = 0.005

model = AlexNet(num_classes)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  

# Decay LR by a factor of 0.1 every 2 epochs
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

In [7]:
from utils.train import train_model

model = train_model(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=exp_lr_scheduler,
    num_epochs=5,
    device='cuda',
    train_loader=train_loader,
    val_loader=valid_loader)

Using: cuda
Epoch 0/4
----------


100%|██████████| 782/782 [01:08<00:00, 11.37it/s]


train Loss: 3.7195 Acc: 6374/50000 = 0.1275


100%|██████████| 157/157 [00:06<00:00, 25.99it/s]


val Loss: 3.3020 Acc: 1964/10000 = 0.1964
New best accuracy: 0.196

Epoch 1/4
----------


100%|██████████| 782/782 [01:10<00:00, 11.09it/s]


train Loss: 2.9171 Acc: 12994/50000 = 0.2599


100%|██████████| 157/157 [00:06<00:00, 22.95it/s]


val Loss: 2.8898 Acc: 2757/10000 = 0.2757
New best accuracy: 0.276

Epoch 2/4
----------


100%|██████████| 782/782 [01:11<00:00, 10.94it/s]


train Loss: 2.2982 Acc: 19566/50000 = 0.3913


100%|██████████| 157/157 [00:06<00:00, 24.93it/s]


val Loss: 2.1959 Acc: 4133/10000 = 0.4133
New best accuracy: 0.413

Epoch 3/4
----------


100%|██████████| 782/782 [01:10<00:00, 11.14it/s]


train Loss: 2.1688 Acc: 20982/50000 = 0.4196


100%|██████████| 157/157 [00:06<00:00, 24.97it/s]


val Loss: 2.1063 Acc: 4334/10000 = 0.4334
New best accuracy: 0.433

Epoch 4/4
----------


100%|██████████| 782/782 [01:10<00:00, 11.17it/s]


train Loss: 2.0454 Acc: 22540/50000 = 0.4508


100%|██████████| 157/157 [00:06<00:00, 23.96it/s]


val Loss: 2.0185 Acc: 4507/10000 = 0.4507
New best accuracy: 0.451

Training complete in 6m 25s
Best val Acc: 0.450700


In [8]:
inputs, labels = next(iter(train_loader))
model = model.eval()
preds = model.to(device)(inputs.to(device))
_, preds = torch.max(preds, 1)
preds, labels

(tensor([91, 87, 61, 25, 37, 62, 27, 24, 24, 43, 28, 62, 64, 68, 56, 12, 57, 54,
         95,  8, 74, 87, 46, 77, 63, 35, 93, 70, 23, 17, 27, 73, 48, 22,  5, 20,
         39, 62, 18, 45,  6, 83, 48, 70,  1, 27, 81, 62, 94,  7, 82, 53, 62, 81,
         31, 73, 97, 13, 45, 35, 38, 57, 51, 27], device='cuda:0'),
 tensor([27, 94, 61, 25, 90, 62, 27,  7, 24, 66, 28, 62, 64, 68, 56, 12, 55, 54,
         95, 61, 74, 87, 35, 55, 63, 64, 73, 70, 30, 37, 80, 73, 48, 22, 72, 20,
         39, 53, 44, 26,  6, 83, 48, 45, 32, 12, 17, 70, 94, 14, 82, 83, 62, 39,
         43, 99, 66, 13, 45, 11, 80, 57, 25, 73]))