In [1]:
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim

import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from torch import autocast
from torch.cuda.amp import GradScaler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Utiliy Functions.
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [3]:
# Ensure that we get deterministic results.
SEED=1
random.seed(SEED)
torch.manual_seed(SEED)
cudnn.deterministic = True

In [4]:
# Setup TensorBoard
writer = SummaryWriter(log_dir="/data/logs")

In [5]:
# Check that we have a CUDA enabled device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
MEAN_RGB = [0.47889522, 0.47227842, 0.43047404]
STD_RGB = [0.229, 0.224, 0.225]
IMG_SIZE = 64

In [7]:
# Load train and val dataset. 
TRAINDIR = "/data/train" # ImageNet train.
VALDIR = "/data/val" # ImageNet val.
BATCH_SIZE = 64
TRAIN_WORKERS = 6
VAL_WORKERS = 0
SHUFFLE = False

# Scale to ImageNet mean and STD since we will be using a model pretrained on ImageNet. 
transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(MEAN_RGB, STD_RGB),
])

transform_val = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(MEAN_RGB, STD_RGB),
])

# Load training data.
train_dataset = datasets.ImageFolder(
    TRAINDIR, transform=transform_train)

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=BATCH_SIZE, num_workers=TRAIN_WORKERS, shuffle=SHUFFLE, pin_memory=True, sampler=None)

# Load validation data.
val_dataset = datasets.ImageFolder(
    VALDIR, transform=transform_val)

val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=BATCH_SIZE, num_workers=VAL_WORKERS, shuffle=SHUFFLE, pin_memory=True, sampler=None)

In [8]:
NUM_CLASSES = 1000
ARCH = 'resnet18'
LR = 1e-4
# Load model from PyTorch.
model = models.__dict__[ARCH]()
inf = model.fc.in_features
# Set fully connected layer to train with 1000 classes.
model.fc = nn.Linear(inf, NUM_CLASSES)
model.to(device)
# Setup optimizer and loss function. 
optimizer = torch.optim.Adam(model.parameters(), LR)
criterion = nn.CrossEntropyLoss().to(device)

In [9]:
PRINT_FREQ = BATCH_SIZE - 1 
global_step = 0

In [10]:
scaler = GradScaler()
# Define train step.
def train(train_loader, model, criterion, optimizer, epoch):
    global global_step
    # Keep progress of metrics.
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # Switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # Measure data loading time.
        data_time.update(time.time() - end)
        
        # Move data to device
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
            
        # Use automatic mixed precision (AMP) to increase training speed.
        with autocast("cuda"):
            # Compute output.
            output = model(images)
            loss = criterion(output, target)

        # Measure accuracy and record loss.
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
        
        # Write tensorboard logs.
        writer.add_scalar("Loss/train", loss, global_step=global_step)
        writer.add_scalar("Acc1/train", top1.avg, global_step=global_step)
        writer.add_scalar("Acc5/train", top5.avg, global_step=global_step)
        global_step += 1

        # Compute gradient.
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Measure elapsed time.
        batch_time.update(time.time() - end)
        end = time.time()

        if i % PRINT_FREQ == 0:
            progress.display(i)

In [11]:
# Define validation step.
def validate(val_loader, model, criterion):
    global global_step
    # Keep progress of metrics.
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # Switch to evaluate mode.
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            
            # Move data to GPU if CUDA device is available.
            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            # Compute output.
            output = model(images)
            loss = criterion(output, target)

            # Measure accuracy and record loss.
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))
            

            # Measure elapsed time.
            batch_time.update(time.time() - end)
            end = time.time()

            if i % PRINT_FREQ == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
    
     # Write tensorboard logs.
    writer.add_scalar("Loss/val", loss, global_step=global_step)
    writer.add_scalar("Acc1/val", top1.avg, global_step=global_step)
    writer.add_scalar("Acc5/val", top5.avg, global_step=global_step)
    
    return top1.avg

In [12]:
start_time = time.time()

for epoch in range(1):
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch)

    # evaluate on validation set
    validate(val_loader, model, criterion)
    
execution_time = (time.time() - start_time)
print(f"Execution time in seconds: {execution_time}")

