In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb
import torchmetrics # Use torchmetrics for mIoU
from torchmetrics.segmentation import MeanIoU

# --- Configuration ---
config = {
    "dataset_path": "dataset", # IMPORTANT: Change this path
    "num_classes": 13,
    "batch_size": 8, # Adjust based on your GPU memory
    "epochs": 30,      # Increase for better convergence
    "learning_rate": 1e-4,
    "backbone": "vgg19", # or "vgg19"
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "wandb_project": "FCN_Semantic_Segmentation",
    "seed": 45,
    "img_size": (224, 224)
}

# --- Reproducibility ---
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config["seed"])

# --- Class Names and Color Map ---
CLASS_NAMES = [
    "Unlabeled", "Building", "Fence", "Other", "Pedestrian", "Pole",
    "Roadline", "Road", "Sidewalk", "Vegetation", "Car", "Wall", "Traffic sign"
]

# Create a color map for visualization (13 distinct colors)
# Using tab20 colormap and extending it slightly
cmap = plt.get_cmap('tab20', config["num_classes"])
COLOR_MAP = [cmap(i)[:3] for i in range(config["num_classes"])]
COLOR_MAP_UINT8 = [(int(r*255), int(g*255), int(b*255)) for r, g, b in COLOR_MAP]

class SegmentationDataset(Dataset):
    def __init__(self, root_dir, transform=None, target_transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, 'images')
        self.label_dir = os.path.join(root_dir, 'labels')
        # Filter out potential hidden files like .DS_Store
        self.image_filenames = sorted([f for f in os.listdir(self.image_dir) if not f.startswith('.')])
        self.label_filenames = sorted([f for f in os.listdir(self.label_dir) if not f.startswith('.')])
        self.transform = transform
        self.target_transform = target_transform # Keep target_transform if provided

        # Ensure images and labels match
        assert len(self.image_filenames) == len(self.label_filenames), \
            f"Mismatch in number of images ({len(self.image_filenames)}) and labels ({len(self.label_filenames)}) in {root_dir}"
        for img_fn, lbl_fn in zip(self.image_filenames, self.label_filenames):
             # Check if base names match (e.g., 'img1.png', 'img1.png')
             # Or if label has different extension (e.g., 'img1.jpg', 'img1.png')
             img_base = os.path.splitext(img_fn)[0]
             lbl_base = os.path.splitext(lbl_fn)[0]
             assert img_base == lbl_base, \
                 f"Mismatch filenames (bases): {img_base} (from {img_fn}) and {lbl_base} (from {lbl_fn})"


    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        label_path = os.path.join(self.label_dir, self.label_filenames[idx])

        try:
            image = Image.open(img_path).convert('RGB')
            # Load label as grayscale, assuming it contains integer class IDs 0-12
            label_pil = Image.open(label_path).convert('L')
        except Exception as e:
            print(f"Error loading index {idx}: {self.image_filenames[idx]}, {self.label_filenames[idx]}")
            print(e)
            raise IOError(f"Could not load image/label at index {idx}")


        # Apply image transform (ToTensor, Normalize etc.)
        if self.transform:
            image = self.transform(image)

        # Apply label transform (Resize only)
        if self.target_transform:
            label_pil_resized = self.target_transform(label_pil)
        else:
             label_pil_resized = label_pil # No resize if no transform provided


        # --- Convert resized PIL label to LongTensor ---
        # Convert PIL image to numpy array first
        label_np = np.array(label_pil_resized, dtype=np.int64)
        # Convert numpy array to tensor
        label = torch.from_numpy(label_np).long()
        # Label should now be a [H, W] tensor with integer class IDs

        # Clamp label values to be within num_classes range just in case
        # Values outside this range will cause errors in CrossEntropyLoss
        label = torch.clamp(label, 0, config["num_classes"] - 1)

        # Print unique values for debugging the first few samples
        # if idx < 5:
        #    print(f"Sample {idx}: Unique label values after processing: {torch.unique(label)}")

        return image, label
