# SimSiam Model

In [None]:
import os
import torch
from torch.utils.data import DataLoader, ConcatDataset, Dataset
from torchvision import transforms, datasets
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import random
from torch.utils.data import Subset
from torch.cuda.amp import autocast, GradScaler

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_ROOT = "/kaggle/input/ssl-dataset/ssl_dataset"
TRAIN_FOLDERS = [f"train.X{i}" for i in range(1, 5)]
VAL_FOLDER = "val.X"

BATCH_SIZE = 64       
NUM_WORKERS = 2        
EPOCHS = 50
BASE_LR = 0.05         
WEIGHT_DECAY = 1e-4
MOMENTUM = 0.9

CHECKPOINT_DIR = "/kaggle/working/"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)


In [None]:
class ProjectionMLP(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=512):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x
    
class PredictionMLP(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=256, out_dim=512):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x


# SimSiam with ResNet‑18 backbone

class SimSiam(nn.Module):
    def __init__(self, backbone="resnet18", pretrained_backbone=False):
        super().__init__()
        if backbone == "resnet18":
            base = models.resnet18(pretrained=pretrained_backbone)
            
            self.encoder = nn.Sequential(*list(base.children())[:-1])
            feat_dim = base.fc.in_features 
        else:
            raise NotImplementedError("Only resnet18 is supported here.")

       
        self.projector = ProjectionMLP(in_dim=feat_dim,
                                       hidden_dim=feat_dim,
                                       out_dim=feat_dim)
        
        self.predictor = PredictionMLP(in_dim=feat_dim,
                                       hidden_dim=256,
                                       out_dim=feat_dim)

    def forward_backbone(self, x):
        feat = self.encoder(x)          
        feat = torch.flatten(feat, 1)   
        return feat

    def forward(self, view1, view2):
        f1 = self.forward_backbone(view1)  
        f2 = self.forward_backbone(view2)

        z1 = self.projector(f1)            
        z2 = self.projector(f2)

        p1 = self.predictor(z1)             
        p2 = self.predictor(z2)

        return p1, p2, z1.detach(), z2.detach()

def simsiam_loss(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return - (p * z).sum(dim=1).mean()


## Data Transformations

In [None]:
simsiam_transform = transforms.Compose([
    transforms.RandomResizedCrop(96),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])


class SimSiamDataset(Dataset):
    def __init__(self, base_dataset, transform):
        self.base = base_dataset
        self.transform = transform

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, _ = self.base[idx]
        view1 = self.transform(img)
        view2 = self.transform(img)
        return view1, view2


train_datasets = []
for folder in TRAIN_FOLDERS:
    full_path = os.path.join(DATA_ROOT, folder)
    train_datasets.append(datasets.ImageFolder(full_path, transform=None))


combined_train = ConcatDataset(train_datasets)
total_samples = len(combined_train)   


random.seed(0)
all_indices = list(range(total_samples))
subsample_size = 50_000
subsample_indices = random.sample(all_indices, subsample_size)


subsampled_train = Subset(combined_train, subsample_indices)


simsiam_train = SimSiamDataset(subsampled_train, simsiam_transform)

train_loader = DataLoader(
    simsiam_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True
)


val_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])
val_dataset = datasets.ImageFolder(
    os.path.join(DATA_ROOT, VAL_FOLDER),
    transform=val_transform
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


## Model formation

In [None]:
model = SimSiam(backbone="resnet18", pretrained_backbone=False).to(DEVICE)

optimizer = SGD(
    model.parameters(),
    lr=BASE_LR,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=0)


In [None]:
scaler = GradScaler()

def train_one_epoch(model, loader, optimizer, scaler):
    model.train()
    total_loss = 0.0
    num_samples = 0

    for (view1, view2) in tqdm(loader, desc="Pretrain SimSiam"):
        view1 = view1.to(DEVICE, non_blocking=True)
        view2 = view2.to(DEVICE, non_blocking=True)

        optimizer.zero_grad()
        with autocast():
            p1, p2, z1, z2 = model(view1, view2)
            loss1 = simsiam_loss(p1, z2)
            loss2 = simsiam_loss(p2, z1)
            loss = 0.5 * (loss1 + loss2)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        batch_sz = view1.size(0)
        total_loss += loss.item() * batch_sz
        num_samples += batch_sz

    return total_loss / num_samples


## Training Loop

In [None]:
best_loss = float("inf")
scaler = GradScaler()

for epoch in range(1, EPOCHS + 1):
    epoch_loss = train_one_epoch(model, train_loader, optimizer, scaler)
    scheduler.step()

    print(f"Epoch [{epoch}/{EPOCHS}]  Pretrain Loss = {epoch_loss:.4f}")

    
    ckpt_path = os.path.join(CHECKPOINT_DIR, f"simsiam_r18_epoch{epoch}.pth")
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "loss": epoch_loss,
    }, ckpt_path)

  
    if epoch_loss < best_loss:
        best_loss = epoch_loss

print("→ Pretraining complete.")
