In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import sys
sys.path.append('..')
from models.common_autoencoder_blocks import Encoder  
from models.mnist_supconv import MNISTSupCon
from trainers.supcon_trainer import SupConTrainer
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
class TwoCropsTransform:
    """
    Given one PIL image, apply base_transform twice
    to create two 'views' of the same image.
    """
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        im1 = self.base_transform(x)
        im2 = self.base_transform(x)
        return im1, im2

base_transform = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),  # Random crop
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), shear=5),  # Random translation
    transforms.ToTensor(),
])

train_transform = TwoCropsTransform(base_transform)
val_base_transform = transforms.Compose([transforms.ToTensor()])
val_transform = TwoCropsTransform(val_base_transform)

data_path = "/datasets/cv_datasets/data"
train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=train_transform)
val_dataset   = datasets.MNIST(root=data_path, train=False, download=True, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, drop_last=True)
val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=True)

In [None]:
for (im1, im2), labels in train_loader:
    # im1, im2: shape [B, C, H, W]
    # labels: shape [B]
    batch_size = labels.size(0)
    big_batch = torch.cat([im1, im2], dim=0)
    labels_rep = torch.cat([labels, labels], dim=0)

    print (f"big_batch: {big_batch.shape}, labels_rep: {labels_rep.shape}, batch_size: {batch_size}")
    break

big_batch: torch.Size([128, 1, 28, 28]), labels_rep: torch.Size([128]), batch_size: 64


In [4]:

model = MNISTSupCon(
    input_shape=(1,28,28),
    channels=[32,64],
    kernel_sizes=[3,3],
    strides=[2,2],
    paddings=[1,1],
    latent_dim=128,
    batch_norm_conv=True
)

In [5]:
[print(model)]

MNISTSupCon(
  (encoder): Encoder(
    (conv): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Dropout2d(p=0.2, inplace=False)
      (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): Dropout2d(p=0.2, inplace=False)
    )
    (fc): Sequential(
      (0): Linear(in_features=3136, out_features=1024, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=1024, out_features=128, bias=True)
    )
  )
)


[None]

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

trainer = SupConTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    device=device,
    num_epochs=50,
    patience=5,
    save_path='mnist_supcon.pth'  
)

In [7]:

# (5) Train
trainer.train()

Epoch [1] Train Loss: 3.7111
Epoch [1] Val   Loss: 2.7990
Epoch [2] Train Loss: 2.9478
Epoch [2] Val   Loss: 2.7230
Epoch [3] Train Loss: 2.8712
Epoch [3] Val   Loss: 2.6999
Epoch [4] Train Loss: 2.8400
Epoch [4] Val   Loss: 2.6979
Epoch [5] Train Loss: 2.8229
Epoch [5] Val   Loss: 2.6914
Epoch [6] Train Loss: 2.8108
Epoch [6] Val   Loss: 2.6918
Epoch [7] Train Loss: 2.8045
Epoch [7] Val   Loss: 2.6802
Epoch [8] Train Loss: 2.8021
Epoch [8] Val   Loss: 2.6849
Epoch [9] Train Loss: 2.7949
Epoch [9] Val   Loss: 2.6832
Epoch [10] Train Loss: 2.7969
Epoch [10] Val   Loss: 2.6814
Epoch [11] Train Loss: 2.7865
Epoch [11] Val   Loss: 2.6833
Epoch [12] Train Loss: 2.7888
Epoch [12] Val   Loss: 2.6766
Epoch [13] Train Loss: 2.7858
Epoch [13] Val   Loss: 2.6779
Epoch [14] Train Loss: 2.7860
Epoch [14] Val   Loss: 2.6758
Epoch [15] Train Loss: 2.7832
Epoch [15] Val   Loss: 2.6735
Epoch [16] Train Loss: 2.7790
Epoch [16] Val   Loss: 2.6722
Epoch [17] Train Loss: 2.7801
Epoch [17] Val   Loss: 2.668

In [8]:
#  Results after 10 epochs
eval_transform = transforms.ToTensor()
eval_dataset = datasets.MNIST(root=data_path, train=False, download=False, transform=eval_transform)
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)

acc = trainer.classify_evaluation(eval_loader, epochs=5)
print(f'Classification Accuracy: {acc:.2f}%')

Starting linear evaluation...
  [LinearEval] epoch 1/5, loss=2.3634
  [LinearEval] epoch 2/5, loss=2.1465
  [LinearEval] epoch 3/5, loss=1.9610
  [LinearEval] epoch 4/5, loss=1.7908
  [LinearEval] epoch 5/5, loss=1.6333
Linear Evaluation Accuracy: 98.49%
Classification Accuracy: 98.49%


In [12]:
trainer = SupConTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    device=device,
    num_epochs=30,
    patience=3,
    resume_path='../checkpoints/mnist_supcon.pth',
    save_path='../checkpoints/mnist_supcon_40_epochs.pth'
)
trainer.train()

TypeError: __init__() got an unexpected keyword argument 'resume_path'