# Task : Self-supervised learning 

## Loading Data

In [None]:
import os
from glob import glob
from PIL import Image
import torch
base_path = "/kaggle/input/ssl-dataset/ssl_dataset"
train_dirs = [os.path.join(base_path, f"train.X{i}") for i in range(1, 5)]

def get_all_image_paths():
    image_paths = []
    for train_dir in train_dirs:
        for class_folder in os.listdir(train_dir):
            class_path = os.path.join(train_dir, class_folder)
            image_paths.extend(glob(os.path.join(class_path, "*.JPEG")))
    return image_paths

image_paths = get_all_image_paths()
print(f"Total Training Images: {len(image_paths)}")  


## Transformations

In [None]:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets.folder import default_loader
import random

simclr_transform = transforms.Compose([
    transforms.RandomResizedCrop(96),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(128),
    transforms.RandomApply([transforms.ColorJitter()], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])


class SimCLRDataset(Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = default_loader(self.image_paths[idx])
        return self.transform(image), self.transform(image)


**To define SimCLR and Loss**

In [None]:
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class SimCLRModel(nn.Module):
    def __init__(self, base_model='resnet18', projection_dim=128):
        super(SimCLRModel, self).__init__()
        self.encoder = models.__dict__[base_model](pretrained=False)
        num_ftrs = self.encoder.fc.in_features
        self.encoder.fc = nn.Identity()
        self.projector = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        features = self.encoder(x)
        projections = self.projector(features)
        return projections


In [None]:
def nt_xent_loss(z_i, z_j, temperature=0.5):
    z = torch.cat([z_i, z_j], dim=0)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
    sim /= temperature

    N = z_i.shape[0]
    labels = torch.arange(N).to(z.device)
    labels = torch.cat([labels, labels], dim=0)

    mask = torch.eye(2*N, dtype=torch.bool).to(z.device)
    sim.masked_fill_(mask, -9e15)

    positives = torch.cat([torch.diag(sim, N), torch.diag(sim, -N)], dim=0)
    negatives = sim[~mask].view(2*N, -1)

    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
    return F.cross_entropy(logits, torch.zeros(2*N, dtype=torch.long).to(z.device))


**TO TRAIN AND SAVE**

In [None]:
from tqdm import tqdm
import random 
from torch.optim.lr_scheduler import CosineAnnealingLR
import json

random.shuffle(image_paths)  
dataset = SimCLRDataset(image_paths[:90000], simclr_transform)
loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=8)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimCLRModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 100
scheduler = CosineAnnealingLR(optimizer, T_max=epochs) 

loss_per_epoch = []  
best_loss = float('inf')

# ------------ Training Loop with Checkpointing ------------
for epoch in range(epochs):
    model.train()
    total_loss = 0
    loader_tqdm = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")

    for batch_idx, (x_i, x_j) in enumerate(loader_tqdm):
        x_i, x_j = x_i.to(device), x_j.to(device)
        z_i, z_j = model(x_i), model(x_j)
        loss = nt_xent_loss(z_i, z_j)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        avg_loss_so_far = total_loss / (batch_idx + 1)
        loader_tqdm.set_postfix(loss=f"{avg_loss_so_far:.4f}")

    print(f"Epoch {epoch+1}/{epochs} completed. Avg Loss: {avg_loss_so_far:.4f}")
    loss_per_epoch.append(avg_loss_so_far)
    scheduler.step()

    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss_per_epoch': loss_per_epoch
    }
    torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pt')

    # Save best model
    if avg_loss_so_far < best_loss:
        best_loss = avg_loss_so_far
        torch.save(model.state_dict(), 'simclr_best_model.pt')
        print(f"Best model saved at epoch {epoch+1} with loss {best_loss:.4f}")

# ------------ Save Loss Curve to JSON ------------
with open('simclr_loss.json', 'w') as f:
    json.dump(loss_per_epoch, f)