In [1]:
import argparse
import datetime
import os
import sys
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

from timm.models import create_model

from engine import train_one_epoch, evaluate
from utils import get_training_dataloader, get_test_dataloader
import models

In [5]:
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
CHECKPOINT_PATH = './checkpoint'
MODEL_NAME = 'vit_tiny_patch16_224'
# MODEL_NAME = 'vit_small_patch16_224'
# MODEL_NAME = 'vit_base_patch16_224'
EPOCHS = 50
LR = 0.0001
WD = 0.0001


print(f"Creating model: {MODEL_NAME}")
model = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=100,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

cifar100_training_loader = get_training_dataloader(
    CIFAR100_TRAIN_MEAN,
    CIFAR100_TRAIN_STD,
    num_workers=4,
    batch_size=256,
    shuffle=True
)

cifar100_test_loader = get_test_dataloader(
    CIFAR100_TRAIN_MEAN,
    CIFAR100_TRAIN_STD,
    num_workers=4,
    batch_size=256,
    shuffle=False
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)


Creating model: vit_base_patch16_224


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /home/soroush/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth


  0%|          | 0.00/330M [00:00<?, ?B/s]

Files already downloaded and verified
Files already downloaded and verified
number of params: 85875556


In [None]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar100_training_loader,
        optimizer, device, epoch)
    test_stats = evaluate(cifar100_test_loader, model, criterion, device)
    print(f"Accuracy of the network on the {len(cifar100_test_loader)} test images: {test_stats['acc1']:.1f}%")        

In [None]:
start_time = time.time()
test_stats = evaluate(cifar100_test_loader, model, criterion, device)
end_time = time.time()
num_samples = len(cifar100_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))