In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim

from train import *
import utils
from model import CLR

from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

In [2]:
feature_dim = 128
batch_size = 1024
epochs = 100

# CLR train
CLR_lr = 0.6
CLR_min_lr = 1e-3
CLR_momentum = 0.9
CLR_wd = 1e-6
temperature = 0.4

# Classify train
Class_lr = 1e-3
Class_wd = 1e-4

dataset = "CIFAR10"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
if dataset == "CIFAR10":
    in_ch = 3
    CLR_train_data = utils.CIFAR10Pair(root='/datasets/cv_datasets/data', train=True,
                                       transform=utils.CIFAR10_train_transform, download=True)
    Class_train_data = utils.CIFAR10(root='/datasets/cv_datasets/data', train=True, transform=utils.CIFAR10_test_transform,
                                     download=True)
    test_data = utils.CIFAR10(root='/datasets/cv_datasets/data', train=False, transform=utils.CIFAR10_test_transform,
                              download=True)
else:
    in_ch = 1
    CLR_train_data = utils.MNISTPair(root='./data', train=True,
                                     transform=utils.MNIST_train_transform, download=True)
    Class_train_data = utils.MNIST(root='./data', train=True, transform=utils.MNIST_test_transform,
                                   download=True)
    test_data = utils.MNIST(root='./data', train=False, transform=utils.MNIST_test_transform,
                            download=True)
CLR_train_loader = DataLoader(CLR_train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True,
                              drop_last=True)
Class_train_loader = DataLoader(Class_train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

100%|██████████| 170M/170M [00:16<00:00, 10.6MB/s]


In [5]:
model = CLR(in_ch, feature_dim).to(device)
CLR_optimizer = torch.optim.SGD(model.parameters(), CLR_lr, momentum=CLR_momentum, weight_decay=CLR_wd)
CLR_loss_fn = utils.NTXentLoss
Class_optimizer = torch.optim.SGD(model.classifier.parameters(), lr=Class_lr, momentum=0.9, weight_decay=Class_wd)
Class_loss_fn = nn.CrossEntropyLoss()

warmup = LinearLR(CLR_optimizer, start_factor=1e-4, end_factor=1.0, total_iters=5 * len(CLR_train_loader))

cosine = CosineAnnealingLR(CLR_optimizer, T_max=(epochs - 5) * len(CLR_train_loader), eta_min=CLR_min_lr)

scheduler = SequentialLR(CLR_optimizer, schedulers=[warmup, cosine], milestones=[5 * len(CLR_train_loader)])

In [6]:
# Train encoder
print("----------Train------------")
for epoch in range(1, epochs + 1):
    train_loss = train_epoch(model, CLR_train_loader, CLR_optimizer, CLR_loss_fn, temperature, device, scheduler)
    print('Epoch: {}, Loss: {}'.format(epoch, train_loss))

----------Train------------
Epoch: 1, Loss: 6.32696000734965
Epoch: 2, Loss: 5.863157729307811
Epoch: 3, Loss: 5.741210212310155
Epoch: 4, Loss: 5.669933785994847
Epoch: 5, Loss: 5.571065763632457




Epoch: 6, Loss: 5.504906624555588
Epoch: 7, Loss: 5.454383124907811
Epoch: 8, Loss: 5.410984883705775
Epoch: 9, Loss: 5.386831631263097
Epoch: 10, Loss: 5.3558773795763654
Epoch: 11, Loss: 5.337905645370483
Epoch: 12, Loss: 5.321505566438039
Epoch: 13, Loss: 5.3079849779605865
Epoch: 14, Loss: 5.298394550879796
Epoch: 15, Loss: 5.284863283236821
Epoch: 16, Loss: 5.273321598768234
Epoch: 17, Loss: 5.26464365919431
Epoch: 18, Loss: 5.2531003455321
Epoch: 19, Loss: 5.242337177197139
Epoch: 20, Loss: 5.232055306434631
Epoch: 21, Loss: 5.2245893478393555
Epoch: 22, Loss: 5.218439191579819
Epoch: 23, Loss: 5.215170582135518
Epoch: 24, Loss: 5.2068106432755785
Epoch: 25, Loss: 5.201913644870122
Epoch: 26, Loss: 5.1985509395599365
Epoch: 27, Loss: 5.1916087965170545
Epoch: 28, Loss: 5.183945139249166
Epoch: 29, Loss: 5.184819589058558
Epoch: 30, Loss: 5.175801048676173
Epoch: 31, Loss: 5.170028517643611
Epoch: 32, Loss: 5.169393986463547
Epoch: 33, Loss: 5.164134939511617
Epoch: 34, Loss: 5.15

In [10]:
for param in model.encoder.parameters():
    param.requires_grad = False

# Train classifier
print("------Train Classifier------------")
for epoch in range(1, epochs + 1):
    class_loss, class_acc = train_classifier(model, Class_train_loader, Class_optimizer, Class_loss_fn, device)
    print('Classifier - Epoch: {}, Loss: {}, Accuracy: {}'.format(epoch, class_loss, class_acc))
    test_loss, test_acc = test_epoch(model, test_loader, Class_loss_fn, device)
    print('Test - Epoch: {}, Loss: {}, Accuracy: {}'.format(epoch, test_loss, test_acc))

utils.plot_tsne(model, test_loader, device)

------Train Classifier------------
Classifier - Epoch: 1, Loss: 1.320813939037323, Accuracy: 57.550000000000004
Test - Epoch: 1, Loss: 0.8702166150093079, Accuracy: 71.41999999999999
Classifier - Epoch: 2, Loss: 0.8117879318618775, Accuracy: 72.602
Test - Epoch: 2, Loss: 0.7840889302253723, Accuracy: 73.21
Classifier - Epoch: 3, Loss: 0.753658737449646, Accuracy: 74.19200000000001
Test - Epoch: 3, Loss: 0.7475288650512696, Accuracy: 74.16
Classifier - Epoch: 4, Loss: 0.7224311487007141, Accuracy: 75.136
Test - Epoch: 4, Loss: 0.7246967570304871, Accuracy: 74.96000000000001
Classifier - Epoch: 5, Loss: 0.7025328971672058, Accuracy: 75.652
Test - Epoch: 5, Loss: 0.7098075491905212, Accuracy: 75.32
Classifier - Epoch: 6, Loss: 0.6885229446983337, Accuracy: 75.972
Test - Epoch: 6, Loss: 0.6981242444992065, Accuracy: 75.46000000000001
Classifier - Epoch: 7, Loss: 0.6770308042526245, Accuracy: 76.348
Test - Epoch: 7, Loss: 0.6883173439979553, Accuracy: 75.9
Classifier - Epoch: 8, Loss: 0.666