**Cell-1: Imports & Global Config**

In [1]:
import os
import json
import time
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, models
from PIL import Image

from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score


**Cell-2: Reproducibility (Mandatory)**

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cuda'

**Cell-3: Dataset Paths**

In [3]:
DATASET_DIR = "/kaggle/input/fishes/Fishes"
BATCH_SIZE = 32
IMG_SIZE = 160
EPOCHS = 100
TEMPERATURE = 0.5


**Cell-4: SSL Augmentation (SimCLR style)**

In [4]:
ssl_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor()
])


**Cell-5: SimCLR Dataset (Two Views)**

In [5]:
class SimCLRDataset(Dataset):
    def __init__(self, root, transform):
        self.samples = []
        for cls in os.listdir(root):
            cls_path = os.path.join(root, cls)
            if os.path.isdir(cls_path):
                for img in os.listdir(cls_path):
                    if img.lower().endswith((".jpg",".png",".jpeg")):
                        self.samples.append(os.path.join(cls_path, img))
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img = Image.open(self.samples[idx]).convert("RGB")
        return self.transform(img), self.transform(img)


**Cell-6: DataLoader**

In [6]:
ssl_dataset = SimCLRDataset(DATASET_DIR, ssl_transform)
ssl_loader = DataLoader(
    ssl_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    drop_last=True
)

len(ssl_dataset)


583

**Cell-7: SimCLR Model (Encoder + Projection Head)**

In [7]:
class SimCLR(nn.Module):
    def __init__(self, base_model="resnet50", proj_dim=128):
        super().__init__()

        self.encoder = models.resnet50(weights=None)
        self.encoder.fc = nn.Identity()   # output: [B, 2048]

        self.projector = nn.Sequential(
            nn.Linear(2048, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, proj_dim)
        )

    def forward(self, x):
        h = self.encoder(x)          # [B, 2048]
        z = self.projector(h)        # [B, proj_dim]
        return F.normalize(z, dim=1)


**Cell-8: NT-Xent Loss (Contrastive Loss)**

In [8]:
def nt_xent_loss(z1, z2, temperature):
    N = z1.size(0)
    z = torch.cat([z1, z2], dim=0)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)

    sim /= temperature
    labels = torch.arange(N).to(device)
    labels = torch.cat([labels + N, labels])

    mask = torch.eye(2*N, dtype=torch.bool).to(device)
    sim.masked_fill_(mask, -9e15)

    return F.cross_entropy(sim, labels)


**Cell-9: Initialize Model & Optimizer**

In [9]:
model = SimCLR().to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)


**Cell-10: Checkpoint Utils (Mandatory)**

In [10]:
CKPT_PATH = "/kaggle/working/simclr_ckpt.pt"

def save_ckpt(epoch, best_loss):
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "best_loss": best_loss
    }, CKPT_PATH)

def load_ckpt():
    if os.path.exists(CKPT_PATH):
        ckpt = torch.load(CKPT_PATH)
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        return ckpt["epoch"], ckpt["best_loss"]
    return 0, float("inf")


**Cell-11: SimCLR Pretraining Loop (Label-Free)**

In [11]:
start_epoch, best_loss = load_ckpt()
print("Resuming from epoch:", start_epoch)

scaler = torch.amp.GradScaler("cuda")