# --- Transformations ---
# Image transforms: Resize, ToTensor, Normalize
img_transform = transforms.Compose([
    transforms.Resize(config["img_size"]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Label transforms: Resize (using nearest neighbor), ToTensor
# Note: We convert to LongTensor within the dataset __getitem__ for CrossEntropy
__
label_transform = transforms.Compose([
    transforms.Resize(config["img_size"], interpolation=transforms.InterpolationMode.NEAREST),
])


# --- Dataset Visualization (Task 2.1) ---
# --- Dataset Visualization (Task 2.1) ---
def visualize_dataset_sample(dataset, index=0):
    """Visualizes a sample image and its class-specific masks."""
    image, label_tensor = dataset[index] # Get raw tensor label

    # Denormalize image for display
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    image_display = inv_normalize(image).permute(1, 2, 0).cpu().numpy()
    image_display = np.clip(image_display, 0, 1)

    label_numpy = label_tensor.cpu().numpy() # Label might be [1, H, W] or [H, W]

    # --- FIX: Ensure label_numpy is 2D ---
    if label_numpy.ndim == 3 and label_numpy.shape[0] == 1:
        # print(f"DEBUG: Squeezing label_numpy shape from {label_numpy.shape}")
        label_numpy = label_numpy.squeeze(0) # Remove leading dimension if present
    elif label_numpy.ndim != 2:
         # If it's not 3D with size 1 first, and not 2D, it's an unexpected shape
         raise ValueError(f"Unexpected label shape received in visualize_dataset_sample: {label_numpy.shape}")
    # Now label_numpy is guaranteed to be (H, W)

    num_classes = config["num_classes"]
    rows = 3
    # Calculate columns needed (+2 for original image and combined mask)
    cols = (num_classes + 2 + rows - 1) // rows


    plt.figure(figsize=(18, 9)) # Adjust figsize if needed

    # Plot original image
    plt.subplot(rows, cols, 1)
    plt.imshow(image_display)
    plt.title("Original Image")
    plt.axis('off')

    # Plot original mask (colored)
    # This uses the now corrected 2D label_numpy
    colored_mask = np.zeros((*label_numpy.shape, 3), dtype=np.uint8) # Shape is (H, W, 3)
    for i in range(num_classes):
         mask = label_numpy == i # mask is (H, W) boolean
         colored_mask[mask] = COLOR_MAP_UINT8[i]
    plt.subplot(rows, cols, 2)
    plt.imshow(colored_mask) # This should now work
    plt.title("Ground Truth Mask")
    plt.axis('off')


    # Plot binary masks for each class
    for i in range(num_classes):
        binary_mask = (label_numpy == i).astype(np.uint8) * 255
        # Calculate subplot index starting from 3
        plot_index = i + 3
        if plot_index <= rows * cols: # Check if subplot index is valid
             plt.subplot(rows, cols, plot_index)
             plt.imshow(binary_mask, cmap='gray')
             plt.title(f"{CLASS_NAMES[i]} (ID: {i})", fontsize=8) # Smaller font if needed
             plt.axis('off')
        else:
            print(f"Warning: Not enough subplots calculated ({rows}x{cols}) to display class {i+1}/{num_classes}")


    plt.tight_layout()
    plt.show()


# --- FCN Model Definition ---
def get_vgg_backbone(name="vgg16", pretrained=True):
    """Loads a pretrained VGG backbone"""
    if name == "vgg16":
        model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None)
    elif name == "vgg19":
         model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1 if pretrained else None)
    else:
        raise ValueError(f"Unsupported backbone: {name}")

    # Remove the classifier part
    features = model.features
    # Identify pool3, pool4 layers based on VGG structure
    # VGG16: pool3 is layer 16, pool4 is layer 23, pool5 is layer 30
    # VGG19: pool3 is layer 18, pool4 is layer 27, pool5 is layer 36
    # Use names if available, otherwise indices (check print(model.features))
    pool_indices = {'vgg16': {'pool3': 16, 'pool4': 23, 'pool5': 30},
                    'vgg19': {'pool3': 18, 'pool4': 27, 'pool5': 36}}
    return features, pool_indices[name]


class FCN(nn.Module):
    def __init__(self, backbone_name, num_classes, variant='FCN-32s', pretrained=True, freeze_backbone=False):
        super().__init__()
        self.variant = variant
        self.num_classes = num_classes

        self.features, self.pool_indices = get_vgg_backbone(backbone_name, pretrained)

        # Freeze backbone if requested
        if freeze_backbone:
            print("Freezing backbone weights.")
            for param in self.features.parameters():
                param.requires_grad = False

        # Determine feature map depth after pool5
        # Run a dummy tensor to find out the number of channels
        with torch.no_grad():
             dummy_input = torch.zeros(1, 3, config["img_size"][0], config["img_size"][1])
             pool5_out_channels = self.features(dummy_input).shape[1]
             # Also get pool4 and pool3 channels if needed
             pool4_idx = self.pool_indices['pool4']
             pool3_idx = self.pool_indices['pool3']
             pool4_out_channels = self.features[:pool4_idx+1](dummy_input).shape[1]
             pool3_out_channels = self.features[:pool3_idx+1](dummy_input).shape[1]


        # Replace VGG classifier with 1x1 convolutions
        self.score_pool5 = nn.Conv2d(pool5_out_channels, num_classes, kernel_size=1)

        if variant == 'FCN-32s':
            # Upsample directly to input size
            # Kernel size should be 2 * stride, padding = stride / 2 for bilinear like init
            self.upsample32 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, stride=32, padding=16, bias=False)
            self._initialize_weights(self.upsample32) # Initialize upsampling weights bilinearly

        elif variant == 'FCN-16s':
            self.score_pool4 = nn.Conv2d(pool4_out_channels, num_classes, kernel_size=1)
            self.upsample2_pool5 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1, bias=False)
            self.upsample16_combined = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=32, stride=16, padding=8, bias=False)
            self._initialize_weights(self.upsample2_pool5, self.upsample16_combined)

        elif variant == 'FCN-8s':
             self.score_pool4 = nn.Conv2d(pool4_out_channels, num_classes, kernel_size=1)
             self.score_pool3 = nn.Conv2d(pool3_out_channels, num_classes, kernel_size=1)

             self.upsample2_pool5 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1, bias=False)
             self.upsample2_pool4 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1, bias=False)
             self.upsample8_combined = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=16, stride=8, padding=4, bias=False)
             self._initialize_weights(self.upsample2_pool5, self.upsample2_pool4, self.upsample8_combined)

        else:
            raise ValueError(f"Unknown FCN variant: {variant}")

    def _initialize_weights(self, *layers):
        """Initialize ConvTranspose2d layers for bilinear upsampling."""
        for layer in layers:
            if isinstance(layer, nn.ConvTranspose2d):
                # Bilinear initialization
                # From: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py
                factor = (layer.kernel_size[0] + 1) // 2
                if layer.kernel_size[0] % 2 == 1:
                    center = factor - 1
                else:
                    center = factor - 0.5
                og = np.ogrid[:layer.kernel_size[0], :layer.kernel_size[1]]
                filt = (1 - abs(og[0] - center) / factor) * \
                       (1 - abs(og[1] - center) / factor)
                weight = np.zeros((layer.in_channels, layer.out_channels,
                                   layer.kernel_size[0], layer.kernel_size[1]),
                                  dtype=np.float32)
                weight[range(layer.in_channels), range(layer.out_channels), :, :] = filt
                layer.weight.data.copy_(torch.from_numpy(weight))
                if layer.bias is not None:
                     nn.init.constant_(layer.bias.data, 0)

        # Initialize score layers with zeros (as suggested in paper)
        if hasattr(self, 'score_pool5'): nn.init.zeros_(self.score_pool5.weight) ; nn.init.zeros_(self.score_pool5.bias)
        if hasattr(self, 'score_pool4'): nn.init.zeros_(self.score_pool4.weight) ; nn.init.zeros_(self.score_pool4.bias)
        if hasattr(self, 'score_pool3'): nn.init.zeros_(self.score_pool3.weight) ; nn.init.zeros_(self.score_pool3.bias)


    def forward(self, x):
        input_size = x.shape[-2:] # H, W
        pool3_idx = self.pool_indices['pool3']
        pool4_idx = self.pool_indices['pool4']
        pool5_idx = self.pool_indices['pool5'] # End of features usually

        # Pass through backbone, saving intermediate features if needed
        pool3_feat, pool4_feat, pool5_feat = None, None, None
        current_feat = x
        for i, layer in enumerate(self.features):
            current_feat = layer(current_feat)
            if i == pool3_idx and self.variant in ['FCN-8s']:
                pool3_feat = current_feat
            elif i == pool4_idx and self.variant in ['FCN-16s', 'FCN-8s']:
                pool4_feat = current_feat
            elif i == pool5_idx: # Assumes pool5 is the last layer of features
                pool5_feat = current_feat
                if self.variant == 'FCN-32s': # Stop early if we only need pool5
                     break

        # --- Score and Upsample ---
        score5 = self.score_pool5(pool5_feat)

        if self.variant == 'FCN-32s':
            out = self.upsample32(score5)
            # Crop to input size (ConvTranspose2d might add padding)
            out = self._crop(out, input_size)
            return out

        elif self.variant == 'FCN-16s':
            score4 = self.score_pool4(pool4_feat)
            upsampled_score5 = self.upsample2_pool5(score5)

            # Crop score4 to match upsampled_score5 size
            score4_cropped = self._crop(score4, upsampled_score5.shape[-2:])
            combined16 = score4_cropped + upsampled_score5

            out = self.upsample16_combined(combined16)
            out = self._crop(out, input_size)
            return out

        elif self.variant == 'FCN-8s':
             score4 = self.score_pool4(pool4_feat)
             score3 = self.score_pool3(pool3_feat)

             upsampled_score5 = self.upsample2_pool5(score5)
             score4_cropped = self._crop(score4, upsampled_score5.shape[-2:])
             combined_pool45 = score4_cropped + upsampled_score5

             upsampled_pool45 = self.upsample2_pool4(combined_pool45)
             score3_cropped = self._crop(score3, upsampled_pool45.shape[-2:])
             combined_pool345 = score3_cropped + upsampled_pool45

             out = self.upsample8_combined(combined_pool345)
             out = self._crop(out, input_size)
             return out

    def _crop(self, tensor, target_size):
        """Center crop a tensor to target size (H, W)."""
        _, _, H, W = tensor.shape
        th, tw = target_size
        if H == th and W == tw:
            return tensor
        # Calculate crop start indices
        h_start = (H - th) // 2
        w_start = (W - tw) // 2
        return tensor[:, :, h_start:h_start+th, w_start:w_start+tw]


