In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import cv2
from torch.utils.data import Dataset, DataLoader
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor, Normalize
from sklearn.metrics import jaccard_score, precision_score, recall_score, f1_score
import time

# Dataset class for Cityscapes with recursive file loading
class CityscapesDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, target_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.target_transform = target_transform

        self.image_files = self._load_files(self.image_dir)
        self.mask_files = self._load_files(self.mask_dir)

    def _load_files(self, dir_path):
        all_files = []
        for root, dirs, files in os.walk(dir_path):
            for file in files:
                if file.endswith(".png"):
                    all_files.append(os.path.join(root, file))
        return sorted(all_files)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = self.mask_files[idx]

        image = cv2.imread(img_path)[..., ::-1].copy()  # Convert BGR to RGB and fix strides issue
        mask = cv2.imread(mask_path, 0).copy()  # Load mask in grayscale and fix strides issue

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        return image, mask

# Define the transforms
transform = Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
target_transform = ToTensor()

# Define paths to Cityscapes validation dataset
val_image_dir = "segment-anything-2/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/val"
val_mask_dir = "segment-anything-2/cityscapes/gtFine_trainvaltest/gtFine/val"

# Create dataset and dataloader for validation
val_dataset = CityscapesDataset(val_image_dir, val_mask_dir, transform=transform, target_transform=target_transform)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Load SAM2 model and fine-tuned checkpoint
sam2_checkpoint = 'sam2_cityscapes_optimized.pth'  # Fine-tuned pruned model checkpoint
model_cfg = 'sam2_hiera_s.yaml'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Build and load model
sam2_model = build_sam2(model_cfg, None, device=device)
checkpoint = torch.load(sam2_checkpoint, map_location=device)
sam2_model.load_state_dict(checkpoint, strict=False)  # Allow missing or unexpected keys
sam2_model = sam2_model.to(device)
sam2_model.eval()  # Set model to evaluation mode

# Initialize SAM2ImagePredictor
predictor = SAM2ImagePredictor(sam2_model)

# Metrics storage
iou_scores = []
precisions = []
recalls = []
f1_scores = []
latencies = []

# For plotting
plot_data = []

# Define unnormalize transform
mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)

# Evaluation loop
with torch.no_grad():
    for batch_idx, (image, target) in enumerate(val_dataloader):
        start_time = time.time()

        image = image.to(device)
        target = target.to(device)

        # Ensure the image is in (C, H, W) format by removing batch dimension if necessary
        image_np = image.squeeze(0).cpu().numpy().transpose(1, 2, 0)  # Convert to HWC format

        # Predictor processing
        predictor.set_image(image_np)  
        
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(None, None, None)
        
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]

        low_res_masks, _, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            repeat_image=False,
            multimask_output=False,
            high_res_features=high_res_features,
        )

        pred_mask = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])[0]
        
        # Convert boolean mask to uint8 before resizing
        pred_mask_uint8 = (torch.sigmoid(pred_mask).cpu().numpy() > 0.5).astype(np.uint8)

        # Resizing predicted mask to match target size using INTER_NEAREST for binary masks
        pred_mask_resized = cv2.resize(pred_mask_uint8[0], (target.shape[-1], target.shape[-2]), interpolation=cv2.INTER_NEAREST)

        # Ensure target is binary (if not already)
        target_binary = (target.cpu().numpy() > 0.5).astype(np.uint8)

        # Compute metrics with both arrays as binary values
        iou = jaccard_score(target_binary.flatten(), pred_mask_resized.flatten(), average="macro")
        precision = precision_score(target_binary.flatten(), pred_mask_resized.flatten(), average="macro", zero_division=0)
        recall = recall_score(target_binary.flatten(), pred_mask_resized.flatten(), average="macro", zero_division=0)
        f1 = f1_score(target_binary.flatten(), pred_mask_resized.flatten(), average="macro", zero_division=0)
        latency = time.time() - start_time

        # Store metrics
        iou_scores.append(iou)
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
        latencies.append(latency)

        print(f"Batch {batch_idx}: IOU={iou:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, Latency={latency:.4f}s")

        # Collect data for plotting for the first 4 images
        if batch_idx < 4:
            # Unnormalize image
            image_unnorm = image * std + mean
            image_np_plot = image_unnorm.squeeze(0).cpu().numpy().transpose(1, 2, 0)  # H x W x C
            image_np_plot = np.clip(image_np_plot, 0, 1)  # Ensure values are between 0 and 1

            # Get target mask
            target_np = target.squeeze(0).squeeze(0).cpu().numpy()  # H x W

            # Append to plot_data
            plot_data.append((image_np_plot, pred_mask_resized, target_np))

    # Compute overall metrics
    mean_iou = np.mean(iou_scores)
    mean_precision = np.mean(precisions)
    mean_recall = np.mean(recalls)
    mean_f1 = np.mean(f1_scores)
    mean_latency = np.mean(latencies)

    print(f"Mean IOU: {mean_iou:.4f}")
    print(f"Mean Precision: {mean_precision:.4f}")
    print(f"Mean Recall: {mean_recall:.4f}")
    print(f"Mean F1 Score: {mean_f1:.4f}")
    print(f"Mean Latency (s): {mean_latency:.4f}")

    # Plot the actual images and predicted masks for any 4 images
    for i, (image_np_plot, pred_mask_np, target_np) in enumerate(plot_data):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(image_np_plot)
        axes[0].set_title('Original Image')
        axes[0].axis('off')

        # Overlay the predicted mask onto the original image
        axes[1].imshow(image_np_plot)
        axes[1].imshow(pred_mask_np, cmap='jet', alpha=0.5)
        axes[1].set_title('Predicted Mask Overlay')
        axes[1].axis('off')

        # Overlay the ground truth mask onto the original image
        axes[2].imshow(image_np_plot)
        axes[2].imshow(target_np, cmap='jet', alpha=0.5)
        axes[2].set_title('Ground Truth Mask Overlay')
        axes[2].axis('off')

        plt.show()