for epoch in range(start_epoch, EPOCHS):
    model.train()
    epoch_loss = 0.0

    torch.cuda.empty_cache()  # once per epoch

    for x1, x2 in tqdm(ssl_loader):
        x1 = x1.to(device, non_blocking=True)
        x2 = x2.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda"):
            z1 = model(x1)
            z2 = model(x2)
            loss = nt_xent_loss(z1, z2, TEMPERATURE)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    epoch_loss /= len(ssl_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {epoch_loss:.4f}")

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        save_ckpt(epoch+1, best_loss)

torch.cuda.empty_cache()


Resuming from epoch: 0


100%|██████████| 18/18 [01:14<00:00,  4.15s/it]


Epoch [1/100] Loss: 4.1305


100%|██████████| 18/18 [01:01<00:00,  3.41s/it]


Epoch [2/100] Loss: 3.9944


100%|██████████| 18/18 [01:05<00:00,  3.61s/it]


Epoch [3/100] Loss: 3.8509


100%|██████████| 18/18 [01:03<00:00,  3.53s/it]


Epoch [4/100] Loss: 3.7749


100%|██████████| 18/18 [01:00<00:00,  3.36s/it]


Epoch [5/100] Loss: 3.7198


100%|██████████| 18/18 [01:04<00:00,  3.60s/it]


Epoch [6/100] Loss: 3.5823


100%|██████████| 18/18 [01:01<00:00,  3.41s/it]


Epoch [7/100] Loss: 3.5754


100%|██████████| 18/18 [01:00<00:00,  3.34s/it]


Epoch [8/100] Loss: 3.5181


100%|██████████| 18/18 [01:02<00:00,  3.46s/it]


Epoch [9/100] Loss: 3.5117


100%|██████████| 18/18 [01:00<00:00,  3.37s/it]


Epoch [10/100] Loss: 3.4359


100%|██████████| 18/18 [01:01<00:00,  3.39s/it]


Epoch [11/100] Loss: 3.3695


100%|██████████| 18/18 [01:03<00:00,  3.50s/it]


Epoch [12/100] Loss: 3.2975


100%|██████████| 18/18 [01:01<00:00,  3.44s/it]


Epoch [13/100] Loss: 3.2759


100%|██████████| 18/18 [01:00<00:00,  3.39s/it]


Epoch [14/100] Loss: 3.2642


100%|██████████| 18/18 [01:01<00:00,  3.41s/it]


Epoch [15/100] Loss: 3.1877


100%|██████████| 18/18 [01:01<00:00,  3.42s/it]


Epoch [16/100] Loss: 3.2097


100%|██████████| 18/18 [01:00<00:00,  3.36s/it]


Epoch [17/100] Loss: 3.1870


100%|██████████| 18/18 [01:04<00:00,  3.58s/it]


Epoch [18/100] Loss: 3.1980


100%|██████████| 18/18 [01:01<00:00,  3.43s/it]


Epoch [19/100] Loss: 3.2231


100%|██████████| 18/18 [00:59<00:00,  3.33s/it]


Epoch [20/100] Loss: 3.1885


100%|██████████| 18/18 [00:59<00:00,  3.32s/it]


Epoch [21/100] Loss: 3.2217


100%|██████████| 18/18 [01:00<00:00,  3.34s/it]


Epoch [22/100] Loss: 3.1664


100%|██████████| 18/18 [00:59<00:00,  3.32s/it]


Epoch [23/100] Loss: 3.1482


100%|██████████| 18/18 [01:00<00:00,  3.33s/it]


Epoch [24/100] Loss: 3.2008


100%|██████████| 18/18 [00:59<00:00,  3.29s/it]


Epoch [25/100] Loss: 3.1700


100%|██████████| 18/18 [01:00<00:00,  3.38s/it]


Epoch [26/100] Loss: 3.1762


100%|██████████| 18/18 [01:00<00:00,  3.36s/it]


Epoch [27/100] Loss: 3.1087


100%|██████████| 18/18 [01:00<00:00,  3.34s/it]


Epoch [28/100] Loss: 3.1248


100%|██████████| 18/18 [01:01<00:00,  3.42s/it]


Epoch [29/100] Loss: 3.1265


100%|██████████| 18/18 [01:05<00:00,  3.62s/it]


Epoch [30/100] Loss: 3.1109


100%|██████████| 18/18 [01:02<00:00,  3.46s/it]


Epoch [31/100] Loss: 3.1624


100%|██████████| 18/18 [01:04<00:00,  3.57s/it]


Epoch [32/100] Loss: 3.1003


100%|██████████| 18/18 [01:00<00:00,  3.38s/it]


Epoch [33/100] Loss: 3.0994


100%|██████████| 18/18 [00:59<00:00,  3.33s/it]


Epoch [34/100] Loss: 3.1293


100%|██████████| 18/18 [01:00<00:00,  3.36s/it]


Epoch [35/100] Loss: 3.0781


100%|██████████| 18/18 [00:59<00:00,  3.31s/it]


Epoch [36/100] Loss: 3.0476


100%|██████████| 18/18 [01:01<00:00,  3.42s/it]


Epoch [37/100] Loss: 3.0710


100%|██████████| 18/18 [01:03<00:00,  3.55s/it]


Epoch [38/100] Loss: 3.0558


100%|██████████| 18/18 [01:00<00:00,  3.35s/it]


Epoch [39/100] Loss: 3.0274


100%|██████████| 18/18 [01:04<00:00,  3.60s/it]


Epoch [40/100] Loss: 3.0409


100%|██████████| 18/18 [00:59<00:00,  3.30s/it]


Epoch [41/100] Loss: 3.0408


100%|██████████| 18/18 [01:00<00:00,  3.38s/it]


Epoch [42/100] Loss: 3.0510


100%|██████████| 18/18 [01:01<00:00,  3.40s/it]


Epoch [43/100] Loss: 3.0665


100%|██████████| 18/18 [01:03<00:00,  3.52s/it]


Epoch [44/100] Loss: 3.0445


100%|██████████| 18/18 [01:02<00:00,  3.49s/it]


Epoch [45/100] Loss: 3.0524


100%|██████████| 18/18 [01:00<00:00,  3.37s/it]


Epoch [46/100] Loss: 3.0816


100%|██████████| 18/18 [00:59<00:00,  3.31s/it]


Epoch [47/100] Loss: 3.0848


100%|██████████| 18/18 [01:04<00:00,  3.56s/it]


Epoch [48/100] Loss: 3.0222


100%|██████████| 18/18 [01:04<00:00,  3.60s/it]


Epoch [49/100] Loss: 3.0453


100%|██████████| 18/18 [01:01<00:00,  3.42s/it]


Epoch [50/100] Loss: 3.0610


100%|██████████| 18/18 [00:59<00:00,  3.32s/it]


Epoch [51/100] Loss: 3.0114


100%|██████████| 18/18 [00:59<00:00,  3.33s/it]


Epoch [52/100] Loss: 3.0202


100%|██████████| 18/18 [01:02<00:00,  3.46s/it]


Epoch [53/100] Loss: 2.9654


100%|██████████| 18/18 [01:01<00:00,  3.41s/it]


Epoch [54/100] Loss: 3.0278


100%|██████████| 18/18 [01:06<00:00,  3.69s/it]


Epoch [55/100] Loss: 2.9470


100%|██████████| 18/18 [01:06<00:00,  3.70s/it]


Epoch [56/100] Loss: 2.9840


100%|██████████| 18/18 [00:59<00:00,  3.31s/it]


Epoch [57/100] Loss: 2.9448


100%|██████████| 18/18 [01:04<00:00,  3.56s/it]


Epoch [58/100] Loss: 2.9800


100%|██████████| 18/18 [01:01<00:00,  3.41s/it]


Epoch [59/100] Loss: 2.9497


100%|██████████| 18/18 [01:08<00:00,  3.80s/it]


Epoch [60/100] Loss: 2.9710


100%|██████████| 18/18 [01:01<00:00,  3.41s/it]


Epoch [61/100] Loss: 2.9822


100%|██████████| 18/18 [01:02<00:00,  3.50s/it]


Epoch [62/100] Loss: 2.9623


100%|██████████| 18/18 [01:01<00:00,  3.42s/it]


Epoch [63/100] Loss: 2.9913


100%|██████████| 18/18 [01:04<00:00,  3.57s/it]


Epoch [64/100] Loss: 2.9820


100%|██████████| 18/18 [01:00<00:00,  3.37s/it]


Epoch [65/100] Loss: 3.0035


100%|██████████| 18/18 [01:01<00:00,  3.39s/it]


Epoch [66/100] Loss: 3.0453


100%|██████████| 18/18 [01:02<00:00,  3.46s/it]


Epoch [67/100] Loss: 2.9646


100%|██████████| 18/18 [01:02<00:00,  3.49s/it]


Epoch [68/100] Loss: 2.9329


100%|██████████| 18/18 [01:04<00:00,  3.61s/it]


Epoch [69/100] Loss: 2.9456


100%|██████████| 18/18 [01:01<00:00,  3.42s/it]


Epoch [70/100] Loss: 2.8914


100%|██████████| 18/18 [01:04<00:00,  3.56s/it]


Epoch [71/100] Loss: 2.9492


100%|██████████| 18/18 [01:04<00:00,  3.56s/it]


Epoch [72/100] Loss: 2.9476


100%|██████████| 18/18 [01:04<00:00,  3.59s/it]


Epoch [73/100] Loss: 2.9414


100%|██████████| 18/18 [01:02<00:00,  3.45s/it]


Epoch [74/100] Loss: 2.9649


100%|██████████| 18/18 [01:00<00:00,  3.39s/it]


Epoch [75/100] Loss: 2.9409


100%|██████████| 18/18 [01:00<00:00,  3.39s/it]


Epoch [76/100] Loss: 2.8876


100%|██████████| 18/18 [00:59<00:00,  3.31s/it]


Epoch [77/100] Loss: 2.9466


100%|██████████| 18/18 [01:00<00:00,  3.34s/it]


Epoch [78/100] Loss: 2.9242


100%|██████████| 18/18 [00:59<00:00,  3.33s/it]


Epoch [79/100] Loss: 2.8943


100%|██████████| 18/18 [01:01<00:00,  3.42s/it]


Epoch [80/100] Loss: 2.9197


100%|██████████| 18/18 [01:02<00:00,  3.47s/it]


Epoch [81/100] Loss: 2.9288


100%|██████████| 18/18 [01:00<00:00,  3.39s/it]


Epoch [82/100] Loss: 2.8724


100%|██████████| 18/18 [01:01<00:00,  3.44s/it]


Epoch [83/100] Loss: 2.9029


100%|██████████| 18/18 [01:00<00:00,  3.36s/it]


Epoch [84/100] Loss: 2.9060


100%|██████████| 18/18 [01:01<00:00,  3.43s/it]


Epoch [85/100] Loss: 2.9105


100%|██████████| 18/18 [01:03<00:00,  3.51s/it]


Epoch [86/100] Loss: 2.9001


100%|██████████| 18/18 [01:03<00:00,  3.55s/it]


Epoch [87/100] Loss: 2.8801


100%|██████████| 18/18 [01:01<00:00,  3.43s/it]


Epoch [88/100] Loss: 2.8673


100%|██████████| 18/18 [01:01<00:00,  3.39s/it]


Epoch [89/100] Loss: 2.9014


100%|██████████| 18/18 [01:01<00:00,  3.44s/it]


Epoch [90/100] Loss: 2.8975


100%|██████████| 18/18 [01:05<00:00,  3.64s/it]


Epoch [91/100] Loss: 2.8711


100%|██████████| 18/18 [01:02<00:00,  3.45s/it]


Epoch [92/100] Loss: 2.8678


100%|██████████| 18/18 [01:00<00:00,  3.35s/it]


Epoch [93/100] Loss: 2.8519


100%|██████████| 18/18 [01:02<00:00,  3.47s/it]


Epoch [94/100] Loss: 2.9134


100%|██████████| 18/18 [01:01<00:00,  3.43s/it]


Epoch [95/100] Loss: 2.8754


100%|██████████| 18/18 [01:00<00:00,  3.39s/it]


Epoch [96/100] Loss: 2.8640


100%|██████████| 18/18 [01:02<00:00,  3.45s/it]


Epoch [97/100] Loss: 2.8869


100%|██████████| 18/18 [01:04<00:00,  3.61s/it]


Epoch [98/100] Loss: 2.8876


100%|██████████| 18/18 [01:03<00:00,  3.51s/it]


Epoch [99/100] Loss: 2.8900


100%|██████████| 18/18 [01:04<00:00,  3.59s/it]

Epoch [100/100] Loss: 2.8744





**Cell-12: Save Frozen Encoder (Deliverable)**

In [12]:
torch.save(model.encoder.state_dict(), "encoder_simclr.pt")
print("SimCLR encoder saved")


SimCLR encoder saved


**Cell-13: Feature Extraction Dataset**

In [13]:
class FeatureDataset(Dataset):
    def __init__(self, root, transform):
        self.samples = []
        self.labels = []
        self.classes = sorted(os.listdir(root))
        cls2idx = {c:i for i,c in enumerate(self.classes)}

        for c in self.classes:
            for img in os.listdir(os.path.join(root,c)):
                if img.lower().endswith((".jpg",".png",".jpeg")):
                    self.samples.append(os.path.join(root,c,img))
                    self.labels.append(cls2idx[c])

        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img = Image.open(self.samples[idx]).convert("RGB")
        return self.transform(img), self.labels[idx]


**Cell-14: Linear Probe Transform**

In [14]:
eval_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])


**Cell-15: Extract Features**

In [15]:
model.encoder.eval()

features, labels = [], []

dataset = FeatureDataset(DATASET_DIR, eval_tf)
loader = DataLoader(dataset, batch_size=64, shuffle=False)

with torch.no_grad():
    for x, y in tqdm(loader):
        x = x.to(device)
        h = model.encoder(x)
        features.append(h.cpu())
        labels.append(y)

X = torch.cat(features).numpy()
y = torch.cat(labels).numpy()


100%|██████████| 10/10 [00:53<00:00,  5.35s/it]


**Cell-16: k-NN Evaluation (Guideline Required)**

In [16]:
for k in [1,5,20]:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X, y)
    pred = knn.predict(X)
    acc = accuracy_score(y, pred)
    print(f"k-NN (k={k}) accuracy: {acc:.4f}")


k-NN (k=1) accuracy: 0.9949
k-NN (k=5) accuracy: 0.8113
k-NN (k=20) accuracy: 0.6278
