In [1]:
!pip install lightning wandb



In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import lightning as L
import torchmetrics
import random

In [3]:
class SimilarityModel(nn.Module):

    def __init__(self, dropout_p):
        super(SimilarityModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(2, 8, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),

            nn.Conv2d(8, 16, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),
        )
        self.output_layer = nn.Linear(3 * 3 * 32, 1)

    def forward(self, x):
        x = self.layers(x)
        x = x.view(x.shape[0], -1)
        return self.output_layer(x).view(-1).sigmoid()


class SimilarityModelBig(nn.Module):

    def __init__(self, dropout_p):
        super(SimilarityModelBig, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(2, 8, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),

            nn.Conv2d(8, 16, kernel_size=(2, 2), stride=(1, 1)),
            nn.Dropout(dropout_p),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1)),
            nn.Dropout(dropout_p),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),
        )
        self.output_layer = nn.Linear(2 * 2 * 128, 1)

    def forward(self, x):
        x = self.layers(x)
        x = x.view(x.shape[0], -1)
        return self.output_layer(x).view(-1).sigmoid()


class SiameseModel(nn.Module):

    def __init__(self, dropout_p):
        super(SiameseModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),

            nn.Conv2d(8, 16, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_p),
            nn.ReLU(),
        )
        self.output_layer = nn.Linear(3 * 3 * 32 * 2, 1)

    def forward(self, x):
        img1_activations = self.layers(x[:, 0].unsqueeze(1))
        img2_activations = self.layers(x[:, 1].unsqueeze(1))
        final_layer_input = torch.cat([
            img1_activations, img2_activations
        ], dim=1).view(x.shape[0], -1)
        return self.output_layer(final_layer_input).view(-1).sigmoid()

In [4]:
class SimilarityModelTrainingModule(L.LightningModule):

    def __init__(self, model, optimizer, scheduler, loss_fn, tensor_dict):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.loss_fn = loss_fn
        self.tensor_dict = torch.cat([tensor_dict[i].unsqueeze(0) for i in range(10)], dim=0)

    def training_step(self, batch, batch_idx):
        # Splitting images and targets into pairs
        x, y = batch
        x = x.squeeze(1)
        batch_size, h, w = x.shape
        x = x.view(batch_size // 2, 2, h, w)
        y = y.view(batch_size // 2, 2)
        y = (y[:, 0] == y[:, 1]).to(torch.float32)

        # Passing inputs through model
        yhat = self.model(x)
        loss = self.loss_fn(yhat, y)

        self.log("train-loss", loss)
        self.log("lr", self.scheduler.get_last_lr()[0])
        self.scheduler.step()
        return loss

    def validation_step(self, batch, batch_idx):
        # Splitting images and targets into pairs
        x, y = batch
        yc = y
        x = x.squeeze(1)
        batch_size, h, w = x.shape
        x = x.view(batch_size // 2, 2, h, w)
        y = y.view(batch_size // 2, 2)
        y = (y[:, 0] == y[:, 1]).to(torch.float32)

        yhat = self.model(x)
        loss = self.loss_fn(yhat, y)
        self.log("binary-classification-loss", loss, on_step=False, on_epoch=True)


        bin_class_acc = (yhat.round() == y.round()).sum() / len(yhat)
        self.log("binary-classification-acc", bin_class_acc, on_step=False, on_epoch=True)

        model_input_tensor = torch.cat([self.tensor_dict[i].unsqueeze(0).unsqueeze(0) for i in range(10)], dim=0)
        model_input_tensor = model_input_tensor.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
        x = x.view(batch_size, h, w).unsqueeze(1).unsqueeze(1).repeat(1, 10, 1, 1, 1)
        model_input_tensor = torch.cat([x, model_input_tensor.to(x.device)], dim=2)
        batch_size, digits, ab, h, w = model_input_tensor.shape
        output = self.model(model_input_tensor.view(-1, ab, h, w).to(x.device))
        output = output.view(batch_size, digits).argmax(dim=1)
        acc = (output == yc).sum() / len(yc)
        self.log("classification-acc", acc, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        return [self.optimizer], [self.scheduler]


In [5]:
class MNISTOneShotDataset(Dataset):

    def __init__(self, dataset_size, mnist_transform):
        mnist_dataset = datasets.MNIST(root="./data", download=True, train=True, transform=mnist_transform)
        self.digit_dict = {}
        for x, y in mnist_dataset:
            if y not in self.digit_dict:
                self.digit_dict[y] = []
            else:
                self.digit_dict[y].append(x)
        self.flag = False
        self.val_counter = 0
        self.dataset_size = dataset_size

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        if idx % 4 == 0 or idx % 4 == 2:
            y = self.val_counter
        elif idx % 4 == 1:
            y = self.val_counter
            self.val_counter += 1
        else:
            y = self.val_counter
            possible_values = list(range(10))
            possible_values.pop(y)
            y = random.choice(possible_values)

        x = random.choice(self.digit_dict[y])
        if self.val_counter > 9:
            self.val_counter = 0
        return x, y

In [6]:
DROPOUT_P = 0.4
TRAIN_BATCH_SIZE = 64
VAL_BATCH_SIZE = 128
NUM_WORKERS = 2
LR = 0.00001
EPOCHS = 150
DATASET_SIZE = 100000
LR_STEP_SIZE = 18000
LR_STEP_GAMMA = 0.7
DEVICE = "cuda"

In [7]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
])

# train_dataset = MNISTOneShotDataset(DATASET_SIZE, mnist_transform)
train_dataset = datasets.MNIST(root="./data", download=True, train=True, transform=mnist_transform)
test_dataset = datasets.MNIST(root="./data", download=True, train=False, transform=mnist_transform)

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=VAL_BATCH_SIZE, num_workers=NUM_WORKERS)

tensor_dict = {}
idx = 0
while any([i not in tensor_dict for i in range(10)]):
    x, y = train_dataset[idx]
    if y not in tensor_dict:
        tensor_dict[y] = x[0]
    idx += 1

In [None]:
model = SimilarityModelBig(DROPOUT_P)
optimizer = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.BCELoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_STEP_GAMMA)

module = SimilarityModelTrainingModule(model, optimizer, scheduler, loss_fn, tensor_dict)
logger = L.pytorch.loggers.WandbLogger(project="similarity-model-test2", name="v2.10-scheduler+big_model+unbalanced")
logger.log_hyperparams({
    "dropout_p": DROPOUT_P,
    "train_batch_size": TRAIN_BATCH_SIZE,
    "val_batch_size": VAL_BATCH_SIZE,
    "num_workers": NUM_WORKERS,
    "lr": LR,
    "lr_step_size":LR_STEP_SIZE,
    "lr_step_gamma":LR_STEP_GAMMA,
    "epochs": EPOCHS,
    "device": DEVICE
})

trainer = L.Trainer(
    max_epochs = EPOCHS,
    accelerator = DEVICE,
    logger = logger
)
trainer.fit(module, train_loader, test_loader)

[34m[1mwandb[0m: Currently logged in as: [33mgursi26[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name    | Type               | Params
-----------------------------------------------
0 | model   | SimilarityModelBig | 44.3 K
1 | loss_fn | BCELoss            | 0     
-----------------------------------------------
44.3 K    Trainable params
0         Non-trainable params
44.3 K    Total params
0.177     Total estimated model params size (MB)
INFO:light

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]