# Task 5: CILP Assessment Performance

## 1. Objective
Train a complete Contrastive Image-LiDAR Pretraining (CILP) model to achieve specific performance thresholds:
* **CILP Loss:** < 3.5
* **Projector MSE:** < 2.5
* **Classifier Accuracy:** > 95%

In [None]:
!pip install wandb -q

import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import numpy as np
import getpass
import shutil
import glob
import random
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from google.colab import drive

# REPRODUCIBILITY
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
    print(f"Random seed set to {seed}")

set_seed(42)

if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/CILP_Assignment'
sys.path.append(PROJECT_ROOT)
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

EXTRACT_DIR = '/content/data_local'
SEARCH_DIR = os.path.join(PROJECT_ROOT, 'data')

found_zips = glob.glob(os.path.join(SEARCH_DIR, "*.zip"))
if len(found_zips) > 0 and not os.path.exists(EXTRACT_DIR):
    os.makedirs(EXTRACT_DIR, exist_ok=True)
    os.system(f'unzip -q "{found_zips[0]}" -d "{EXTRACT_DIR}"')
elif not os.path.exists(EXTRACT_DIR):
    EXTRACT_DIR = SEARCH_DIR

DATA_PATH = None
for root, dirs, files in os.walk(EXTRACT_DIR):
    if 'cubes' in dirs: DATA_PATH = root; break

# DATASET
class RobustAssessmentDataset(Dataset):
    def __init__(self, root_dir, subset_fraction=1.0):
        self.samples = []
        self.transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor()
        ])
        classes = ["cubes", "spheres"]

        for label, shape in enumerate(classes):
            shape_dir = os.path.join(root_dir, shape)
            rgb_dir = os.path.join(shape_dir, "rgb")
            lidar_dir = os.path.join(shape_dir, "lidar")
            if not os.path.exists(rgb_dir): continue

            try: az, ze = np.load(os.path.join(shape_dir, "azimuth.npy")), np.load(os.path.join(shape_dir, "zenith.npy"))
            except: az, ze = np.zeros(10000), np.zeros(10000)

            image_files = sorted([f for f in os.listdir(rgb_dir) if f.endswith('.png')])
            for img_name in image_files:
                file_id = img_name.split('.')[0]
                lidar_path = os.path.join(lidar_dir, f"{file_id}.npy")
                if os.path.exists(lidar_path):
                    try: idx = int(file_id)
                    except: idx = 0
                    self.samples.append({
                        "rgb": os.path.join(rgb_dir, img_name),
                        "lidar": lidar_path,
                        "az": az[idx] if idx < len(az) else 0,
                        "ze": ze[idx] if idx < len(ze) else 0,
                        "label": label
                    })

        if subset_fraction < 1.0:
            random.shuffle(self.samples)
            count = int(len(self.samples) * subset_fraction)
            self.samples = self.samples[:count]

            labels = [s['label'] for s in self.samples]
            n_cubes = labels.count(0)
            n_spheres = labels.count(1)
            print(f"Subset: {len(self.samples)} | Cubes: {n_cubes}, Spheres: {n_spheres}")

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        item = self.samples[idx]
        try:
            rgb = self.transform(Image.open(item["rgb"]).convert("RGB"))
            rgb_in = torch.cat([rgb, torch.zeros(1, 64, 64)], dim=0)
            depth = torch.tensor(np.load(item["lidar"]), dtype=torch.float32)
            mask = (depth < 50.0).float()
            x = depth * np.sin(-item["az"]) * np.cos(-item["ze"])
            y = depth * np.cos(-item["az"]) * np.cos(-item["ze"])
            z = depth * np.sin(-item["ze"])
            return rgb_in, torch.stack([x,y,z,mask], dim=0), torch.tensor(item["label"], dtype=torch.long)
        except: return torch.zeros(4,64,64), torch.zeros(4,64,64), torch.tensor(0)

from src.models import CILPModel