# --- Training and Validation Functions ---
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch, variant, freeze_status):
    model.train()
    total_loss = 0.0
    num_batches = len(dataloader)

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} Train ({variant}, {freeze_status})")
    for images, masks in pbar:
        images, masks = images.to(device), masks.to(device).long() # Ensure masks are Long

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=loss.item())
        wandb.log({f"Train/Batch Loss ({variant}, {freeze_status})": loss.item()})


    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1} Train Loss ({variant}, {freeze_status}): {avg_loss:.4f}")
    return avg_loss

def validate(model, dataloader, criterion, metric, device, epoch, variant, freeze_status):
    model.eval()
    total_loss = 0.0
    num_batches = len(dataloader)
    metric.reset()

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} Val ({variant}, {freeze_status})")
    with torch.no_grad():
        for images, masks in pbar:
            images, masks = images.to(device), masks.to(device).long()

            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            # print(f"Preds shape: {preds.shape}, Masks shape: {masks.shape}")
            metric.update(preds, masks) # Metric on CPU

            pbar.set_postfix(loss=loss.item())


    avg_loss = total_loss / num_batches
    final_miou = metric.compute()
    print(f"Epoch {epoch+1} Val Loss ({variant}, {freeze_status}): {avg_loss:.4f}, mIoU: {final_miou}")
    return avg_loss, final_miou


