In [None]:
# ==========================================
# 1. MOUNT DRIVE
# ==========================================
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# -*- coding: utf-8 -*-
"""scene_segmentation_test_complete

Automatically generated by Colab.
"""

# ==========================================
# 2. INSTALL & IMPORTS
# ==========================================
!pip install -q segmentation-models-pytorch gdown

import gdown
import torch
import torch.nn as nn
import numpy as np
import cv2
from pathlib import Path
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from datetime import datetime
import segmentation_models_pytorch as smp
import time
import shutil
import os

print("="*70)
print("DUALITY AI CHALLENGE - COMPLETE DIAGNOSTIC TEST (WITH METRICS)")
print("="*70)

# ==========================================
# 3. DOWNLOAD DATA & MODEL
# ==========================================
# --- Download Test Data ---

print("\nDownloading Dataset...")
url = "https://storage.googleapis.com/duality-public-share/Hackathons/Duality%20Hackathon/Offroad_Segmentation_testImages.zip"
output_zip = "/content/test.zip"

if not os.path.exists(output_zip):
    gdown.download(url, output_zip, quiet=False)
    print("Unzipping dataset...")
    !unzip -q {output_zip} -d /content/data
    print("‚úì Data extracted to /content/data")
else:
    print("‚úì Dataset already exists, skipping download.")

# --- Download Model ---
print("\nDownloading Model...")
# REPLACE THIS WITH YOUR PATH IF ON DRIVE
MODEL_PATH = "/content/drive/MyDrive/Duality_Project/checkpoints/unet-resnet34-colab-20260204-0618/best_model.pth"
shutil.copy(MODEL_PATH, "/content/model.pth")


[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.8/154.8 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
DUALITY AI CHALLENGE - COMPLETE DIAGNOSTIC TEST (WITH METRICS)

Downloading Dataset...


Downloading...
From: https://storage.googleapis.com/duality-public-share/Hackathons/Duality%20Hackathon/Offroad_Segmentation_testImages.zip
To: /content/test.zip
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.16G/1.16G [00:27<00:00, 41.8MB/s]


Unzipping dataset...
‚úì Data extracted to /content/data

Downloading Model...


'/content/model.pth'

In [None]:
# ==========================================
# 4. CONFIGURATION
# ==========================================
class TestConfig:
    # Paths
    CHECKPOINT_PATH = "/content/model.pth"
    TEST_IMG_DIR = "/content/data/Offroad_Segmentation_testImages/Color_Images"
    TEST_MASK_DIR = "/content/data/Offroad_Segmentation_testImages/Segmentation" # <--- METRICS RELY ON THIS
    OUTPUT_DIR = "/content/test_results"

    # Model
    ENCODER = "resnet34"
    NUM_CLASSES = 10
    IMG_HEIGHT = 544
    IMG_WIDTH = 960
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # Test settings
    BATCH_SIZE = 8
    NUM_WORKERS = 2
    VISUALIZE_SAMPLES = 10

    # Class info
    CLASS_MAPPING = {
        100: 0, 200: 1, 300: 2, 500: 3, 550: 4,
        600: 5, 700: 6, 800: 7, 7100: 8, 10000: 9
    }
    CLASS_NAMES = ["Trees", "Lush Bushes", "Dry Grass", "Dry Bushes",
                   "Ground Clutter", "Flowers", "Logs", "Rocks",
                   "Landscape", "Sky"]

    CLASS_COLORS = [
        [34, 139, 34],   [50, 205, 50],   [154, 205, 50],  [139, 69, 19],  [160, 82, 45],
        [255, 182, 193], [139, 90, 43],   [128, 128, 128], [210, 180, 140],[135, 206, 235]
    ]

# Create output directories
Path(TestConfig.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
Path(TestConfig.OUTPUT_DIR + "/predictions").mkdir(parents=True, exist_ok=True)
Path(TestConfig.OUTPUT_DIR + "/visualizations").mkdir(parents=True, exist_ok=True)

print(f"\n‚úì Configuration loaded")
print(f"‚úì Test images: {TestConfig.TEST_IMG_DIR}")
print(f"‚úì Output: {TestConfig.OUTPUT_DIR}")


‚úì Configuration loaded
‚úì Test images: /content/data/Offroad_Segmentation_testImages/Color_Images
‚úì Output: /content/test_results


In [None]:
# ==========================================
# 5. DATA LOADER
# ==========================================
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir):
        self.image_dir = Path(image_dir)
        self.image_paths = sorted(list(self.image_dir.glob("*.png")) + list(self.image_dir.glob("*.jpg")))
        print(f"\n‚úì Found {len(self.image_paths)} test images")

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        original_image = image.copy()

        # Resize & Normalize
        image = cv2.resize(image, (TestConfig.IMG_WIDTH, TestConfig.IMG_HEIGHT))
        image = image.astype(np.float32) / 255.0
        image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
        image = torch.from_numpy(image.transpose(2, 0, 1)).float()

        return image, original_image, img_path.name

test_dataset = TestDataset(TestConfig.TEST_IMG_DIR)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=TestConfig.BATCH_SIZE, shuffle=False,
    num_workers=TestConfig.NUM_WORKERS, pin_memory=True
)

