In [1]:
import os
from collections import defaultdict

import helper
import numpy as np
import torch
from torchvision.models import resnet50
from tqdm import tqdm

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
WEIGHT_PATH = "weights"
RECORD_PATH = "records"
os.makedirs(WEIGHT_PATH, exist_ok=True)
os.makedirs(RECORD_PATH, exist_ok=True)

In [2]:
trainset = helper.SimCLRDataset("birds-400/train")
validset = helper.SimCLRDataset("birds-400/valid")
testset = helper.SimCLRDataset("birds-400/test")

model = resnet50()
model.fc = torch.nn.Sequential(
    torch.nn.Linear(2048, 512), torch.nn.ReLU(), torch.nn.Linear(512, 128)
)

model = model.to(device)

trainloader = torch.utils.data.DataLoader(trainset, 32, True, pin_memory=True)
validloader = torch.utils.data.DataLoader(validset, 32, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, 32, pin_memory=True)


optimizer = torch.optim.SGD(model.parameters(), 1e-3, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader) * 50)
criterion = helper.NTXentLoss()

In [3]:
history = []
for i in range(1, 101):
    print(f"Epoch {i}/100")
    metric = defaultdict(list)
    pbar = tqdm(total=len(trainloader))

    model.train()
    for i, (inputs_1, inputs_2) in enumerate(trainloader, 1):
        inputs_1 = inputs_1.to(device)
        inputs_2 = inputs_2.to(device)

        outputs_1 = model(inputs_1)
        outputs_2 = model(inputs_2)
        loss = criterion(outputs_1, outputs_2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        metric["train_loss"].append(loss.item())
        pbar.set_postfix({k: np.mean(v) for k, v in metric.items()})
        pbar.update()

    model.eval()
    for i, (inputs_1, inputs_2) in enumerate(validloader, 1):
        inputs_1 = inputs_1.to(device)
        inputs_2 = inputs_2.to(device)

        with torch.inference_mode():
            outputs_1 = model(inputs_1)
            outputs_2 = model(inputs_2)
            loss = criterion(outputs_1, outputs_2)

        metric["valid_loss"].append(loss.item())
        pbar.set_postfix({k: np.mean(v) for k, v in metric.items()})

    history.append({k: np.mean(v) for k, v in metric.items()})
    pbar.close()
torch.save(model.state_dict(), os.path.join(WEIGHT_PATH, "r50_simclr_pretrain.pt"))
torch.save(history, os.path.join(RECORD_PATH, "r50_simclr_pretrain.pt"))

Epoch 1/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=2.55, valid_loss=2.12]


Epoch 2/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=1.69, valid_loss=1.87]


Epoch 3/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=1.4, valid_loss=1.67]


Epoch 4/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=1.25, valid_loss=1.61]


Epoch 5/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=1.13, valid_loss=1.51]


Epoch 6/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=1.04, valid_loss=1.49]


Epoch 7/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.952, valid_loss=1.31]


Epoch 8/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.889, valid_loss=1.31]


Epoch 9/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.844, valid_loss=1.25]


Epoch 10/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.802, valid_loss=1.21]


Epoch 11/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.756, valid_loss=1.17]


Epoch 12/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.732, valid_loss=1.11]


Epoch 13/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.707, valid_loss=1.11] 


Epoch 14/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.667, valid_loss=1.13]


Epoch 15/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.657, valid_loss=1.06] 


Epoch 16/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.642, valid_loss=1.06]


Epoch 17/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.61, valid_loss=1.01] 


Epoch 18/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.604, valid_loss=0.944]


Epoch 19/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.583, valid_loss=0.991]


Epoch 20/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.565, valid_loss=0.991]


Epoch 21/100


100%|██████████| 1825/1825 [16:23<00:00,  1.86it/s, train_loss=0.556, valid_loss=0.956]


Epoch 22/100


100%|██████████| 1825/1825 [16:23<00:00,  1.85it/s, train_loss=0.55, valid_loss=0.935]


Epoch 23/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.536, valid_loss=0.963]


Epoch 24/100


100%|██████████| 1825/1825 [16:23<00:00,  1.86it/s, train_loss=0.53, valid_loss=0.933]


Epoch 25/100


100%|██████████| 1825/1825 [16:22<00:00,  1.86it/s, train_loss=0.521, valid_loss=0.907]


Epoch 26/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.51, valid_loss=0.878]


Epoch 27/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.494, valid_loss=0.835]


Epoch 28/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.483, valid_loss=0.866]


Epoch 29/100


100%|██████████| 1825/1825 [16:23<00:00,  1.86it/s, train_loss=0.481, valid_loss=0.858]


Epoch 30/100


100%|██████████| 1825/1825 [16:22<00:00,  1.86it/s, train_loss=0.476, valid_loss=0.863]