# --- Evaluation on Test Set ---
def evaluate_test_set(model, dataloader, metric, device, variant, freeze_status):
    model.eval()
    metric.reset()
    pbar = tqdm(dataloader, desc=f"Evaluating Test Set ({variant}, {freeze_status})")

    with torch.no_grad():
        for images, masks in pbar:
            images, masks = images.to(device), masks.to(device).long()
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            # print(f"Preds shape: {preds.shape}, Masks shape: {masks.shape}")
            metric.update(preds, masks) # Metric on CPU

    final_miou = metric.compute()
    print(f"Test mIoU ({variant}, {freeze_status}): {final_miou}")
    return final_miou

# --- Visualization of Predictions ---
def visualize_predictions(model, dataloader, device, num_samples=5, variant=None, freeze_status=None):
    model.eval()
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )

    samples_shown = 0
    with torch.no_grad():
        for images, masks_gt in dataloader:
            if samples_shown >= num_samples:
                break

            images, masks_gt = images.to(device), masks_gt.to(device).long()
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            for i in range(images.size(0)):
                if samples_shown >= num_samples:
                    break

                img_display = inv_normalize(images[i]).permute(1, 2, 0).cpu().numpy()
                img_display = np.clip(img_display, 0, 1)

                gt_mask_numpy = masks_gt[i].cpu().numpy()
                pred_mask_numpy = preds[i].cpu().numpy()

                # Colorize masks
                gt_colored = np.zeros((*gt_mask_numpy.shape, 3), dtype=np.uint8)
                pred_colored = np.zeros((*pred_mask_numpy.shape, 3), dtype=np.uint8)
                for c_id in range(config["num_classes"]):
                    gt_colored[gt_mask_numpy == c_id] = COLOR_MAP_UINT8[c_id]
                    pred_colored[pred_mask_numpy == c_id] = COLOR_MAP_UINT8[c_id]

                plt.figure(figsize=(12, 4))
                plt.subplot(1, 3, 1)
                plt.imshow(img_display)
                plt.title("Input Image")
                plt.axis('off')

                plt.subplot(1, 3, 2)
                plt.imshow(gt_colored)
                plt.title("Ground Truth Mask")
                plt.axis('off')

                plt.subplot(1, 3, 3)
                plt.imshow(pred_colored)
                plt.title(f"Predicted Mask ({variant}, {freeze_status})")
                plt.axis('off')

                plt.suptitle(f"Sample {samples_shown + 1}")
                plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout for suptitle
                # Save or show plot
                save_filename = f"prediction_{variant}_{freeze_status}_sample_{samples_shown+1}.png"
                plt.savefig(save_filename)
                print(f"Saved prediction visualization: {save_filename}")
                # plt.show() # Comment out if running non-interactively
                plt.close()


                samples_shown += 1




