In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim

from train import *
import utils
from model import CLR

In [2]:
feature_dim = 128
batch_size = 512
epochs = 10

# CLR train
CLR_lr = 1e-3
CLR_wd = 1e-6
temperature = 0.2

#Classify train
Class_lr = 1e-3
Class_wd = 1e-4

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
CLR_train_data = utils.CIFAR10Pair(root='/datasets/cv_datasets/data', train=True, transform=utils.train_transform, download=True)
CLR_train_loader = DataLoader(CLR_train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
Class_train_data = utils.CIFAR10(root='/datasets/cv_datasets/data', train=True, transform=utils.test_transform, download=True)
Class_train_loader = DataLoader(Class_train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_data = utils.CIFAR10(root='/datasets/cv_datasets/data', train=False, transform=utils.test_transform, download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [5]:
model = CLR(feature_dim).to(device)
CLR_optimizer = optim.Adam(model.parameters(), lr=CLR_lr, weight_decay=CLR_wd)
CLR_loss_fn = utils.NTXentLoss 
Class_optimizer = torch.optim.SGD(model.fc.parameters(), lr=Class_lr, momentum=0.9, weight_decay=Class_wd)
Class_loss_fn = nn.CrossEntropyLoss()

In [None]:
# Train encoder
for epoch in range(1, epochs + 1):
    train_loss = train_epoch(model, CLR_train_loader, CLR_optimizer, CLR_loss_fn, temperature, device)
    print('Epoch: {}, Loss: {}'.format(epoch, train_loss))

print("plotting")
utils.plot_tsne(model, test_loader, device)

Training: 100%|██████████| 97/97 [01:27<00:00,  1.11it/s]


Epoch: 1, Loss: 5.599375090648219


Training: 100%|██████████| 97/97 [01:24<00:00,  1.14it/s]


Epoch: 2, Loss: 5.076444984711323


Training: 100%|██████████| 97/97 [01:25<00:00,  1.13it/s]


Epoch: 3, Loss: 4.859454110725639


Training: 100%|██████████| 97/97 [01:27<00:00,  1.11it/s]


Epoch: 4, Loss: 4.713647970219248


Training: 100%|██████████| 97/97 [01:27<00:00,  1.11it/s]


Epoch: 5, Loss: 4.644251149954255


Training: 100%|██████████| 97/97 [01:27<00:00,  1.11it/s]


Epoch: 6, Loss: 4.571394581155679


Training: 100%|██████████| 97/97 [01:26<00:00,  1.12it/s]


Epoch: 7, Loss: 4.519970308874071


Training: 100%|██████████| 97/97 [01:26<00:00,  1.12it/s]


Epoch: 8, Loss: 4.50058686856142


Training: 100%|██████████| 97/97 [01:26<00:00,  1.12it/s]


Epoch: 9, Loss: 4.471167903585532


Training: 100%|██████████| 97/97 [01:26<00:00,  1.12it/s]

Epoch: 10, Loss: 4.435387409839434
plotting





In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False
    
for epoch in range(1, epochs + 1):
    class_loss, class_acc = train_classifier(model, Class_train_loader, Class_optimizer, Class_loss_fn, device)
    print('Classifier - Epoch: {}, Loss: {}, Accuracy: {}'.format(epoch, class_loss, class_acc))
    test_loss, test_acc = test_epoch(model, test_loader, Class_loss_fn, device)
    print('Test - Epoch: {}, Loss: {}, Accuracy: {}'.format(epoch, test_loss, test_acc))