class Projector(nn.Module):
    def __init__(self, input_dim=128, output_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, output_dim)
        )
    def forward(self, x): return self.net(x)

def compute_contrastive_loss(rgb_emb, lidar_emb, device='cuda'):
    rgb_norm = F.normalize(rgb_emb, dim=1)
    lidar_norm = F.normalize(lidar_emb, dim=1)
    scale = torch.tensor(2.65).to(device).exp()
    logits = scale * (rgb_norm @ lidar_norm.T)
    labels = torch.arange(logits.size(0)).to(device)
    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.T, labels)
    return (loss_i + loss_t) / 2

def get_similarity_heatmap(model, loader, device):
    model.eval()
    rgb_batch, lidar_batch, _ = next(iter(loader))
    rgb_batch, lidar_batch = rgb_batch.to(device), lidar_batch.to(device)
    with torch.no_grad():
        out = model(rgb_batch, lidar_batch)
        v_rgb, v_lidar = (out[0], out[1]) if isinstance(out, tuple) else (out, out)
        sim = torch.matmul(F.normalize(v_rgb, dim=1), F.normalize(v_lidar, dim=1).T).cpu().numpy()
    fig, ax = plt.subplots(figsize=(6, 5))
    cax = ax.imshow(sim, cmap='viridis')
    fig.colorbar(cax)
    plt.close(fig)
    return fig

# EXECUTION
print("W&B LOGIN")
wandb.login(key=getpass.getpass("API Key: "))
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = DataLoader(RobustAssessmentDataset(DATA_PATH, 0.3), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(RobustAssessmentDataset(DATA_PATH, 0.1), batch_size=BATCH_SIZE)
class_weights = torch.tensor([1.0, 4.0]).to(DEVICE)

# PHASE 1: CILP
print("\nPhase 1: Contrastive Pretraining")
cilp = CILPModel().to(DEVICE)
opt_cilp = torch.optim.Adam(cilp.parameters(), lr=5e-4)

wandb.init(project="cilp-extended-assessment", name="Task5_CILP_Pretrain", reinit=True)
best_loss = float('inf')

for epoch in range(20):
    cilp.train()
    total_loss = 0
    for rgb, lidar, _ in train_loader:
        rgb, lidar = rgb.to(DEVICE), lidar.to(DEVICE)
        out = cilp(rgb, lidar)
        rgb_emb, lidar_emb = (out[0], out[1]) if isinstance(out, tuple) else (out, out)
        loss = compute_contrastive_loss(rgb_emb, lidar_emb, DEVICE)
        opt_cilp.zero_grad(); loss.backward(); opt_cilp.step()
        total_loss += loss.item()

    cilp.eval()
    val_loss = 0
    with torch.no_grad():
        for rgb, lidar, _ in val_loader:
            out = cilp(rgb.to(DEVICE), lidar.to(DEVICE))
            rgb_emb, lidar_emb = (out[0], out[1]) if isinstance(out, tuple) else (out, out)
            val_loss += compute_contrastive_loss(rgb_emb, lidar_emb, DEVICE).item()

    avg_val = val_loss/len(val_loader)
    print(f"Ep {epoch+1} | Val Loss: {avg_val:.4f}")
    wandb.log({"val_loss": avg_val})
    if avg_val < best_loss:
        best_loss = avg_val
        torch.save(cilp.state_dict(), os.path.join(CHECKPOINT_DIR, "best_cilp.pt"))

try: wandb.log({"Similarity Matrix": wandb.Image(get_similarity_heatmap(cilp, val_loader, DEVICE))})
except: pass
wandb.finish()

Random seed set to 42
W&B LOGIN
API Key: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Subset: 3225 | Cubes: 2978, Spheres: 247
Subset: 1075 | Cubes: 1001, Spheres: 74

=== Phase 1: Contrastive Pretraining ===


Ep 1 | Val Loss: 0.7228
Ep 2 | Val Loss: 0.5244
Ep 3 | Val Loss: 0.3677
Ep 4 | Val Loss: 0.2865
Ep 5 | Val Loss: 0.2539
Ep 6 | Val Loss: 0.2172
Ep 7 | Val Loss: 0.2170
Ep 8 | Val Loss: 0.1854
Ep 9 | Val Loss: 0.2021
Ep 10 | Val Loss: 0.1819
Ep 11 | Val Loss: 0.1792
Ep 12 | Val Loss: 0.1791
Ep 13 | Val Loss: 0.1677
Ep 14 | Val Loss: 0.1676
Ep 15 | Val Loss: 0.1626
Ep 16 | Val Loss: 0.1623
Ep 17 | Val Loss: 0.1797
Ep 18 | Val Loss: 0.1747
Ep 19 | Val Loss: 0.1737
Ep 20 | Val Loss: 0.1621


0,1
val_loss,█▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
val_loss,0.16207


## 5.2 Cross-Modal Projector (Phase 2)
We freeze the CILP encoders and train a Projector (MLP) to map RGB embeddings into the LiDAR embedding space.
* **Goal:** Enable the RGB encoder to "hallucinate" LiDAR features.
* **Constraint:** MSE Loss < 2.5.

In [None]:
print("\nPhase 2: Projector Training")
cilp.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "best_cilp.pt")))
cilp.eval()
projector = Projector().to(DEVICE)
opt_proj = torch.optim.Adam(projector.parameters(), lr=1e-3)
crit_mse = nn.MSELoss()