# ==========================================
# 6. LOAD MODEL
# ==========================================
print(f"\n{'='*70}")
print("LOADING MODEL")
print(f"{'='*70}")

def load_model_safely(checkpoint_path, device):
    model = smp.Unet(
        encoder_name=TestConfig.ENCODER,
        encoder_weights=None,
        classes=TestConfig.NUM_CLASSES
    )
    print(f"Loading checkpoint: {checkpoint_path}")

    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
    except:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
        print("‚úì Extracted model_state_dict")
    else:
        state_dict = checkpoint

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict, strict=False)
    print("‚úì Weights loaded (non-strict mode)")

    model = model.to(device)
    model.eval()

    # Model Stats
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nüìä Model Information:")
    print(f"  ‚Ä¢ Architecture: UNet-{TestConfig.ENCODER}")
    print(f"  ‚Ä¢ Total parameters: {total_params:,}")
    print(f"  ‚Ä¢ Model size: {total_params * 4 / (1024**2):.2f} MB (FP32)")

    return model

model = load_model_safely(TestConfig.CHECKPOINT_PATH, TestConfig.DEVICE)


‚úì Found 1002 test images

LOADING MODEL
Loading checkpoint: /content/model.pth
‚úì Weights loaded (non-strict mode)

üìä Model Information:
  ‚Ä¢ Architecture: UNet-resnet34
  ‚Ä¢ Total parameters: 24,437,674
  ‚Ä¢ Model size: 93.22 MB (FP32)


In [None]:
# ==========================================
# 7. LATENCY CHECK (FPS)
# ==========================================
print(f"\n{'='*70}")
print("QUICK LATENCY CHECK")
print(f"{'='*70}\n")

# Warmup
print("Warming up...")
dummy_input = torch.randn(1, 3, TestConfig.IMG_HEIGHT, TestConfig.IMG_WIDTH).to(TestConfig.DEVICE)
with torch.no_grad():
    for _ in range(10): _ = model(dummy_input)

if TestConfig.DEVICE == 'cuda': torch.cuda.synchronize()

# Measure
print("Measuring latency (100 runs)...")
times = []
with torch.no_grad():
    for _ in tqdm(range(100)):
        if TestConfig.DEVICE == 'cuda': torch.cuda.synchronize()
        start = time.perf_counter()
        _ = model(dummy_input)
        if TestConfig.DEVICE == 'cuda': torch.cuda.synchronize()
        times.append(time.perf_counter() - start)

avg_time_ms = np.mean(times) * 1000
fps = 1000 / avg_time_ms

print(f"\n‚ö° Results:")
print(f"  ‚Ä¢ Average latency: {avg_time_ms:.2f} ms")
print(f"  ‚Ä¢ FPS: {fps:.2f}")
print(f"  ‚Ä¢ Real-time capable (30 FPS): {'‚úÖ Yes' if fps >= 30 else '‚ùå No'}")

