In [None]:
import os
import random
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pycocotools import mask as coco_mask
from transformers import SamModel, SamProcessor, pipeline
from raffm import RaFFM
import torch
import gc

# Initialize your SAM model and RaFFM as before
model = SamModel.from_pretrained("facebook/sam-vit-huge")
elastic_config = {
    "atten_out_space": [1280],
    "inter_hidden_space": [2048],
    "residual_hidden_space": [2048],
}
raffm_model = RaFFM(model.to("cuda"), elastic_config=elastic_config)
submodel, params, config = raffm_model.random_resource_aware_model()

# Initialize your mask generation pipeline
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
generator = pipeline("mask-generation", model=submodel, device="cuda", image_processor=processor.image_processor)

# Helper functions
def random_color():
    return [random.randint(0, 255) for _ in range(3)]

def display_ground_truth_masks(mask_path, image):
    mask_overlay = np.full_like(image, 255)
    with open(mask_path, 'r') as json_file:
        mask_data = json.load(json_file)
    for annotation in mask_data['annotations']:
        rle_mask = annotation['segmentation']
        binary_mask = coco_mask.decode(rle_mask)
        color = random_color()
        for i in range(3):
            mask_overlay[:, :, i] = np.where(binary_mask == 1, color[i], mask_overlay[:, :, i])
    return mask_overlay

def show_mask_on_white_background(mask, ax):
    if mask is not None:
        h, w = mask.shape
        white_background = np.ones((h, w, 3), dtype=np.uint8) * 255
        color = np.concatenate([np.random.random(3), np.array([1])], axis=0) * 255
        mask_image = mask.reshape(h, w, 1) * color[:3].reshape(1, 1, 3)
        white_background[mask_image[:, :, 0] > 0] = mask_image[mask_image[:, :, 0] > 0]
        ax.imshow(white_background)
    else:
        ax.text(0.5, 0.5, 'No mask detected', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
    ax.set_xticks([])
    ax.set_yticks([])

def get_image_info(dataset_directory, num_images=3):
    image_mask_pairs = []
    for filename in os.listdir(dataset_directory):
        if filename.endswith(".jpg"):
            image_path = os.path.join(dataset_directory, filename)
            mask_filename = filename.replace(".jpg", ".json")
            mask_path = os.path.join(dataset_directory, mask_filename)
            if os.path.exists(mask_path):
                image_mask_pairs.append((image_path, mask_path))
    selected_pairs = random.sample(image_mask_pairs, min(num_images, len(image_mask_pairs)))
    return selected_pairs

def get_ground_truth_masks(mask_path):
    binary_masks = []
    with open(mask_path, 'r') as json_file:
        mask_data = json.load(json_file)
    for annotation in mask_data['annotations']:
        rle_mask = annotation['segmentation']
        binary_mask = coco_mask.decode(rle_mask)
        binary_masks.append(binary_mask)
    return binary_masks

def calculate_metrics(pred_mask, gt_mask):
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    iou = intersection / union if union != 0 else 0

    tp = intersection  # True Positives
    fp = pred_mask.sum() - tp  # False Positives
    fn = gt_mask.sum() - tp  # False Negatives

    recall = tp / (tp + fn) if (tp + fn) != 0 else 0
    precision = tp / (tp + fp) if (tp + fp) != 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0

    return iou, recall, precision, f1

# Main script
dataset_directory = "SA1B"
selected_images = get_image_info(dataset_directory, 3)

all_ious = []

for image_path, mask_path in selected_images:
    image_name = os.path.basename(image_path)
    image = Image.open(image_path).convert("RGB")
    raw_image = np.array(image)

    ground_truth_masks = get_ground_truth_masks(mask_path)
    outputs = generator(image, points_per_batch=64)
    predicted_masks = outputs["masks"]
    max_ious = []

    for gt_mask in ground_truth_masks:
        max_iou = 0
        max_pred = None
        
    
        for pred_mask in predicted_masks:
            iou, rec, prec, f1 = calculate_metrics(pred_mask, gt_mask)
            if iou > 0:  # Only consider non-zero IoU cases
                if iou > max_iou:
                    max_pred = pred_mask
                    max_iou = iou
                        
        max_ious.append(max_iou)
        
        if max_iou:
        
            print(f'Image: {image_name}, IoU: {max_iou:.4f}')

            # Display the image, current ground truth mask, and corresponding predicted mask
            plt.figure(figsize=(18, 6))

            plt.subplot(1, 3, 1)
            plt.imshow(raw_image)
            plt.title('Original Image')
            plt.axis('off')

            plt.subplot(1, 3, 2)
            show_mask_on_white_background(gt_mask, plt.gca())
            plt.title('Current Ground Truth Mask')
            plt.axis('off')

            plt.subplot(1, 3, 3)
            show_mask_on_white_background(max_pred, plt.gca())
            plt.title('Corresponding Predicted Mask')
            plt.axis('off')

            plt.show()
    image_iou = sum(max_ious)/len(max_ious)
    print(f'Image: {image_name}, Ave IoU: {image_iou:.4f}')
    all_ious.append(image_iou)
    
total_iou = sum(all_ious)/len(all_ious)
print(f'Total IoU: {total_iou:.4f}')