In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install segmentation-models-pytorch albumentations --quiet

import os
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import seaborn as sns

In [None]:
!pip install -q segmentation-models-pytorch

In [None]:
import segmentation_models_pytorch as smp

In [None]:
model = smp.Unet(
    encoder_name="resnet34",        # could also try "efficientnet-b0"
    encoder_weights="imagenet",     # use pretrained weights
    in_channels=3,                  # RGB input
    classes=12                      # number of segmentation classes
).cuda()

In [None]:
ce_loss = nn.CrossEntropyLoss()

In [None]:
class ARBTLoss(nn.Module):
    def __init__(self, class_weights=None, delta=0.7, gamma=1.5, lam=0.2, epsilon=1e-6):
        super().__init__()
        self.class_weights = class_weights
        self.delta = delta  # Controls balance between FN and FP
        self.gamma = gamma  # Focal power
        self.lam = lam      # Focal weighting strength
        self.epsilon = epsilon

    def forward(self, logits, targets):
        num_classes = logits.shape[1]
        probs = F.softmax(logits, dim=1)
        one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()

        total_loss = 0.0
        for c in range(num_classes):
            p = probs[:, c]
            g = one_hot[:, c]

            tp = (p * g).sum(dim=(1,2))
            fp = (p * (1 - g)).sum(dim=(1,2))
            fn = ((1 - p) * g).sum(dim=(1,2))

            tversky = (tp + self.epsilon) / (tp + self.delta * fp + (1 - self.delta) * fn + self.epsilon)
            tversky_loss = 1 - tversky

            focal_term = torch.mean(torch.abs(p - g) ** self.gamma, dim=(1,2))
            loss = tversky_loss + self.lam * focal_term

            if self.class_weights is not None:
                loss = loss * self.class_weights[c]

            total_loss += loss.mean()

        return total_loss / num_classes

In [None]:
class CASM:
    def __init__(self, base_optimizer, model, rho=0.05, kappa=1.0, epsilon=1e-12):
        self.model = model
        self.rho = rho
        self.kappa = kappa
        self.epsilon = epsilon
        self.base_optimizer = base_optimizer

    def step(self, closure, rare_ratio=0.0):
        adaptive_rho = self.rho * (1 + self.kappa * (1 - rare_ratio))

        grads = []
        for group in self.base_optimizer.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                grads.append(p.grad.clone())
                e_w = adaptive_rho * p.grad / (p.grad.norm() + self.epsilon)
                p.add_(e_w)

        loss = closure()  # tracked by autograd
        loss.backward()

        self.base_optimizer.step()

        # Optionally restore grads
        for group, grad_group in zip(self.base_optimizer.param_groups, grads):
            for p, g in zip(group['params'], grad_group):
                if p.grad is None: continue
                p.grad.copy_(g)

        return loss

In [None]:
def train_model(model, dataloader, loss_fn, optimizer, epochs=15, use_casm=False):
    model.train()
    history = {"loss": []}

    for epoch in range(epochs):
        total_loss = 0.0
        for images, masks in tqdm(dataloader):
            images, masks = images.cuda(), masks.cuda()

            def closure():
                optimizer.zero_grad()
                outputs = model(images)
                loss = loss_fn(outputs, masks)
                return loss

            outputs = model(images)
            loss = loss_fn(outputs, masks)

            if use_casm:
                rare_ratio = (masks == 5).sum() / (masks.numel())  # adjust class index as needed
            
                optimizer.base_optimizer.zero_grad()  # FIXED: call on base_optimizer
                optimizer.step(
                    closure=lambda: loss_fn(model(images), masks), 
                    rare_ratio=rare_ratio.item()
                )
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        history["loss"].append(avg_loss)
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")

    return history

In [None]:
def evaluate(model, dataloader, num_classes=12):
    model.eval()
    preds_all = []
    masks_all = []

    with torch.no_grad():
        for images, masks in tqdm(dataloader):
            images = images.cuda()
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            masks = masks.numpy()

            preds_all.extend(preds.reshape(-1))
            masks_all.extend(masks.reshape(-1))

    preds_all = np.array(preds_all)
    masks_all = np.array(masks_all)

    cm = confusion_matrix(masks_all, preds_all, labels=list(range(num_classes)))
    iou = np.diag(cm) / (cm.sum(1) + cm.sum(0) - np.diag(cm) + 1e-6)
    f1 = 2 * np.diag(cm) / (cm.sum(1) + cm.sum(0) + 1e-6)

    return iou, f1, cm

In [None]:
def plot_loss(history, label):
    plt.plot(history["loss"], label=label)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

In [None]:
def visualize_prediction(model, dataset, idx=0):
    model.eval()
    image, mask = dataset[idx]
    with torch.no_grad():
        output = model(image.unsqueeze(0).cuda())
        pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(image.permute(1,2,0).cpu())
    axs[0].set_title("Input Image")
    axs[1].imshow(mask)
    axs[1].set_title("Ground Truth")
    axs[2].imshow(pred)
    axs[2].set_title("Prediction")
    for ax in axs: ax.axis("off")
    plt.show()

In [None]:
from torch import optim

