In [1]:
import os
from collections import defaultdict

import numpy as np
import torch
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50
from tqdm import tqdm

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
WEIGHT_PATH = "weights"
RECORD_PATH = "records"
os.makedirs(WEIGHT_PATH, exist_ok=True)
os.makedirs(RECORD_PATH, exist_ok=True)

In [2]:
transform_train = T.Compose(
    [
        T.RandomResizedCrop((224, 224)),
        T.RandAugment(),
        T.ToTensor(),
    ]
)

transform_eval = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor(),
    ]
)

trainset = ImageFolder("birds-400/train", transform=transform_train)
validset = ImageFolder("birds-400/valid", transform=transform_eval)
testset = ImageFolder("birds-400/test", transform=transform_eval)

model = resnet50()
model.fc = torch.nn.Sequential(
    torch.nn.Linear(2048, 512), torch.nn.ReLU(), torch.nn.Linear(512, 128)
)

model.load_state_dict(torch.load(os.path.join(WEIGHT_PATH, "r50_simclr_pretrain.pt")))

for p in model.parameters():
    p.requires_grad = False

model.fc = torch.nn.Linear(2048, len(trainset.class_to_idx))

model = model.to(device)

trainloader = torch.utils.data.DataLoader(trainset, 32, True, pin_memory=True)
validloader = torch.utils.data.DataLoader(validset, 32, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, 32, pin_memory=True)


optimizer = torch.optim.SGD(model.parameters(), 5e-3, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader) * 50)
criterion = torch.nn.CrossEntropyLoss()

In [3]:
history = []
for i in range(1, 51):
    print(f"Epoch {i}/50")
    metric = defaultdict(list)
    pbar = tqdm(total=len(trainloader))

    model.train()
    for i, (inputs, targets) in enumerate(trainloader, 1):
        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        metric["train_loss"].append(loss.item())
        metric["train_accuracy"].append(
            (outputs.argmax(-1) == targets).float().mean().item()
        )
        pbar.set_postfix({k: np.mean(v) for k, v in metric.items()})
        pbar.update()

    model.eval()
    for i, (inputs, targets) in enumerate(validloader, 1):
        inputs = inputs.to(device)
        targets = targets.to(device)

        with torch.inference_mode():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        metric["valid_loss"].append(loss.item())
        metric["valid_accuracy"].append(
            (outputs.argmax(-1) == targets).float().mean().item()
        )
        pbar.set_postfix({k: np.mean(v) for k, v in metric.items()})

    history.append({k: np.mean(v) for k, v in metric.items()})
    pbar.close()
torch.save(model.state_dict(), os.path.join(WEIGHT_PATH, "r50_simclr_finetune.pt"))
torch.save(history, os.path.join(RECORD_PATH, "r50_simclr_finetune.pt"))

Epoch 1/50


100%|██████████| 1825/1825 [03:34<00:00,  8.49it/s, train_loss=4.91, train_accuracy=0.112, valid_loss=4.12, valid_accuracy=0.21]  


Epoch 2/50


100%|██████████| 1825/1825 [03:35<00:00,  8.49it/s, train_loss=4.4, train_accuracy=0.169, valid_loss=3.9, valid_accuracy=0.222]  


Epoch 3/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=4.24, train_accuracy=0.19, valid_loss=3.67, valid_accuracy=0.256]


Epoch 4/50


100%|██████████| 1825/1825 [03:34<00:00,  8.51it/s, train_loss=4.13, train_accuracy=0.205, valid_loss=3.45, valid_accuracy=0.283]


Epoch 5/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=4.04, train_accuracy=0.217, valid_loss=3.4, valid_accuracy=0.291] 


Epoch 6/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=3.95, train_accuracy=0.231, valid_loss=3.23, valid_accuracy=0.312] 


Epoch 7/50


