In [None]:
import time
import torch
import numpy as np
import cv2
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import matplotlib.pyplot as plt
import os

# Load the model configuration and checkpoint
sam2_checkpoint = "model_mvd.torch"  # Path to the fine-tuned model weights
model_cfg = "sam2_hiera_l.yaml"  # Path to the model configuration file
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")  # Load the fine-tuned model
predictor = SAM2ImagePredictor(sam2_model)

# Set the model to evaluation mode
predictor.model.eval()

def inference_without_annotation(image_path):
    # Load the image
    Img = np.array(Image.open(image_path))[..., :3]  # Read the image (ignoring alpha channel)

    # Resize the image for inference
    r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]])  # Scaling factor
    Img_resized = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))

    # Prepare dummy points for the inference (since no annotations are provided)
    points = [[[Img_resized.shape[1] // 2, Img_resized.shape[0] // 2]]]  # Center point as a simple prompt

    # Perform inference
    start_time = time.time()  # Start timer
    predictor.set_image(Img_resized)  # Process the image with the model

    # Prepare inputs for the prompt encoder
    mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
        np.array(points), np.ones([len(points), 1]), box=None, mask_logits=None, normalize_coords=True)
    
    # Get embeddings from the prompt encoder
    sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
        points=(unnorm_coords, labels), boxes=None, masks=None)
    
    # Decode masks using the SAM2 mask decoder
    high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
    low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
        image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
        image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=True,
        repeat_image=unnorm_coords.shape[0] > 1,
        high_res_features=high_res_features,
    )
    
    # Post-process the masks
    prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
    prd_masks = torch.sigmoid(prd_masks[:, 0])  # Convert logits to probabilities

    # Measure the inference time
    inference_time = time.time() - start_time

    return prd_masks.cpu().numpy(), inference_time

def visualize_masks(image, masks):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        plt.imshow(mask, alpha=0.5)  # Overlay each mask with some transparency
    plt.axis('off')
    plt.show()

def test_multiple_images(test_dir):
    # Get a list of all test images
    image_paths = [os.path.join(test_dir, img) for img in os.listdir(test_dir) if img.endswith(('.jpg', '.png'))]

    results = []
    for image_path in image_paths:
        print(f"Processing {image_path}...")
        predicted_masks, inference_time = inference_without_annotation(image_path)

        # Load and resize the image to match the mask size
        image = np.array(Image.open(image_path))[..., :3]
        r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])
        image_resized = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))

        # Store results for later analysis or visualization
        results.append({
            "image_path": image_path,
            "masks": predicted_masks,
            "inference_time": inference_time,
            "resized_image": image_resized
        })

        # Visualize the results for each image
        visualize_masks(image_resized, predicted_masks)
        print(f"Inference Time: {inference_time:.4f} seconds\n")

    return results

# Example usage
test_dir = r"segment-anything-2/Dataset/testing/images"  # Replace with your test images directory

# Perform testing on multiple images
results = test_multiple_images(test_dir)