In [None]:

# --- Setup ---
device = config["device"]
print(f"Using device: {device}")

# --- Load Data ---
# Load the FULL training dataset first
full_train_dataset = SegmentationDataset(
    root_dir=os.path.join(config["dataset_path"], "train"),
    transform=img_transform,
    target_transform=label_transform
)

# Define the split ratio
train_ratio = 0.8 # 80% for training, 20% for validation
n_total = len(full_train_dataset)
n_train = int(n_total * train_ratio)
n_val = n_total - n_train

print(f"Splitting the original training data ({n_total} samples) into:")
print(f"  - Training set: {n_train} samples")
print(f"  - Validation set: {n_val} samples")

# Split the dataset using a fixed generator for reproducibility
generator = torch.Generator().manual_seed(config["seed"])
train_subset, val_subset = torch.utils.data.random_split(
    full_train_dataset, [n_train, n_val], generator=generator
)

# Load the test dataset as before
test_dataset = SegmentationDataset(
    root_dir=os.path.join(config["dataset_path"], "test"),
    transform=img_transform,
    target_transform=label_transform
)

# Create DataLoaders for the subsets and the test set
train_loader = DataLoader(train_subset, batch_size=config["batch_size"], shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=config["batch_size"], shuffle=False, num_workers=4, pin_memory=True) # No shuffle for validation
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=4, pin_memory=True)

# Update print statement for clarity
print(f"DataLoaders created: Train ({len(train_subset)} samples), Val ({len(val_subset)} samples), Test ({len(test_dataset)} samples)")


# --- Visualize Dataset Sample (Task 2.1) ---
print("\nVisualizing one dataset sample and its class masks (from original training set)...")
# Visualize from the original full dataset or the train_subset
# Using full_train_dataset ensures we can pick any index before split
for i in range(1):
    visualize_dataset_sample(full_train_dataset, index=i) # Visualize first 10 samples
    plt.show() # Show each plot interactively

# --- Training Loop ---
# ... (The rest of the script remains the same) ...