wandb.init(project="cilp-extended-assessment", name="Task5_Projector", reinit=True)

for epoch in range(15):
    projector.train()
    for rgb, lidar, _ in train_loader:
        rgb, lidar = rgb.to(DEVICE), lidar.to(DEVICE)
        with torch.no_grad():
            out = cilp(rgb, lidar)
            e_rgb, e_lidar = (out[0], out[1]) if isinstance(out, tuple) else (out, out)
        loss = crit_mse(projector(e_rgb), e_lidar)
        opt_proj.zero_grad(); loss.backward(); opt_proj.step()

    projector.eval()
    val_mse = 0
    with torch.no_grad():
        for rgb, lidar, _ in val_loader:
            out = cilp(rgb.to(DEVICE), lidar.to(DEVICE))
            e_rgb, e_lidar = (out[0], out[1]) if isinstance(out, tuple) else (out, out)
            val_mse += crit_mse(projector(e_rgb), e_lidar).item()

    avg_mse = val_mse / len(val_loader)
    print(f"Ep {epoch+1} | Val MSE: {avg_mse:.4f}")
    wandb.log({"val_mse": avg_mse})

torch.save(projector.state_dict(), os.path.join(CHECKPOINT_DIR, "best_projector.pt"))
wandb.finish()


Phase 2: Projector Training


Ep 1 | Val MSE: 0.0035
Ep 2 | Val MSE: 0.0029
Ep 3 | Val MSE: 0.0029
Ep 4 | Val MSE: 0.0028
Ep 5 | Val MSE: 0.0027
Ep 6 | Val MSE: 0.0027
Ep 7 | Val MSE: 0.0026
Ep 8 | Val MSE: 0.0026
Ep 9 | Val MSE: 0.0026
Ep 10 | Val MSE: 0.0026
Ep 11 | Val MSE: 0.0026
Ep 12 | Val MSE: 0.0025
Ep 13 | Val MSE: 0.0025
Ep 14 | Val MSE: 0.0025
Ep 15 | Val MSE: 0.0025


0,1
val_mse,█▄▃▃▂▂▂▁▁▁▁▁▁▁▁

0,1
val_mse,0.00254


## 5.3 Final Classifier (Phase 3)
We train a linear classifier on top of the projected features to distinguish between Cubes and Spheres.
* **Goal:** High accuracy on the imbalanced dataset.
* **Strategy:** Used `Class Weights` in CrossEntropyLoss to handle the 10:1 class imbalance.
* **Constraint:** Validation Accuracy > 95%.