if TestConfig.DEVICE == 'cuda':
    mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
    print(f"  ‚Ä¢ GPU memory: {mem_mb:.2f} MB")


QUICK LATENCY CHECK

Warming up...
Measuring latency (100 runs)...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:03<00:00, 29.20it/s]


‚ö° Results:
  ‚Ä¢ Average latency: 33.48 ms
  ‚Ä¢ FPS: 29.87
  ‚Ä¢ Real-time capable (30 FPS): ‚ùå No
  ‚Ä¢ GPU memory: 291.38 MB





In [None]:
# ==========================================
# 8. INFERENCE
# ==========================================
print(f"\n{'='*70}")
print("RUNNING INFERENCE")
print(f"{'='*70}\n")

predictions = []
original_images = []
filenames = []
class_pixel_counts = {i: 0 for i in range(TestConfig.NUM_CLASSES)}
total_pixels = 0

with torch.no_grad():
    for images, orig_imgs, names in tqdm(test_loader, desc="Inference"):
        images = images.to(TestConfig.DEVICE)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1).cpu().numpy()

        for pred, orig_img, name in zip(preds, orig_imgs, names):
            predictions.append(pred)
            original_images.append(orig_img.numpy())
            filenames.append(name)

            # Update Statistics
            unique, counts = np.unique(pred, return_counts=True)
            for cls, count in zip(unique, counts):
                class_pixel_counts[cls] += count
                total_pixels += count

print(f"\n‚úì Generated {len(predictions)} predictions")

# ==========================================
# 9. CALCULATING METRICS (RESTORED!)
# ==========================================
print(f"\n{'='*70}")
print("CALCULATING ACCURACY & IoU")
print(f"{'='*70}\n")

test_mask_dir = Path(TestConfig.TEST_MASK_DIR)

if test_mask_dir.exists():
    print(f"‚úì Ground truth found at: {test_mask_dir}")

    gt_masks = []
    valid_pairs = []

    print("Loading Ground Truth...")
    for idx, filename in enumerate(tqdm(filenames)):
        mask_path = test_mask_dir / filename
        if mask_path.exists():
            gt_mask = cv2.imread(str(mask_path), cv2.IMREAD_UNCHANGED)
            pred = predictions[idx]

            # FIX 1: Resize GT to match prediction size
            if gt_mask.shape != pred.shape:
                gt_mask = cv2.resize(gt_mask, (pred.shape[1], pred.shape[0]), interpolation=cv2.INTER_NEAREST)

            # FIX 2: Remap GT values to 0-9
            gt_remapped = np.zeros(pred.shape, dtype=np.int64)
            for old_id, new_id in TestConfig.CLASS_MAPPING.items():
                gt_remapped[gt_mask == old_id] = new_id

            gt_masks.append(gt_remapped)
            valid_pairs.append(idx)

    print(f"‚úì Found masks for {len(valid_pairs)}/{len(predictions)} images")

    # --- IoU Calculation Logic ---
    print("Computing metrics...")
    all_ious = []
    all_accuracies = []
    class_totals = np.zeros(TestConfig.NUM_CLASSES)

    for idx in tqdm(valid_pairs):
        pred = predictions[idx]
        gt = gt_masks[idx]

        # Pixel Accuracy
        acc = (pred == gt).sum() / pred.size
        all_accuracies.append(acc)

        # IoU per class
        ious = np.full(TestConfig.NUM_CLASSES, np.nan)
        for cls in range(TestConfig.NUM_CLASSES):
            pred_mask = (pred == cls)
            target_mask = (gt == cls)

            if target_mask.sum() > 0:
                class_totals[cls] += 1
                intersection = np.logical_and(pred_mask, target_mask).sum()
                union = np.logical_or(pred_mask, target_mask).sum()
                ious[cls] = intersection / union if union > 0 else 0.0
            else:
                ious[cls] = np.nan
        all_ious.append(ious)

    # --- Aggregation ---
    all_ious = np.array(all_ious)
    mean_ious = np.nanmean(all_ious, axis=0)

    # Calculate Mean IoU only for classes present
    valid_class_ious = mean_ious[~np.isnan(mean_ious)]
    mean_iou = np.mean(valid_class_ious) if len(valid_class_ious) > 0 else 0.0
    mean_accuracy = np.mean(all_accuracies)

    # --- PRINT RESULTS ---
    print("\n" + "="*70)
    print("TEST SET METRICS")
    print("="*70)
    print(f"\nüìä Overall Metrics:")
    print(f"  ‚Ä¢ Mean IoU:       {mean_iou:.4f}")
    print(f"  ‚Ä¢ Pixel Accuracy: {mean_accuracy:.4f}")

    print(f"\nüìà Per-Class Metrics:")
    print("-" * 95)
    print(f"{'Class':<15} {'IoU':>8} {'Count':>10} {'Status':>15}")
    print("-" * 95)

    for i, class_name in enumerate(TestConfig.CLASS_NAMES):
        iou = mean_ious[i]
        count = int(class_totals[i])

        if np.isnan(iou) or count == 0:
            iou_str = "N/A"
            status = "Not in test"
        else:
            iou_str = f"{iou:.4f}"
            status = "Present"

        print(f"{class_name:<15} {iou_str:>8} {count:>10} {status:>15}")
    print("-" * 95)

    # Save Metrics JSON
    test_metrics = {
        "mean_iou": float(mean_iou),
        "pixel_accuracy": float(mean_accuracy),
        "per_class": {name: float(mean_ious[i]) if not np.isnan(mean_ious[i]) else None for i, name in enumerate(TestConfig.CLASS_NAMES)}
    }
    with open(f"{TestConfig.OUTPUT_DIR}/test_metrics.json", 'w') as f:
        json.dump(test_metrics, f, indent=2)