variants = ['FCN-32s', 'FCN-16s', 'FCN-8s']
freeze_options = [True, False] # True: Freeze backbone, False: Fine-tune all

criterion = nn.CrossEntropyLoss() # Ignores class 0 by default if needed, check docs
# Use TorchMetrics for mIoU
metric = MeanIoU(num_classes=config["num_classes"], per_class= False, input_format = 'index', include_background=False).to(device)

for freeze_backbone in freeze_options:
    freeze_status = "Frozen" if freeze_backbone else "Finetuned"
    for variant in variants:
        print(f"\n--- Training {variant} with {freeze_status} Backbone ---")

        # --- Initialize wandb ---
        run = wandb.init(
            project=config["wandb_project"],
            config=config,
            name=f"{variant}-{freeze_status}-{config['backbone']}-e{config['epochs']}-lr{config['learning_rate']}",
            reinit=True # Allows multiple init calls in one script
        )
        wandb.config.update({"variant": variant, "freeze_backbone": freeze_backbone})


        # --- Model, Optimizer ---
        model = FCN(
            backbone_name=config["backbone"],
            num_classes=config["num_classes"],
            variant=variant,
            pretrained=True,
            freeze_backbone=freeze_backbone
        ).to(device)

        # Adjust optimizer based on freeze status (optional: differential learning rate)
        if freeze_backbone:
                optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config["learning_rate"])
        else:
                optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])


        best_val_miou = -1.0
        model_save_path = f"best_model_{variant}_{freeze_status}.pth"

        # --- Epoch Loop ---
        for epoch in range(config["epochs"]):
            train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, variant, freeze_status)
            val_loss, val_miou = validate(model, val_loader, criterion, metric, device, epoch, variant, freeze_status)

            wandb.log({
                f"Val/Epoch Loss": val_loss,
                f"Val/Epoch mIoU": val_miou,
                f"Train/Epoch Loss": train_loss,
                "epoch": epoch + 1
            })
            # use mean mIoU for best model selection
            val_miou = val_miou.mean().item() if isinstance(val_miou, torch.Tensor) else val_miou
            
            # Save best model based on validation mIoU
            if val_miou > best_val_miou:
                best_val_miou = val_miou
                torch.save(model.state_dict(), model_save_path)
                print(f"Saved best model to {model_save_path} (mIoU: {best_val_miou:.4f})")
                wandb.save(model_save_path) # Save model artifact to wandb


        # --- Final Evaluation and Visualization ---
        print(f"\n--- Evaluating {variant} ({freeze_status}) on Test Set ---")
        # Load best model
        if os.path.exists(model_save_path):
            model.load_state_dict(torch.load(model_save_path))
        else:
            print(f"Warning: Best model file {model_save_path} not found. Evaluating with the last state.")

        test_miou = evaluate_test_set(model, test_loader, metric, device, variant, freeze_status)
        wandb.log({f"Test/Final mIoU ({variant}, {freeze_status})": test_miou})
        wandb.summary[f"best_val_miou_{variant}_{freeze_status}"] = best_val_miou
        wandb.summary[f"final_test_miou_{variant}_{freeze_status}"] = test_miou

        print(f"\n--- Visualizing Predictions for {variant} ({freeze_status}) ---")
        visualize_predictions(model, test_loader, device, num_samples=5, variant=variant, freeze_status=freeze_status)

        run.finish() # Finish wandb run for this configuration



### 3. Explanation and Comparison

**Summary of Test Set mIoU Results:**

| Name                                | variant | freeze_backbone | backbone | Runtime (s) | final_test_miou   |
| :---------------------------------- | :------ | :-------------- | :------- | :---------- | :---------------- |
| FCN-8s-Finetuned-vgg19-e30-lr0.0001 | FCN-8s  | false           | vgg19    | 973         | 0.2632            |
| FCN-16s-Finetuned-vgg19-e30-lr0.0001 | FCN-16s | false           | vgg19    | 972         | 0.2553            |
| FCN-32s-Finetuned-vgg19-e30-lr0.0001 | FCN-32s | false           | vgg19    | 980         | 0.2465            |
| FCN-8s-Frozen-vgg19-e30-lr0.0001    | FCN-8s  | true            | vgg19    | 406         | 0.2267            |
| FCN-16s-Frozen-vgg19-e30-lr0.0001   | FCN-16s | true            | vgg19    | 414         | 0.2139            |
| FCN-32s-Frozen-vgg19-e30-lr0.0001   | FCN-32s | true            | vgg19    | 421         | 0.1927            |