Epoch: [0][    0/20019]	Time  1.476 ( 1.476)	Data  0.485 ( 0.485)	Loss 6.3934e+00 (6.3934e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)
Epoch: [0][   63/20019]	Time  0.258 ( 0.073)	Data  0.239 ( 0.037)	Loss 7.4560e+00 (4.8045e+00)	Acc@1   0.00 ( 25.32)	Acc@5   0.00 ( 66.48)
Epoch: [0][  126/20019]	Time  0.019 ( 0.065)	Data  0.000 ( 0.037)	Loss 7.1254e+00 (5.0536e+00)	Acc@1   0.00 ( 13.71)	Acc@5   0.00 ( 55.59)
Epoch: [0][  189/20019]	Time  0.242 ( 0.064)	Data  0.221 ( 0.040)	Loss 5.7836e+00 (5.1099e+00)	Acc@1   0.00 ( 10.86)	Acc@5   0.00 ( 43.69)
Epoch: [0][  252/20019]	Time  0.020 ( 0.061)	Data  0.000 ( 0.038)	Loss 5.8351e+00 (5.0707e+00)	Acc@1   0.00 ( 11.12)	Acc@5   0.00 ( 38.40)
Epoch: [0][  315/20019]	Time  0.219 ( 0.061)	Data  0.200 ( 0.039)	Loss 5.5262e+00 (5.2194e+00)	Acc@1   0.00 (  9.48)	Acc@5   0.00 ( 32.44)
Epoch: [0][  378/20019]	Time  0.016 ( 0.060)	Data  0.000 ( 0.039)	Loss 4.1673e+00 (5.2944e+00)	Acc@1   0.00 (  8.09)	Acc@5   0.00 ( 27.92)
Epoch: [0][  441/20019]	Tim

Epoch: [0][ 3717/20019]	Time  0.016 ( 0.060)	Data  0.000 ( 0.041)	Loss 8.7703e+00 (7.8797e+00)	Acc@1   0.00 (  0.94)	Acc@5   0.00 (  3.24)
Epoch: [0][ 3780/20019]	Time  0.019 ( 0.060)	Data  0.000 ( 0.041)	Loss 9.3325e+00 (7.8957e+00)	Acc@1   0.00 (  0.93)	Acc@5   0.00 (  3.19)
Epoch: [0][ 3843/20019]	Time  0.019 ( 0.060)	Data  0.000 ( 0.041)	Loss 8.2479e+00 (7.9102e+00)	Acc@1   0.00 (  0.91)	Acc@5   0.00 (  3.14)
Epoch: [0][ 3906/20019]	Time  0.138 ( 0.060)	Data  0.118 ( 0.041)	Loss 9.1726e+00 (7.9277e+00)	Acc@1   0.00 (  0.90)	Acc@5   0.00 (  3.09)
Epoch: [0][ 3969/20019]	Time  0.020 ( 0.060)	Data  0.000 ( 0.041)	Loss 9.1560e+00 (7.9427e+00)	Acc@1   0.00 (  0.88)	Acc@5   0.00 (  3.04)
Epoch: [0][ 4032/20019]	Time  0.180 ( 0.060)	Data  0.162 ( 0.041)	Loss 9.2007e+00 (7.9536e+00)	Acc@1   0.00 (  0.87)	Acc@5   0.00 (  2.99)
Epoch: [0][ 4095/20019]	Time  0.020 ( 0.060)	Data  0.000 ( 0.040)	Loss 8.5403e+00 (7.9631e+00)	Acc@1   0.00 (  0.86)	Acc@5   0.00 (  2.94)
Epoch: [0][ 4158/20019]	Tim

Epoch: [0][ 7434/20019]	Time  0.019 ( 0.060)	Data  0.000 ( 0.040)	Loss 7.1094e+00 (7.6436e+00)	Acc@1   0.00 (  0.47)	Acc@5   0.00 (  1.62)
Epoch: [0][ 7497/20019]	Time  0.019 ( 0.060)	Data  0.000 ( 0.040)	Loss 7.1680e+00 (7.6391e+00)	Acc@1   0.00 (  0.47)	Acc@5   0.00 (  1.61)
Epoch: [0][ 7560/20019]	Time  0.016 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.1289e+00 (7.6349e+00)	Acc@1   0.00 (  0.46)	Acc@5   0.00 (  1.60)
Epoch: [0][ 7623/20019]	Time  0.019 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.1250e+00 (7.6304e+00)	Acc@1   0.00 (  0.46)	Acc@5   0.00 (  1.58)
Epoch: [0][ 7686/20019]	Time  0.019 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.0586e+00 (7.6258e+00)	Acc@1   0.00 (  0.46)	Acc@5   0.00 (  1.57)
Epoch: [0][ 7749/20019]	Time  0.147 ( 0.060)	Data  0.129 ( 0.041)	Loss 7.1523e+00 (7.6216e+00)	Acc@1   0.00 (  0.45)	Acc@5   0.00 (  1.56)
Epoch: [0][ 7812/20019]	Time  0.101 ( 0.060)	Data  0.082 ( 0.041)	Loss 7.1289e+00 (7.6172e+00)	Acc@1   0.00 (  0.45)	Acc@5   0.00 (  1.54)
Epoch: [0][ 7875/20019]	Tim