else:
    print(f"‚ö†Ô∏è  Ground truth folder not found: {test_mask_dir}")
    print("Skipping metric calculation (Visualizations only).")

# ==========================================
# 10. SAVING PREDICTIONS
# ==========================================
print(f"\n{'='*70}")
print("SAVING PREDICTIONS")
print(f"{'='*70}\n")

REVERSE_MAPPING = {v: k for k, v in TestConfig.CLASS_MAPPING.items()}

for pred, filename in tqdm(zip(predictions, filenames), total=len(predictions), desc="Saving PNGs"):
    pred_mask = np.zeros_like(pred, dtype=np.uint16)
    for class_idx, original_id in REVERSE_MAPPING.items():
        pred_mask[pred == class_idx] = original_id

    save_path = Path(TestConfig.OUTPUT_DIR) / "predictions" / filename
    cv2.imwrite(str(save_path), pred_mask)


RUNNING INFERENCE



Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 126/126 [00:58<00:00,  2.16it/s]



‚úì Generated 1002 predictions

CALCULATING ACCURACY & IoU

‚úì Ground truth found at: /content/data/Offroad_Segmentation_testImages/Segmentation
Loading Ground Truth...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1002/1002 [00:13<00:00, 75.69it/s]


‚úì Found masks for 1002/1002 images
Computing metrics...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1002/1002 [00:16<00:00, 59.26it/s]
  mean_ious = np.nanmean(all_ious, axis=0)



TEST SET METRICS

üìä Overall Metrics:
  ‚Ä¢ Mean IoU:       0.4172
  ‚Ä¢ Pixel Accuracy: 0.6168

üìà Per-Class Metrics:
-----------------------------------------------------------------------------------------------
Class                IoU      Count          Status
-----------------------------------------------------------------------------------------------
Trees             0.4808        986         Present
Lush Bushes       0.0035        668         Present
Dry Grass         0.3596       1002         Present
Dry Bushes        0.3806       1002         Present
Ground Clutter       N/A          0     Not in test
Flowers              N/A          0     Not in test
Logs                 N/A          0     Not in test
Rocks             0.0768       1002         Present
Landscape         0.6403       1002         Present
Sky               0.9786       1002         Present
-----------------------------------------------------------------------------------------------

