In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim

from train import *
import utils
from model import Self_Supervised
from torchvision.datasets import CIFAR10, MNIST

In [2]:
feature_dim = 128
batch_size = 256
epochs = 50

# Reconstruction train
Rec_lr = 0.0001
betas = (0.9, 0.99)

# Classify train
Class_lr = 0.0005
Class_wd = 1e-6

dataset = "CIFAR10"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
if dataset == "CIFAR10":
    in_ch = 3
    train_data = CIFAR10(root='/datasets/cv_datasets/data', train=True, transform=utils.CIFAR10_train_transform,
                         download=True)
    Class_train_data = CIFAR10(root='/datasets/cv_datasets/data', train=True, transform=utils.CIFAR10_transform,
                         download=True)
    test_data = CIFAR10(root='/datasets/cv_datasets/data', train=False, transform=utils.CIFAR10_transform,
                        download=True)
else:
    in_ch = 1
    train_data = MNIST(root='./data', train=True, transform=utils.MNIST_train_transform, download=True)
    Class_train_data = MNIST(root='./data', train=True, transform=utils.MNIST_transform, download=True)
    test_data = MNIST(root='./data', train=False, transform=utils.MNIST_transform, download=True)

In [5]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
Class_train_loader = DataLoader(Class_train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

model = Self_Supervised(in_ch, feature_dim).to(device)
Rec_optimizer = optim.AdamW(model.parameters(), lr=Rec_lr, betas=betas, weight_decay=1e-4)
Rec_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(Rec_optimizer, T_max=epochs)
Rec_loss_fn = nn.MSELoss()

Class_optimizer = torch.optim.Adam(model.classifier.parameters(), lr=Class_lr, betas=betas, weight_decay=Class_wd)
Class_loss_fn = nn.CrossEntropyLoss()

In [None]:
# Train encoder
print("----------Train------------")
for epoch in range(1, epochs + 1):
    train_loss = train_epoch(model, train_loader, Rec_optimizer, Rec_loss_fn, device)
    print('Epoch: {}, Loss: {}'.format(epoch, train_loss))
    Rec_scheduler.step()

print("plotting")
utils.plot_tsne(model, test_loader, device)

----------Train------------


In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False

# Train classifier
print("------Train Classifier------------")
for epoch in range(1, epochs + 1):
    class_loss, class_acc = train_classifier(model, Class_train_loader, Class_optimizer, Class_loss_fn, device)
    print('Classifier - Epoch: {}, Loss: {}, Accuracy: {}'.format(epoch, class_loss, class_acc))
    test_loss, test_acc = test_epoch(model, test_loader, Class_loss_fn, device)
    print('Test - Epoch: {}, Loss: {}, Accuracy: {}'.format(epoch, test_loss, test_acc))