In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim

from train import *
import utils
from model import CLR

from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

In [2]:
feature_dim = 128
batch_size = 1024
epochs = 20

# CLR train
CLR_lr = 0.6
CLR_min_lr = 1e-3
CLR_momentum = 0.9
CLR_wd = 1e-6
temperature = 0.4

# Classify train
Class_lr = 1e-3
Class_wd = 1e-4

dataset = "MINIST"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
if dataset == "CIFAR10":
    in_ch = 3
    CLR_train_data = utils.CIFAR10Pair(root='/datasets/cv_datasets/data', train=True,
                                       transform=utils.CIFAR10_train_transform, download=True)
    Class_train_data = utils.CIFAR10(root='/datasets/cv_datasets/data', train=True, transform=utils.CIFAR10_test_transform,
                                     download=True)
    test_data = utils.CIFAR10(root='/datasets/cv_datasets/data', train=False, transform=utils.CIFAR10_test_transform,
                              download=True)
else:
    in_ch = 1
    CLR_train_data = utils.MNISTPair(root='./data', train=True,
                                     transform=utils.MNIST_train_transform, download=True)
    Class_train_data = utils.MNIST(root='./data', train=True, transform=utils.MNIST_test_transform,
                                   download=True)
    test_data = utils.MNIST(root='./data', train=False, transform=utils.MNIST_test_transform,
                            download=True)
CLR_train_loader = DataLoader(CLR_train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True,
                              drop_last=True)
Class_train_loader = DataLoader(Class_train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [5]:
model = CLR(in_ch, feature_dim).to(device)
CLR_optimizer = torch.optim.SGD(model.parameters(), CLR_lr, momentum=CLR_momentum, weight_decay=CLR_wd)
CLR_loss_fn = utils.NTXentLoss
Class_optimizer = torch.optim.SGD(model.classifier.parameters(), lr=Class_lr, momentum=0.9, weight_decay=Class_wd)
Class_loss_fn = nn.CrossEntropyLoss()

warmup = LinearLR(CLR_optimizer, start_factor=1e-4, end_factor=1.0, total_iters=5 * len(CLR_train_loader))

cosine = CosineAnnealingLR(CLR_optimizer, T_max=(epochs - 5) * len(CLR_train_loader), eta_min=CLR_min_lr)

scheduler = SequentialLR(CLR_optimizer, schedulers=[warmup, cosine], milestones=[5 * len(CLR_train_loader)])

In [6]:
# Train encoder
print("----------Train------------")
for epoch in range(1, epochs + 1):
    train_loss = train_epoch(model, CLR_train_loader, CLR_optimizer, CLR_loss_fn, temperature, device, scheduler)
    print('Epoch: {}, Loss: {}'.format(epoch, train_loss))

----------Train------------
Epoch: 1, Loss: 6.276703916746994
Epoch: 2, Loss: 5.674489473474437
Epoch: 3, Loss: 5.40928107294543
Epoch: 4, Loss: 5.259128586999301
Epoch: 5, Loss: 5.131481326859573




Epoch: 6, Loss: 5.040282117909398
Epoch: 7, Loss: 4.971941446435863
Epoch: 8, Loss: 4.936141696469537
Epoch: 9, Loss: 4.903209982247188
Epoch: 10, Loss: 4.87939661124657
Epoch: 11, Loss: 4.859683513641357
Epoch: 12, Loss: 4.848694308050748
Epoch: 13, Loss: 4.830592821384299
Epoch: 14, Loss: 4.817731577774574
Epoch: 15, Loss: 4.809844715841885
Epoch: 16, Loss: 4.803453675631819
Epoch: 17, Loss: 4.798894725996872
Epoch: 18, Loss: 4.787950910370926
Epoch: 19, Loss: 4.7876068230332995
Epoch: 20, Loss: 4.785810972082204


In [7]:
for param in model.encoder.parameters():
    param.requires_grad = False

# Train classifier
print("------Train Classifier------------")
for epoch in range(1, epochs + 1):
    class_loss, class_acc = train_classifier(model, Class_train_loader, Class_optimizer, Class_loss_fn, device)
    print('Classifier - Epoch: {}, Loss: {}, Accuracy: {}'.format(epoch, class_loss, class_acc))
    test_loss, test_acc = test_epoch(model, test_loader, Class_loss_fn, device)
    print('Test - Epoch: {}, Loss: {}, Accuracy: {}'.format(epoch, test_loss, test_acc))

utils.plot_tsne(model, test_loader, device)

------Train Classifier------------
Classifier - Epoch: 1, Loss: 0.7018762108564377, Accuracy: 83.07666666666667
Test - Epoch: 1, Loss: 0.22243525574207307, Accuracy: 95.15
Classifier - Epoch: 2, Loss: 0.1990634534597397, Accuracy: 95.57833333333333
Test - Epoch: 2, Loss: 0.1575004495024681, Accuracy: 96.58
Classifier - Epoch: 3, Loss: 0.15882994296948116, Accuracy: 96.25833333333334
Test - Epoch: 3, Loss: 0.1326299097776413, Accuracy: 96.99
Classifier - Epoch: 4, Loss: 0.13904957903226217, Accuracy: 96.66666666666667
Test - Epoch: 4, Loss: 0.1184900804400444, Accuracy: 97.2
Classifier - Epoch: 5, Loss: 0.12611327072381973, Accuracy: 96.87833333333333
Test - Epoch: 5, Loss: 0.10897275111675263, Accuracy: 97.26
Classifier - Epoch: 6, Loss: 0.11680334549744924, Accuracy: 97.09333333333333
Test - Epoch: 6, Loss: 0.10182454489469528, Accuracy: 97.35000000000001
Classifier - Epoch: 7, Loss: 0.11040541930596033, Accuracy: 97.09333333333333
Test - Epoch: 7, Loss: 0.09694757107496262, Accuracy: