# Momentum Contrast (MoCo)

In [None]:
import os
import warnings
warnings.filterwarnings("ignore")
import glob
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Subset
import random
import copy
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn.functional as F
from tqdm.auto import tqdm

for rd in [
    "/kaggle/input/ssl-dataset/ssl_dataset/train.X1",
    "/kaggle/input/ssl-dataset/ssl_dataset/train.X2",
    "/kaggle/input/ssl-dataset/ssl_dataset/train.X3",
    "/kaggle/input/ssl-dataset/ssl_dataset/train.X4"
]:
    print(f"\nContents of {rd}:")
    for name in sorted(os.listdir(rd)):
        path = os.path.join(rd, name)
        if os.path.isdir(path):
            print(f"  [DIR]  {name}  →  contains {len(os.listdir(path))} entries")
        else:
            print(f"  [FILE] {name}")


## Parameters and Labels

In [None]:
TRAIN_DIRS = ["/kaggle/input/ssl-dataset/ssl_dataset/train.X1", "/kaggle/input/ssl-dataset/ssl_dataset/train.X2", "/kaggle/input/ssl-dataset/ssl_dataset/train.X3", "/kaggle/input/ssl-dataset/ssl_dataset/train.X4"]
IMG_SIZE    = 80
BATCH_SIZE  = 64
EMB_DIM     = 128
HEAD_DIM    = 512
QUEUE_SIZE  = 4096
MOMENTUM    = 0.999
TEMPERATURE = 0.07
LR          = 0.03 * (BATCH_SIZE / 256)
EPOCHS      = 100
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"


## Transformations

In [None]:
aug = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.4,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])


class ContrastiveDS(Dataset):
    def __init__(self, roots, tfm):
        self.paths = []
        for rd in roots:
            for r, _, fn in os.walk(rd):
                for f in fn:
                    if f.lower().endswith((".jpg",".jpeg",".png")):
                        self.paths.append(os.path.join(r,f))
        assert self.paths, "No images found"
        self.tfm = tfm
    def __len__(self): return len(self.paths)
    def __getitem__(self,i):
        img = Image.open(self.paths[i]).convert("RGB")
        return self.tfm(img), self.tfm(img)

In [None]:
NUM_SAMPLES = 90000
ds_full = ContrastiveDS(TRAIN_DIRS, aug)
subset_indices = random.sample(range(len(ds_full)), NUM_SAMPLES)
ds = Subset(ds_full, subset_indices)


loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=8, pin_memory=True, drop_last=True)

print("Samples:", len(ds)) 


## MoCo Design

In [None]:
class MoCo(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(pretrained=False)
        feat_dim = base.fc.in_features  # 512 for ResNet18
        self.encoder_q = nn.Sequential(*list(base.children())[:-1])
        self.encoder_k = nn.Sequential(*list(base.children())[:-1])
        self.projector_q = nn.Sequential(
            nn.Linear(feat_dim, HEAD_DIM),
            nn.ReLU(inplace=True),
            nn.Linear(HEAD_DIM, EMB_DIM)
        )
        self.projector_k = copy.deepcopy(self.projector_q)
        self.m, self.T = MOMENTUM, TEMPERATURE

        for q, k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            k.data.copy_(q.data); k.requires_grad=False
        for q, k in zip(self.projector_q.parameters(), self.projector_k.parameters()):
            k.data.copy_(q.data); k.requires_grad=False

     
        self.register_buffer("queue", torch.zeros(EMB_DIM, QUEUE_SIZE))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
        self.criterion = nn.CrossEntropyLoss()

    @torch.no_grad()
    def _update_key(self):
        for q, k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            k.data = k.data*self.m + q.data*(1-self.m)
        for q, k in zip(self.projector_q.parameters(), self.projector_k.parameters()):
            k.data = k.data*self.m + q.data*(1-self.m)

    @torch.no_grad()
    def _enqueue(self, keys):
        B = keys.size(0); ptr = int(self.queue_ptr); K = self.queue.size(1)
        qk = keys.T
        if ptr + B <= K:
            self.queue[:, ptr:ptr+B] = qk
        else:
            first = K - ptr
            self.queue[:, ptr:]    = qk[:, :first]
            self.queue[:, :B-first] = qk[:, first:]
        self.queue_ptr[0] = (ptr + B) % K

    def forward(self, x1, x2):
  
        qf = self.encoder_q(x1).flatten(1)
        q  = F.normalize(self.projector_q(qf), dim=1)

   
        with torch.no_grad():
            self._update_key()
            kf = self.encoder_k(x2).flatten(1)
            k  = F.normalize(self.projector_k(kf), dim=1)

    
        queue_const = self.queue.clone().detach().to(q.device)

   
        l_pos  = (q * k).sum(1, True)
        l_neg  = q @ queue_const     
        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

        self._enqueue(k)
        return logits, labels

## Pre-Train Loop

In [None]:
SAVE_DIR = "/kaggle/working/"
os.makedirs(SAVE_DIR, exist_ok=True)

moco = MoCo().to(DEVICE, memory_format=torch.channels_last)
print("Using device:", DEVICE)
print("Model parameters on:", next(moco.parameters()).device)

opt = torch.optim.SGD(
    list(moco.encoder_q.parameters()) + list(moco.projector_q.parameters()),
    lr=LR, momentum=0.9, weight_decay=1e-4
)

sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=10, T_mult=2)
scaler = GradScaler()

best_loss = float('inf')
patience = 5
wait = 0

epoch_losses = []

for ep in range(1, EPOCHS + 1):
    moco.train()
    running_loss = 0.0
    total_samples = 0
    
    pbar = tqdm(loader, desc=f"Epoch {ep}/{EPOCHS}", leave=False)
    for a, b in pbar:
        a = a.to(DEVICE, non_blocking=True)
        b = b.to(DEVICE, non_blocking=True)
        batch_size = a.size(0)
        
        opt.zero_grad()
        with autocast():
            logits, labs = moco(a, b)
            loss = F.cross_entropy(logits, labs)
        
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        running_loss += loss.item() * batch_size
        total_samples += batch_size
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    sch.step()
    avg_loss = running_loss / total_samples
    epoch_losses.append(avg_loss)

    print(f"Epoch {ep}/{EPOCHS}  Avg Loss: {avg_loss:.4f}  LR: {opt.param_groups[0]['lr']:.6f}", flush=True)

    
    if avg_loss < best_loss:
        best_loss = avg_loss
        wait = 0
        torch.save(moco.state_dict(), f"{SAVE_DIR}/best.pth")
        print(f"🔥 New best model @ epoch {ep}: {best_loss:.4f}")
    else:
        wait += 1
        if wait >= patience:
            print(f"⛔ Early stopping triggered @ epoch {ep}")
            break

    torch.save({
        'epoch': ep,
        'model_state_dict': moco.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': avg_loss,
        'lr': opt.param_groups[0]['lr']
    }, f"{SAVE_DIR}/checkpoint_epoch_{ep}.pth")