100%|██████████| 1825/1825 [03:34<00:00,  8.51it/s, train_loss=3.9, train_accuracy=0.237, valid_loss=3.31, valid_accuracy=0.305]


Epoch 8/50


100%|██████████| 1825/1825 [03:34<00:00,  8.50it/s, train_loss=3.83, train_accuracy=0.246, valid_loss=3.18, valid_accuracy=0.334]


Epoch 9/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=3.77, train_accuracy=0.252, valid_loss=3.17, valid_accuracy=0.329]


Epoch 10/50


100%|██████████| 1825/1825 [03:34<00:00,  8.50it/s, train_loss=3.69, train_accuracy=0.26, valid_loss=3.03, valid_accuracy=0.339]


Epoch 11/50


100%|██████████| 1825/1825 [03:35<00:00,  8.45it/s, train_loss=3.65, train_accuracy=0.266, valid_loss=3.08, valid_accuracy=0.344] 


Epoch 12/50


100%|██████████| 1825/1825 [03:36<00:00,  8.44it/s, train_loss=3.59, train_accuracy=0.272, valid_loss=2.99, valid_accuracy=0.359]


Epoch 13/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=3.54, train_accuracy=0.28, valid_loss=2.81, valid_accuracy=0.381]


Epoch 14/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=3.5, train_accuracy=0.284, valid_loss=2.95, valid_accuracy=0.36] 


Epoch 15/50


100%|██████████| 1825/1825 [03:34<00:00,  8.50it/s, train_loss=3.43, train_accuracy=0.292, valid_loss=2.93, valid_accuracy=0.359]


Epoch 16/50


100%|██████████| 1825/1825 [03:35<00:00,  8.47it/s, train_loss=3.37, train_accuracy=0.297, valid_loss=2.78, valid_accuracy=0.39] 


Epoch 17/50


100%|██████████| 1825/1825 [03:34<00:00,  8.51it/s, train_loss=3.32, train_accuracy=0.304, valid_loss=2.72, valid_accuracy=0.393] 


Epoch 18/50


100%|██████████| 1825/1825 [03:34<00:00,  8.51it/s, train_loss=3.28, train_accuracy=0.312, valid_loss=2.77, valid_accuracy=0.388]


Epoch 19/50


100%|██████████| 1825/1825 [03:35<00:00,  8.47it/s, train_loss=3.24, train_accuracy=0.313, valid_loss=2.61, valid_accuracy=0.404]


Epoch 20/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=3.18, train_accuracy=0.324, valid_loss=2.57, valid_accuracy=0.409]


Epoch 21/50


100%|██████████| 1825/1825 [03:34<00:00,  8.49it/s, train_loss=3.13, train_accuracy=0.329, valid_loss=2.6, valid_accuracy=0.42]  


Epoch 22/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=3.1, train_accuracy=0.333, valid_loss=2.54, valid_accuracy=0.425]


Epoch 23/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=3.05, train_accuracy=0.342, valid_loss=2.51, valid_accuracy=0.419]


Epoch 24/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=3, train_accuracy=0.346, valid_loss=2.47, valid_accuracy=0.443]


Epoch 25/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=2.97, train_accuracy=0.353, valid_loss=2.38, valid_accuracy=0.441]


Epoch 26/50


100%|██████████| 1825/1825 [03:35<00:00,  8.49it/s, train_loss=2.91, train_accuracy=0.36, valid_loss=2.35, valid_accuracy=0.453]


Epoch 27/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=2.89, train_accuracy=0.366, valid_loss=2.33, valid_accuracy=0.462]


Epoch 28/50


100%|██████████| 1825/1825 [03:35<00:00,  8.47it/s, train_loss=2.86, train_accuracy=0.37, valid_loss=2.28, valid_accuracy=0.466]


Epoch 29/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=2.82, train_accuracy=0.376, valid_loss=2.26, valid_accuracy=0.471]


Epoch 30/50


