In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from models import *
from trainers import *
from eval_tools.lossfuncs import NTXentLoss
from data_utils import *

In [2]:
device = 'cuda:0'

model = ResNetSimCLR(base_model='resnet10', out_dim=128, from_small=True)
dataloaders, dataset_sizes = data_loader('cifar100', 
                                         './data/cifar100', 
                                         rep_augment='simclr',
                                         batch_size=128)

optimizer = optim.SGD(model.parameters(), lr=0.3, momentum=0.9, weight_decay=10e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=80, eta_min=0.00001)

criterion = NTXentLoss(device, 
                       batch_size=128, 
                       temperature=0.5, 
                       use_cosine_similarity=True)

Feature extractor: resnet10
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
trainer = SimCLRTrainer(model, 
                        dataloaders, 
                        dataset_sizes, 
                        criterion, 
                        optimizer, 
                        scheduler, 
                        device)

In [4]:
probe_loaders, probe_sizes = data_loader('cifar100', './data/cifar100', rep_augment=None, batch_size=128)

probe_setup = {
    'dataloaders': probe_loaders,
    'dataset_sizes': probe_sizes,
    'num_classes': 100
}

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [5]:
trainer.train(num_epochs=80, probe_freq=5, probe_setup=probe_setup)

Train start on device: TITAN RTX

[Epoch 1/80] Elapsed 56.343s/it
[Train] Loss - 10.6922, Loss - 10.69 Learning Rate - 0.299884
[Valid] Loss - 10.6304, Loss - 10.63 
[Epoch 2/80] Elapsed 64.575s/it
[Train] Loss - 10.4531, Loss - 10.45 Learning Rate - 0.299538
[Valid] Loss - 10.3227, Loss - 10.32 
[Epoch 3/80] Elapsed 52.839s/it
[Train] Loss - 10.1052, Loss - 10.11 Learning Rate - 0.298960
[Valid] Loss - 9.8763, Loss - 9.88 
[Epoch 4/80] Elapsed 52.619s/it
[Train] Loss - 9.7869, Loss - 9.79 Learning Rate - 0.298153
[Valid] Loss - 9.7355, Loss - 9.74 
[Epoch 5/80] Elapsed 53.638s/it
[Train] Loss - 9.6262, Loss - 9.63 Learning Rate - 0.297118
[Valid] Loss - 9.7876, Loss - 9.79 
[Linear Probe Epoch 1/20]
[Train] Loss - 2231.5042, Top1 Acc - 3.11% 
[Valid] Loss - 2060.8778, Top1 Acc - 4.00% 
[Valid] Top5 Acc - 13.12% 
---------------------------------------------
[Linear Probe Epoch 2/20]
[Train] Loss - 1930.8866, Top1 Acc - 3.96% 
[Valid] Loss - 1308.3941, Top1 Acc - 5.50% 
[Valid] Top5 Ac