*(Note: mIoU values are rounded to four decimal places for clarity)*

**Differences between FCN-32s, FCN-16s, and FCN-8s:**

The core difference between these FCN variants lies in how they combine coarse, high-level semantic information from deep layers with finer, spatial information from shallower layers using skip connections before upsampling back to the original image resolution:

*   **FCN-32s:** This variant uses only the output from the *final* convolutional layer of the VGG19 backbone (output stride 32). It applies a classifier (1x1 conv) and then performs a single, large 32x bilinear upsampling to reach the input image size. It lacks access to finer details from earlier layers, resulting in coarser segmentation maps.
*   **FCN-16s:** This variant improves upon FCN-32s by incorporating information from an earlier layer. It takes the stride 32 prediction, upsamples it 2x, and fuses it (element-wise addition) with predictions made from the `pool4` layer features (output stride 16). This combined map is then upsampled 16x. By adding `pool4` features, it captures more spatial detail than FCN-32s.
*   **FCN-8s:** This variant further refines the process by adding another skip connection. It takes the stride 16 fused map (from the FCN-16s process before final upsampling), upsamples it 2x, and fuses it with predictions made from the `pool3` layer features (output stride 8). This final fused map is then upsampled 8x. By incorporating information from `pool3`, `pool4`, and the final layer, FCN-8s leverages features at multiple spatial resolutions, enabling potentially more accurate localization and finer segmentation boundaries.

**Segmentation Performance Discussion:**

Based on the `final_test_miou` reported in the table:

*   **Trend within Variants:** For both the *frozen* and *finetuned* scenarios, the performance consistently improves as more skip connections are introduced:
    *   Frozen: FCN-8s (0.2267) > FCN-16s (0.2139) > FCN-32s (0.1927)
    *   Finetuned: FCN-8s (0.2632) > FCN-16s (0.2553) > FCN-32s (0.2465)
*   **Interpretation:** The FCN-8s variant achieves the highest mIoU in both settings, indicating that combining feature maps from earlier layers (`pool3`, `pool4`) with the final layer's semantic information leads to the best segmentation results on this dataset, according to the mIoU metric. The progressive addition of finer spatial details helps refine the coarse predictions from the deeper layers. FCN-32s consistently performs the worst, highlighting the limitation of relying solely on the deepest layer's output.

**Comparison: Frozen vs. Finetuned Backbone:**

Comparing the performance of models with a frozen backbone versus a finetuned backbone reveals significant differences:

*   **Performance:** Finetuning the VGG19 backbone yields substantially better results across all FCN variants compared to keeping the backbone frozen.
    *   FCN-32s sees a relative mIoU improvement of approximately **28%** (0.2465 vs 0.1927).
    *   FCN-16s sees a relative mIoU improvement of approximately **19%** (0.2553 vs 0.2139).
    *   FCN-8s sees a relative mIoU improvement of approximately **16%** (0.2632 vs 0.2267).
*   **Segmentation Quality:** The higher mIoU scores strongly suggest that finetuning leads to superior segmentation quality. By allowing the backbone weights to be updated, the model can better adapt the ImageNet-pretrained features to the specific characteristics and classes of the target semantic segmentation dataset. This likely translates to more accurate pixel classifications and better delineation of object boundaries (although visual inspection of predictions would be needed to confirm this qualitatively).
*   **Training Time:** Finetuning comes at the cost of significantly increased training time. As seen in the `Runtime` column, finetuned models took more than double the time to train compared to their frozen counterparts (e.g., ~970s vs ~410s). This is because gradients must be backpropagated through the entire network, including the large VGG19 backbone, requiring more computation per iteration.

**Conclusion:**

The experimental results indicate that **FCN-8s with a finetuned VGG19 backbone** provides the best performance on this dataset, achieving a final test mIoU of **0.2632**. The inclusion of multi-level skip connections (as in FCN-8s) and the adaptation of backbone features via finetuning are both crucial factors for achieving higher accuracy in semantic segmentation.