Epoch: [0][11151/20019]	Time  0.020 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.2263e+00 (7.4871e+00)	Acc@1   0.00 (  0.31)	Acc@5   0.00 (  1.08)
Epoch: [0][11214/20019]	Time  0.220 ( 0.060)	Data  0.200 ( 0.041)	Loss 7.3452e+00 (7.4857e+00)	Acc@1   0.00 (  0.31)	Acc@5   0.00 (  1.08)
Epoch: [0][11277/20019]	Time  0.020 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.4189e+00 (7.4848e+00)	Acc@1   0.00 (  0.31)	Acc@5   0.00 (  1.07)
Epoch: [0][11340/20019]	Time  0.024 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.5909e+00 (7.4839e+00)	Acc@1   0.00 (  0.31)	Acc@5   0.00 (  1.06)
Epoch: [0][11403/20019]	Time  0.019 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.3828e+00 (7.4832e+00)	Acc@1   0.00 (  0.31)	Acc@5   0.00 (  1.06)
Epoch: [0][11466/20019]	Time  0.017 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.2859e+00 (7.4820e+00)	Acc@1   0.00 (  0.31)	Acc@5   0.00 (  1.05)
Epoch: [0][11529/20019]	Time  0.019 ( 0.060)	Data  0.000 ( 0.041)	Loss 7.2728e+00 (7.4812e+00)	Acc@1   0.00 (  0.30)	Acc@5   0.00 (  1.05)
Epoch: [0][11592/20019]	Tim

Epoch: [0][14868/20019]	Time  0.020 ( 0.059)	Data  0.000 ( 0.040)	Loss 7.4531e+00 (7.4484e+00)	Acc@1   0.00 (  0.24)	Acc@5   0.00 (  0.81)
Epoch: [0][14931/20019]	Time  0.016 ( 0.059)	Data  0.000 ( 0.040)	Loss 7.4375e+00 (7.4482e+00)	Acc@1   0.00 (  0.23)	Acc@5   0.00 (  0.81)
Epoch: [0][14994/20019]	Time  0.153 ( 0.059)	Data  0.138 ( 0.040)	Loss 7.4727e+00 (7.4481e+00)	Acc@1   0.00 (  0.23)	Acc@5   0.00 (  0.80)
Epoch: [0][15057/20019]	Time  0.056 ( 0.059)	Data  0.040 ( 0.040)	Loss 7.3945e+00 (7.4478e+00)	Acc@1   0.00 (  0.23)	Acc@5   0.00 (  0.80)
Epoch: [0][15120/20019]	Time  0.020 ( 0.059)	Data  0.000 ( 0.040)	Loss 7.3594e+00 (7.4476e+00)	Acc@1   0.00 (  0.23)	Acc@5   0.00 (  0.80)
Epoch: [0][15183/20019]	Time  0.213 ( 0.059)	Data  0.194 ( 0.040)	Loss 7.4883e+00 (7.4475e+00)	Acc@1   0.00 (  0.23)	Acc@5   0.00 (  0.79)
Epoch: [0][15246/20019]	Time  0.021 ( 0.059)	Data  0.000 ( 0.040)	Loss 7.3750e+00 (7.4473e+00)	Acc@1   0.00 (  0.23)	Acc@5   0.00 (  0.79)
Epoch: [0][15309/20019]	Tim

Epoch: [0][18585/20019]	Time  0.252 ( 0.059)	Data  0.233 ( 0.040)	Loss 7.5312e+00 (7.4505e+00)	Acc@1   0.00 (  0.19)	Acc@5   0.00 (  0.65)
Epoch: [0][18648/20019]	Time  0.019 ( 0.059)	Data  0.000 ( 0.040)	Loss 7.4844e+00 (7.4507e+00)	Acc@1   0.00 (  0.19)	Acc@5   0.00 (  0.65)
Epoch: [0][18711/20019]	Time  0.257 ( 0.059)	Data  0.238 ( 0.040)	Loss 7.6016e+00 (7.4510e+00)	Acc@1   0.00 (  0.19)	Acc@5   0.00 (  0.64)
Epoch: [0][18774/20019]	Time  0.150 ( 0.059)	Data  0.133 ( 0.040)	Loss 7.4961e+00 (7.4513e+00)	Acc@1   0.00 (  0.19)	Acc@5   0.00 (  0.64)
Epoch: [0][18837/20019]	Time  0.271 ( 0.059)	Data  0.252 ( 0.040)	Loss 7.5430e+00 (7.4515e+00)	Acc@1   0.00 (  0.19)	Acc@5   0.00 (  0.64)
Epoch: [0][18900/20019]	Time  0.020 ( 0.059)	Data  0.000 ( 0.040)	Loss 7.4844e+00 (7.4518e+00)	Acc@1   0.00 (  0.19)	Acc@5   0.00 (  0.64)
Epoch: [0][18963/20019]	Time  0.158 ( 0.059)	Data  0.138 ( 0.040)	Loss 7.5602e+00 (7.4522e+00)	Acc@1   0.00 (  0.18)	Acc@5   0.00 (  0.64)
Epoch: [0][19026/20019]	Tim