In [None]:
!pip install timm

In [None]:
import torch
import torchvision
import timm

from tqdm import tqdm

In [None]:
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)

In [None]:
EPOCHS=25
BATCH_SIZE=32
LR=0.01       # Starting learning rate
LR_GAMMA=0.9  # Exponential decay factor of learning rate (lr = LR * LR_GAMMA ^ epoch)
MODEL_NAME="resnet18"
PRETRAINED=False
DEVICE='cuda'
# Dropout
DROP_RATE=0.2
DROP_PATH_RATE=0.2

In [None]:
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

val_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

In [None]:
dataset_root = "./data"
train_dataset = torchvision.datasets.CIFAR10(root=dataset_root, train=True, download=True, transform=train_transforms)
val_dataset = torchvision.datasets.CIFAR10(root=dataset_root, train=False, transform=val_transforms)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=4)

In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
device = torch.device(DEVICE)

In [None]:
torch.backends.cudnn.benchmark = True

model = timm.create_model(
    MODEL_NAME,
    pretrained=PRETRAINED, 
    num_classes=len(classes), 
    drop_rate=DROP_RATE, 
    drop_path_rate=DROP_PATH_RATE
).to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=1e-4, nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_GAMMA)
scaler = torch.cuda.amp.GradScaler()

In [None]:
def train(model, loader, criterion, optimizer, scaler):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return train_loss / total, 100 * correct / total

In [None]:
def validate(model, loader, criterion):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return val_loss / total, 100 * correct / total

In [None]:
# Train the model
for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, scaler)
    val_loss, val_acc = validate(model, val_dataloader, criterion)

    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.2f}%, Val Loss: {:.4f}, Val Acc: {:.2f}%, lr: {:.6f}'
          .format(epoch+1, EPOCHS, train_loss, train_acc, val_loss, val_acc, lr_scheduler.get_last_lr()[0]))
          
    lr_scheduler.step()