SAVING PREDICTI

Saving PNGs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1002/1002 [00:12<00:00, 79.07it/s]


In [None]:
# ==========================================
# 11. VISUALIZATIONS & CHARTS
# ==========================================
print(f"\n{'='*70}")
print("GENERATING VISUALIZATIONS")
print(f"{'='*70}\n")

def mask_to_rgb(mask):
    h, w = mask.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for i, color in enumerate(TestConfig.CLASS_COLORS):
        rgb[mask == i] = color
    return rgb

# 1. Image Visualizations
np.random.seed(42)
viz_indices = np.random.choice(len(predictions), min(TestConfig.VISUALIZE_SAMPLES, len(predictions)), replace=False)

for idx in tqdm(viz_indices, desc="Creating Overlays"):
    pred = predictions[idx]
    orig_img = original_images[idx]
    filename = filenames[idx]

    pred_resized = cv2.resize(pred.astype(np.uint8), (orig_img.shape[1], orig_img.shape[0]), interpolation=cv2.INTER_NEAREST)
    mask_rgb = mask_to_rgb(pred_resized)

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].imshow(orig_img); axes[0].set_title("Original", fontsize=14)
    axes[1].imshow(mask_rgb); axes[1].set_title("Prediction", fontsize=14)
    axes[2].imshow(cv2.addWeighted(orig_img, 0.6, mask_rgb, 0.4, 0)); axes[2].set_title("Overlay", fontsize=14)
    for ax in axes: ax.axis('off')

    patches = [mpatches.Patch(color=np.array(c)/255., label=n) for c, n in zip(TestConfig.CLASS_COLORS, TestConfig.CLASS_NAMES)]
    fig.legend(handles=patches, loc='lower center', ncol=5, fontsize=10)
    plt.savefig(f"{TestConfig.OUTPUT_DIR}/visualizations/viz_{filename}", bbox_inches='tight')
    plt.close()

# 2. Summary Bar Chart
print("\nCreating Summary Chart...")
fig, ax = plt.subplots(figsize=(12, 6))
class_pixel_counts = {i: 0 for i in range(TestConfig.NUM_CLASSES)}
total_pixels = 0
for pred in predictions:
    unique, counts = np.unique(pred, return_counts=True)
    for cls, count in zip(unique, counts):
        class_pixel_counts[cls] += count
        total_pixels += count

percentages = [(class_pixel_counts[i] / total_pixels) * 100 for i in range(TestConfig.NUM_CLASSES)]
colors_normalized = [np.array(color)/255. for color in TestConfig.CLASS_COLORS]

bars = ax.bar(TestConfig.CLASS_NAMES, percentages, color=colors_normalized, edgecolor='black')
ax.set_ylabel('Percentage of Pixels (%)'); ax.set_title('Class Distribution in Test Set Predictions')
plt.xticks(rotation=45, ha='right'); plt.grid(axis='y', alpha=0.3)

for bar, pct in zip(bars, percentages):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height(), f'{pct:.1f}%', ha='center', va='bottom')

plt.tight_layout()
plt.savefig(f"{TestConfig.OUTPUT_DIR}/class_distribution_summary.png", dpi=150)
plt.close()


GENERATING VISUALIZATIONS



Creating Overlays: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [00:05<00:00,  1.80it/s]



Creating Summary Chart...


In [None]:
# ==========================================
# 12. ZIP & DOWNLOAD
# ==========================================
print(f"\n{'='*70}")
print("ZIPPING RESULTS")
print(f"{'='*70}\n")
shutil.make_archive("/content/test_results", 'zip', TestConfig.OUTPUT_DIR)
print(f"‚úì Done! Download 'test_results.zip' from the files tab.")


ZIPPING RESULTS

‚úì Done! Download 'test_results.zip' from the files tab.