100%|██████████| 1825/1825 [03:34<00:00,  8.49it/s, train_loss=2.79, train_accuracy=0.379, valid_loss=2.26, valid_accuracy=0.465]


Epoch 31/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=2.75, train_accuracy=0.387, valid_loss=2.21, valid_accuracy=0.474]


Epoch 32/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=2.71, train_accuracy=0.398, valid_loss=2.18, valid_accuracy=0.494]


Epoch 33/50


100%|██████████| 1825/1825 [03:34<00:00,  8.49it/s, train_loss=2.71, train_accuracy=0.4, valid_loss=2.18, valid_accuracy=0.496]


Epoch 34/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=2.68, train_accuracy=0.405, valid_loss=2.13, valid_accuracy=0.478]


Epoch 35/50


100%|██████████| 1825/1825 [03:34<00:00,  8.49it/s, train_loss=2.65, train_accuracy=0.411, valid_loss=2.13, valid_accuracy=0.498]


Epoch 36/50


100%|██████████| 1825/1825 [03:35<00:00,  8.47it/s, train_loss=2.62, train_accuracy=0.421, valid_loss=2.11, valid_accuracy=0.505]


Epoch 37/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=2.61, train_accuracy=0.419, valid_loss=2.1, valid_accuracy=0.512] 


Epoch 38/50


100%|██████████| 1825/1825 [03:35<00:00,  8.49it/s, train_loss=2.59, train_accuracy=0.424, valid_loss=2.09, valid_accuracy=0.505]


Epoch 39/50


100%|██████████| 1825/1825 [03:33<00:00,  8.53it/s, train_loss=2.56, train_accuracy=0.433, valid_loss=2.06, valid_accuracy=0.517]


Epoch 40/50


100%|██████████| 1825/1825 [03:35<00:00,  8.47it/s, train_loss=2.54, train_accuracy=0.436, valid_loss=2.06, valid_accuracy=0.519]


Epoch 41/50


100%|██████████| 1825/1825 [03:34<00:00,  8.50it/s, train_loss=2.53, train_accuracy=0.44, valid_loss=2.04, valid_accuracy=0.526]


Epoch 42/50


100%|██████████| 1825/1825 [03:35<00:00,  8.49it/s, train_loss=2.51, train_accuracy=0.442, valid_loss=2.04, valid_accuracy=0.525]


Epoch 43/50


100%|██████████| 1825/1825 [03:35<00:00,  8.49it/s, train_loss=2.5, train_accuracy=0.448, valid_loss=2.02, valid_accuracy=0.538]


Epoch 44/50


100%|██████████| 1825/1825 [03:34<00:00,  8.50it/s, train_loss=2.49, train_accuracy=0.45, valid_loss=2, valid_accuracy=0.541]   


Epoch 45/50


100%|██████████| 1825/1825 [03:35<00:00,  8.46it/s, train_loss=2.48, train_accuracy=0.454, valid_loss=2.03, valid_accuracy=0.532]


Epoch 46/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=2.47, train_accuracy=0.456, valid_loss=2.01, valid_accuracy=0.542]


Epoch 47/50


100%|██████████| 1825/1825 [03:34<00:00,  8.50it/s, train_loss=2.46, train_accuracy=0.457, valid_loss=2.01, valid_accuracy=0.538]


Epoch 48/50


100%|██████████| 1825/1825 [03:34<00:00,  8.50it/s, train_loss=2.46, train_accuracy=0.456, valid_loss=2, valid_accuracy=0.541]   


Epoch 49/50


100%|██████████| 1825/1825 [03:35<00:00,  8.48it/s, train_loss=2.45, train_accuracy=0.46, valid_loss=2, valid_accuracy=0.545]   


Epoch 50/50


100%|██████████| 1825/1825 [03:35<00:00,  8.47it/s, train_loss=2.46, train_accuracy=0.46, valid_loss=2, valid_accuracy=0.542]   
