In [1]:
from preprocess import Preprocess
import timm
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
from torchvision import transforms
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRANSFORM = transforms.ToTensor()

In [3]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin: float = 1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, x1, x2, label):
        dist = nn.functional.pairwise_distance(x1, x2)
        # label 1 means similar, 0 means dissimilar
        # when similar, loss is the distance
        # when dissimilar and more distant than the margin, no loss
        # when dissimilar and closer than the margin, loss is the distance to the margin
        loss = label * torch.pow(dist, 2) + (1 - label) * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2)
        loss = torch.mean(loss)
        return loss

In [4]:
class TimmSiameseNetwork(nn.Module):
    def __init__(self):
        super(TimmSiameseNetwork, self).__init__()
        # https://huggingface.co/timm/vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k
        self.model = timm.create_model(
            'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k',
            pretrained=True,
            num_classes=0,  # remove classifier nn.Linear
        )
        # Freeze all layers except the last conv
        for param in self.model.parameters():
            param.requires_grad = False
        for param in self.model.blocks[-1].parameters():
            param.requires_grad = True

        # get model specific transforms (normalization, resize)
        data_config = timm.data.resolve_model_data_config(self.model)
        self.transforms = timm.data.create_transform(**data_config, is_training=True)

    def forward_once(self, img) -> torch.Tensor:
        return self.model(self.transforms(img))

    def forward(self, img1, img2):
        return self.forward_once(img1), self.forward_once(img2)

timm_model = TimmSiameseNetwork()

In [None]:
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 20

train_dataset = Preprocess.load_train_pairs(transform=TRANSFORM)
val_dataset = Preprocess.load_test_pairs(transform=TRANSFORM)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = ContrastiveLoss()
optimizer = optim.Adam(timm_model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

# Training loop
def train_loop(model):
    best_val_loss = float('inf')

    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0.0

        for batch_idx, (img1, img2, labels) in enumerate(train_loader):
            img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

            optimizer.zero_grad()
            output1, output2 = model(img1, img2)
            loss = criterion(output1, output2, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            print(f'Epoch {epoch + 1}, Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}     ', end='\r')

        avg_train_loss = train_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for img1, img2, labels in val_loader:
                img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
                output1, output2 = model(img1, img2)
                loss = criterion(output1, output2, labels)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        # Print epoch results
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # Learning rate scheduling
        scheduler.step(avg_val_loss)

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_siamese_model.pth')

    print("Training completed!")

train_loop(timm_model)