Epoch 31/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.469, valid_loss=0.836]


Epoch 32/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.468, valid_loss=0.854]


Epoch 33/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.46, valid_loss=0.829]


Epoch 34/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.458, valid_loss=0.777]


Epoch 35/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.44, valid_loss=0.824]


Epoch 36/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.441, valid_loss=0.792]


Epoch 37/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.435, valid_loss=0.812]


Epoch 38/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.436, valid_loss=0.812]


Epoch 39/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.434, valid_loss=0.798]


Epoch 40/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.428, valid_loss=0.782]


Epoch 41/100


100%|██████████| 1825/1825 [16:23<00:00,  1.85it/s, train_loss=0.42, valid_loss=0.805]


Epoch 42/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.424, valid_loss=0.788]


Epoch 43/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.419, valid_loss=0.793]


Epoch 44/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.416, valid_loss=0.763]


Epoch 45/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.421, valid_loss=0.798]


Epoch 46/100


100%|██████████| 1825/1825 [16:29<00:00,  1.84it/s, train_loss=0.421, valid_loss=0.771]


Epoch 47/100


100%|██████████| 1825/1825 [16:28<00:00,  1.85it/s, train_loss=0.412, valid_loss=0.744]


Epoch 48/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.411, valid_loss=0.786]


Epoch 49/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.413, valid_loss=0.767]


Epoch 50/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.417, valid_loss=0.765]


Epoch 51/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.41, valid_loss=0.771]


Epoch 52/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.417, valid_loss=0.748]


Epoch 53/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.419, valid_loss=0.741]


Epoch 54/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.418, valid_loss=0.761]


Epoch 55/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.421, valid_loss=0.769]


Epoch 56/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.413, valid_loss=0.81] 


Epoch 57/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.414, valid_loss=0.792]


Epoch 58/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.414, valid_loss=0.738]


Epoch 59/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.412, valid_loss=0.759]


Epoch 60/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.417, valid_loss=0.764]


Epoch 61/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.422, valid_loss=0.769]


Epoch 62/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.416, valid_loss=0.776]


Epoch 63/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.422, valid_loss=0.746]


Epoch 64/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.418, valid_loss=0.766]


Epoch 65/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.425, valid_loss=0.762]


Epoch 66/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.426, valid_loss=0.788]


Epoch 67/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.421, valid_loss=0.798]


Epoch 68/100


100%|██████████| 1825/1825 [16:23<00:00,  1.86it/s, train_loss=0.422, valid_loss=0.812]


Epoch 69/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.424, valid_loss=0.781]


Epoch 70/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.417, valid_loss=0.807]


Epoch 71/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.425, valid_loss=0.72] 


Epoch 72/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.426, valid_loss=0.793]


Epoch 73/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.429, valid_loss=0.771]


Epoch 74/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.42, valid_loss=0.81] 


Epoch 75/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.426, valid_loss=0.796]


Epoch 76/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.422, valid_loss=0.891]


Epoch 77/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.414, valid_loss=0.765]


Epoch 78/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.423, valid_loss=0.792]


Epoch 79/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.421, valid_loss=1]    


Epoch 80/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.42, valid_loss=0.8]  


Epoch 81/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.419, valid_loss=0.773]


Epoch 82/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.418, valid_loss=0.823]


Epoch 83/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.409, valid_loss=0.813]


Epoch 84/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.409, valid_loss=0.873]


Epoch 85/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.402, valid_loss=0.789]


Epoch 86/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.401, valid_loss=0.901]


Epoch 87/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.398, valid_loss=0.754]


Epoch 88/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.399, valid_loss=0.753]


Epoch 89/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.392, valid_loss=0.769]


Epoch 90/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.394, valid_loss=0.825]


Epoch 91/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.388, valid_loss=0.835]


Epoch 92/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.386, valid_loss=0.798]


Epoch 93/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.384, valid_loss=0.796]


Epoch 94/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.38, valid_loss=0.743]


Epoch 95/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.371, valid_loss=0.749]


Epoch 96/100


100%|██████████| 1825/1825 [16:24<00:00,  1.85it/s, train_loss=0.373, valid_loss=0.781]


Epoch 97/100


100%|██████████| 1825/1825 [16:26<00:00,  1.85it/s, train_loss=0.363, valid_loss=0.771]


Epoch 98/100


100%|██████████| 1825/1825 [16:27<00:00,  1.85it/s, train_loss=0.369, valid_loss=0.713]


Epoch 99/100


100%|██████████| 1825/1825 [16:25<00:00,  1.85it/s, train_loss=0.356, valid_loss=0.883]


Epoch 100/100


100%|██████████| 1825/1825 [16:23<00:00,  1.86it/s, train_loss=0.353, valid_loss=0.728]