def run_all_configs(train_loader, val_loader):
    configs = {
        "CASM + ARBT": {"loss": ARBTLoss(), "casm": True},
        "Adam + CE": {"loss": nn.CrossEntropyLoss(), "casm": False},
        "Adam + ARBT": {"loss": ARBTLoss(), "casm": False}
    }

    results = {}
    for name, cfg in configs.items():
        print(f"\n🚀 Running: {name}")
        model = smp.Unet("resnet34", encoder_weights="imagenet", in_channels=3, classes=12).cuda()

        base_optimizer = optim.Adam(model.parameters(), lr=1e-3)
        optimizer = CASM(base_optimizer, model) if cfg["casm"] else base_optimizer

        history = train_model(model, train_loader, cfg["loss"], optimizer, epochs=15, use_casm=cfg["casm"])
        iou, f1, cm = evaluate(model, val_loader)

        results[name] = {
            "history": history,
            "iou": iou,
            "f1": f1,
            "cm": cm,
            "model": model
        }

    return results

In [None]:
CITYSCAPES_COLORS = {
    (128, 64,128): 0,  # road
    (244, 35,232): 1,  # sidewalk
    (70, 70, 70): 2,   # building
    (102,102,156): 3,  # wall
    (190,153,153): 4,  # fence
    (153,153,153): 5,  # pole
    (250,170, 30): 6,  # traffic light
    (220,220,  0): 7,  # traffic sign
    (107,142, 35): 8,  # vegetation
    (152,251,152): 9,  # terrain
    (70,130,180):10,   # sky
    (220, 20, 60):11,  # person
}

def rgb_to_class(mask_rgb):
        h, w, _ = mask_rgb.shape
        mask = np.zeros((h, w), dtype=np.uint8)
    
        for rgb, cls in CITYSCAPES_COLORS.items():
            matches = np.all(mask_rgb == rgb, axis=-1)
            mask[matches] = cls
    
        return mask

In [None]:
from torchvision import transforms as T
from torchvision.transforms import functional as TF

class CityscapesImagePairsDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.image_paths = sorted([
            os.path.join(root_dir, fname)
            for fname in os.listdir(root_dir)
            if fname.endswith('.jpg')
        ])
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        full_image = Image.open(img_path).convert("RGB")
        w, h = full_image.size
        mid = w // 2

        image = full_image.crop((0, 0, mid, h))  # Left side: RGB
        mask = full_image.crop((mid, 0, w, h))   # Right side: Mask

        image = np.array(image)
        mask = np.array(mask)

        if self.transform:
            mask = rgb_to_class(mask)
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask.long()

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

common_transform = A.Compose([
    A.Resize(256, 512),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])

train_dataset = CityscapesImagePairsDataset(
    root_dir='/kaggle/input/cityscapes-image-pairs/cityscapes_data/train',
    transform=common_transform
)

val_dataset = CityscapesImagePairsDataset(
    root_dir='/kaggle/input/cityscapes-image-pairs/cityscapes_data/val',
    transform=common_transform
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

In [None]:
results = run_all_configs(train_loader, val_loader)

In [None]:
for name, r in results.items():
    plot_loss(r["history"], label=name)
plt.title("Training Loss Comparison")
plt.show()

In [None]:
def plot_metric(results, metric='iou'):
    plt.figure(figsize=(10,5))
    for name in results:
        values = results[name][metric]
        plt.plot(values, label=name)
    plt.title(f"Per-Class {metric.upper()} Comparison")
    plt.xlabel("Class Index")
    plt.ylabel(metric.upper())
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
for name in results:
    print(f"Example Prediction from: {name}")
    visualize_prediction(results[name]["model"], train_dataset)

In [None]:
plt.figure(figsize=(10, 5))
for name, r in results.items():
    plt.plot(r["iou"], label=name)
plt.title("Per-Class IoU Comparison")
plt.xlabel("Class Index")
plt.ylabel("IoU")
plt.grid(True)
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
for name, r in results.items():
    plt.plot(r["f1"], label=name)
plt.title("Per-Class F1 Score Comparison")
plt.xlabel("Class Index")
plt.ylabel("F1 Score")
plt.grid(True)
plt.legend()
plt.show()

In [None]:
rare_class_index = 5  # change based on your definition

for name, r in results.items():
    print(f"{name}: Rare-Class-{rare_class_index} IoU = {r['iou'][rare_class_index]:.4f}")

In [None]:
import seaborn as sns

cm = results["CASM + ARBT"]["cm"]
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix - CASM + ARBT")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

In [None]:
import pandas as pd

summary_data = []
for name, r in results.items():
    mean_iou = np.mean(r["iou"])
    mean_f1 = np.mean(r["f1"])
    rare_iou = r["iou"][rare_class_index]
    rare_f1 = r["f1"][rare_class_index]
    summary_data.append([name, mean_iou, mean_f1, rare_iou, rare_f1])

df = pd.DataFrame(summary_data, columns=["Model", "Mean IoU", "Mean F1", f"Rare IoU (Class {rare_class_index})", f"Rare F1 (Class {rare_class_index})"])
display(df)

In [None]:
df.to_csv("optimizer_comparison_summary.csv", index=False)