In [None]:
print("\nPhase 3: Final Classifier")
classifier = nn.Sequential(nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 2)).to(DEVICE)
opt_cls = torch.optim.Adam(classifier.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cls, mode='max', factor=0.5, patience=3)

crit_cls = nn.CrossEntropyLoss(weight=class_weights)

wandb.init(project="cilp-extended-assessment", name="Task5_Classifier", reinit=True)
best_acc = 0.0

for epoch in range(50):
    classifier.train()
    for rgb, _, label in train_loader:
        rgb, label = rgb.to(DEVICE), label.to(DEVICE)
        with torch.no_grad():
            out = cilp(rgb, rgb)
            e_rgb = out[0] if isinstance(out, tuple) else out
            proj_feat = projector(e_rgb)

        loss = crit_cls(classifier(proj_feat), label)
        opt_cls.zero_grad(); loss.backward(); opt_cls.step()

    classifier.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for rgb, _, label in val_loader:
            rgb, label = rgb.to(DEVICE), label.to(DEVICE)
            out = cilp(rgb, rgb)
            e_rgb = out[0] if isinstance(out, tuple) else out
            logits = classifier(projector(e_rgb))
            preds = torch.argmax(logits, dim=1)
            correct += (preds == label).sum().item()
            total += label.size(0)

    acc = correct / total * 100
    scheduler.step(acc)
    print(f"Ep {epoch+1} | Val Acc: {acc:.2f}%")
    wandb.log({"val_accuracy": acc})

    if acc > best_acc:
        best_acc = acc
        torch.save(classifier.state_dict(), os.path.join(CHECKPOINT_DIR, "best_classifier.pt"))

print(f"\nTask 5 Complete. Best Acc: {best_acc:.2f}%")
wandb.finish()


Phase 3: Final Classifier


Ep 1 | Val Acc: 93.12%
Ep 2 | Val Acc: 93.12%
Ep 3 | Val Acc: 93.12%
Ep 4 | Val Acc: 93.12%
Ep 5 | Val Acc: 93.12%
Ep 6 | Val Acc: 93.12%
Ep 7 | Val Acc: 93.40%
Ep 8 | Val Acc: 93.40%
Ep 9 | Val Acc: 93.12%
Ep 10 | Val Acc: 92.47%
Ep 11 | Val Acc: 92.47%
Ep 12 | Val Acc: 92.56%
Ep 13 | Val Acc: 92.65%
Ep 14 | Val Acc: 92.74%
Ep 15 | Val Acc: 92.47%
Ep 16 | Val Acc: 92.47%
Ep 17 | Val Acc: 92.47%
Ep 18 | Val Acc: 92.37%
Ep 19 | Val Acc: 92.37%
Ep 20 | Val Acc: 92.37%
Ep 21 | Val Acc: 92.28%
Ep 22 | Val Acc: 92.37%
Ep 23 | Val Acc: 92.28%
Ep 24 | Val Acc: 92.28%
Ep 25 | Val Acc: 92.28%
Ep 26 | Val Acc: 92.19%
Ep 27 | Val Acc: 92.19%
Ep 28 | Val Acc: 92.19%
Ep 29 | Val Acc: 92.19%
Ep 30 | Val Acc: 92.19%
Ep 31 | Val Acc: 92.19%
Ep 32 | Val Acc: 92.19%
Ep 33 | Val Acc: 92.19%
Ep 34 | Val Acc: 92.19%
Ep 35 | Val Acc: 92.19%
Ep 36 | Val Acc: 92.19%
Ep 37 | Val Acc: 92.19%
Ep 38 | Val Acc: 92.19%
Ep 39 | Val Acc: 92.19%
Ep 40 | Val Acc: 92.19%
Ep 41 | Val Acc: 92.19%
Ep 42 | Val Acc: 92.19%
E

0,1
val_accuracy,▆▆▆▆▆██▆▃▃▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
val_accuracy,92.18605
