In [None]:
!git clone https://github.com/facebookresearch/segment-anything-2 /kaggle/working/segment-anything-2
%cd /kaggle/working/segment-anything-2
!pip install -q -e .

###  **To Download all the checkpoints** 

In [None]:
# !wget -O /kaggle/working/segment-anything-2/sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
# !wget -O /kaggle/working/segment-anything-2/sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"
# !wget -O /kaggle/working/segment-anything-2/sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"
# !wget -O /kaggle/working/segment-anything-2/sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"

In [None]:
import os
import random
import pandas as pd
import cv2
import torch
import torch.nn.utils
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.model_selection import train_test_split

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

In [None]:
def set_seeds():
    SEED_VALUE = 42
    random.seed(SEED_VALUE)
    np.random.seed(SEED_VALUE)
    torch.manual_seed(SEED_VALUE)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(SEED_VALUE)
        torch.cuda.manual_seed_all(SEED_VALUE)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

set_seeds()

# **Run this for Non-Augmented Data**

In [None]:
import os
from sklearn.model_selection import train_test_split

# Dataset directories
DATASET_PATH = "/kaggle/input/sample/NOT-AUGMENTED/DATASET_FINAL"
images_dir = os.path.join(DATASET_PATH, "JPEGImages")
masks_dir = os.path.join(DATASET_PATH, "Annotations")


# List all image and mask files
image_files = sorted([f for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
mask_files = sorted([f for f in os.listdir(masks_dir) if f.endswith('.png')])

# Extract base names (without extensions) for matching
image_basenames = {os.path.splitext(f)[0]: f for f in image_files}
mask_basenames = {os.path.splitext(f)[0]: f for f in mask_files}

# Match images with corresponding masks
data = []
for base_name in image_basenames:
    mask_name = base_name + '_mask'  # Add '_mask' to match the mask file
    if mask_name in mask_basenames:
        data.append({
            "image": os.path.join(images_dir, image_basenames[base_name]),
            "annotation": os.path.join(masks_dir, mask_basenames[mask_name])
        })
    else:
        print(f"No matching mask for {image_basenames[base_name]}")


# Check if matching worked
print(f"Total Pairs Found: {len(data)}")

# Split into train and test sets (80% train, 20% test)
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# Display sample data
print(f"Training Samples: {len(train_data)}")
print(f"Testing Samples: {len(test_data)}")

sampled_train_data = random.sample(train_data, max(1, int(0.2 * len(train_data))))
sampled_test_data = random.sample(test_data, max(1, int(0.2 * len(test_data))))

# Combine into one sample list
data_sample = sampled_train_data + sampled_test_data

print(f"Total Combined Sample Size: {len(data_sample)}")

# **Run this for Augmented Data**

In [None]:
# import os
# import re
# from sklearn.model_selection import train_test_split

# # Dataset directories
# DATASET_PATH = "/kaggle/input/sample/AUGMENTED/DATASET_FINAL"
# images_dir = os.path.join(DATASET_PATH, "JPEGImages")
# masks_dir = os.path.join(DATASET_PATH, "Annotations")

# # List all image and mask files
# image_files = sorted([f for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
# mask_files = sorted([f for f in os.listdir(masks_dir) if f.endswith('.png')])

# # Function to extract base names for matching (keeping augmentation number)
# def extract_base_name(filename, is_mask=False):
#     """
#     Extracts base name while preserving augmentation numbers (_augX).
#     Ensures correct matching of original and augmented files.
#     """
#     filename = filename.replace(".jpg", "").replace(".png", "").replace(".jpeg", "")

#     if is_mask:
#         filename = filename.replace("_orig_mask", "_orig")  # Convert _orig_mask to _orig for matching
#         filename = re.sub(r"_aug_mask(\d+)", r"_aug\1", filename)  # Convert _aug_maskX to _augX
#     else:
#         filename = filename  # No extra processing needed

#     return filename

# # Create dictionaries mapping base names to file paths
# image_basenames = {extract_base_name(f): f for f in image_files}
# mask_basenames = {extract_base_name(f, is_mask=True): f for f in mask_files}

# # Debug: Show extracted base names
# print("Sample extracted image base names:", list(image_basenames.keys())[:5])
# print("Sample extracted mask base names:", list(mask_basenames.keys())[:5])

# # Match images with corresponding masks
# data = []
# for base in image_basenames:
#     if base in mask_basenames:
#         data.append({
#             "image": os.path.join(images_dir, image_basenames[base]),
#             "annotation": os.path.join(masks_dir, mask_basenames[base])
#         })

# print(f"Total Pairs Found: {len(data)}")

# if len(data) == 0:
#     raise ValueError("No matching image-mask pairs found. Please check the filenames and directory structure.")

# # Split into train and test sets (80% train, 20% test)
# train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# print(f"Training Samples: {len(train_data)}")
# print(f"Testing Samples: {len(test_data)}")

# **Fine Tuning**

In [None]:
def read_batch(data, visualize_data=True):
   ent = data[np.random.randint(len(data))]
   Img = cv2.imread(ent["image"])[..., ::-1]
   ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)

   if Img is None or ann_map is None:
       print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}")
       return None, None, None, 0

   r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]])
   Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
   ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),
                        interpolation=cv2.INTER_NEAREST)

   binary_mask = np.zeros_like(ann_map, dtype=np.uint8)
   points = []
   inds = np.unique(ann_map)[1:]
   for ind in inds:
       mask = (ann_map == ind).astype(np.uint8)
       binary_mask = np.maximum(binary_mask, mask)

   eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)
   coords = np.argwhere(eroded_mask > 0)
   if len(coords) > 0:
       for _ in inds:
           yx = np.array(coords[np.random.randint(len(coords))])
           points.append([yx[1], yx[0]])
   points = np.array(points)

   if visualize_data:
       plt.figure(figsize=(15, 5))
       plt.subplot(1, 3, 1)
       plt.title('Original Image')
       plt.imshow(Img)
       plt.axis('off')

       plt.subplot(1, 3, 2)
       plt.title('Binarized Mask')
       plt.imshow(binary_mask, cmap='gray')
       plt.axis('off')

       plt.subplot(1, 3, 3)
       plt.title('Binarized Mask with Points')
       plt.imshow(binary_mask, cmap='gray')
       colors = list(mcolors.TABLEAU_COLORS.values())
       for i, point in enumerate(points):
           plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=25)
       plt.axis('off')

       plt.tight_layout()
       plt.show()

   binary_mask = np.expand_dims(binary_mask, axis=-1)
   binary_mask = binary_mask.transpose((2, 0, 1))
   points = np.expand_dims(points, axis=1)
   return Img, binary_mask, points, len(inds)

Img1, masks1, points1, num_masks = read_batch(train_data, visualize_data=True)

In [None]:
sam2_checkpoint = "/kaggle/input/sample/sam2_hiera_tiny.pt"
model_cfg = "sam2_hiera_t.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)

predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)

In [None]:
scaler = torch.amp.GradScaler()
NO_OF_STEPS = 3000
FINE_TUNED_MODEL_NAME = "SAM2_FT_Kidney"

optimizer = torch.optim.AdamW(params=predictor.model.parameters(),
                              lr=0.0005,
                              weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.6)
accumulation_steps = 8

In [None]:
def read_batch(data, visualize_data=True):
   ent = data[np.random.randint(len(data))]
   Img = cv2.imread(ent["image"])[..., ::-1]
   ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)

   if Img is None or ann_map is None:
       print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}")
       return None, None, None, 0

   r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]])
   Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
   ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),
                        interpolation=cv2.INTER_NEAREST)

   binary_mask = np.zeros_like(ann_map, dtype=np.uint8)
   points = []
   inds = np.unique(ann_map)[1:]
   for ind in inds:
       mask = (ann_map == ind).astype(np.uint8)
       binary_mask = np.maximum(binary_mask, mask)

   eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)
   coords = np.argwhere(eroded_mask > 0)
   if len(coords) > 0:
       for _ in inds:
           yx = np.array(coords[np.random.randint(len(coords))])
           points.append([yx[1], yx[0]])
   points = np.array(points)

   if visualize_data:
       plt.figure(figsize=(15, 5))
       plt.subplot(1, 3, 1)
       plt.title('Original Image')
       plt.imshow(Img)
       plt.axis('off')

       plt.subplot(1, 3, 2)
       plt.title('Binarized Mask')
       plt.imshow(binary_mask, cmap='gray')
       plt.axis('off')

       plt.subplot(1, 3, 3)
       plt.title('Binarized Mask with Points')
       plt.imshow(binary_mask, cmap='gray')
       colors = list(mcolors.TABLEAU_COLORS.values())
       for i, point in enumerate(points):
           plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=25)
       plt.axis('off')

       plt.tight_layout()
       plt.show()

   binary_mask = np.expand_dims(binary_mask, axis=-1)
   binary_mask = binary_mask.transpose((2, 0, 1))
   points = np.expand_dims(points, axis=1)
   return Img, binary_mask, points, len(inds)

Img1, masks1, points1, num_masks = read_batch(train_data, visualize_data=True)

# **SGD**

In [None]:
# def train(predictor, train_data, step, mean_iou):
#     global max_mean_iou_in_interval  # Store max IoU across interval
    
#     with torch.amp.autocast(device_type='cuda'):
#         image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False)

#         if image is None or mask is None or num_masks == 0:
#             return max_mean_iou_in_interval  # Return max IoU in the interval  

#         input_label = np.ones((num_masks, 1))

#         # Ensure input_point has at least one valid point
#         if input_point is None or input_point.size == 0:
#             print(f"⚠ Step {step}: Skipping due to empty input_point")
#             return max_mean_iou_in_interval  

#         # Ensure correct shape (N,1,2)
#         if input_point.ndim == 2 and input_point.shape[1] == 2:
#             input_point = np.expand_dims(input_point, axis=1)

#         predictor.set_image(image)
#         mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
#             input_point, input_label, box=None, mask_logits=None, normalize_coords=True
#         )

#         sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
#             points=(unnorm_coords, labels), boxes=None, masks=None
#         )

#         batched_mode = unnorm_coords.shape[0] > 1
#         high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
#         low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
#             image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
#             image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
#             sparse_prompt_embeddings=sparse_embeddings,
#             dense_prompt_embeddings=dense_embeddings,
#             multimask_output=True,
#             repeat_image=batched_mode,
#             high_res_features=high_res_features,
#         )

#         prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

#         gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
#         prd_mask = torch.sigmoid(prd_masks[:, 0])

#         seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-6)).mean()

#         inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
#         iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

#         score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
#         loss = seg_loss + score_loss * 0.05

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         scheduler.step()  # <-- Updated: Adjust learning rate after optimizer step

#         # Update mean IoU
#         mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())

#         # Track max mean IoU in the interval
#         if step % (NO_OF_STEPS / 10) == 1:  # Start of interval
#             max_mean_iou_in_interval = mean_iou  # Reset at interval start

#         max_mean_iou_in_interval = max(max_mean_iou_in_interval, mean_iou)  # Update max IoU
        
#         # Save logs for visualization
#         if step % (NO_OF_STEPS / 20) == 1 or step == NO_OF_STEPS:
#             train_logs["step"].append(step)
#             train_logs["loss"].append(seg_loss.item())
#             train_logs["iou"].append(mean_iou)
#             train_logs["lr"].append(optimizer.param_groups[0]["lr"])

#         if step % (NO_OF_STEPS / 20) == 0:
#             current_lr = optimizer.param_groups[0]["lr"]
#             print(f"Step {step}: LR = {current_lr:.6f}, IoU = {mean_iou:.6f}, Loss = {seg_loss:.6f}")

#     return max_mean_iou_in_interval  # Return the maximum IoU in the interval

In [None]:
# def validate(predictor, test_data, step, mean_iou, optimizer, scheduler):
#     global max_mean_iou_in_interval  # Store max IoU across interval
    
#     predictor.model.eval()
#     with torch.amp.autocast(device_type='cuda'):
#         with torch.no_grad():
#             image, mask, input_point, num_masks = read_batch(test_data, visualize_data=False)

#             if image is None or mask is None or num_masks == 0:
#                 print(f"⚠ Step {step}: Skipping due to missing or empty test data")
#                 return max_mean_iou_in_interval  # Return max IoU in the interval  

#             input_label = np.ones((num_masks, 1))

#             if input_point is None or input_point.size == 0:
#                 print(f"⚠ Step {step}: Skipping due to empty input_point")
#                 return max_mean_iou_in_interval  # Return max IoU in the interval  

#             if input_point.ndim == 2 and input_point.shape[1] == 2:
#                 input_point = np.expand_dims(input_point, axis=1)

#             predictor.set_image(image)
#             mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
#                 input_point, input_label, box=None, mask_logits=None, normalize_coords=True
#             )

#             sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
#                 points=(unnorm_coords, labels), boxes=None, masks=None
#             )

#             batched_mode = unnorm_coords.shape[0] > 1
#             high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
#             low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
#                 image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
#                 image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
#                 sparse_prompt_embeddings=sparse_embeddings,
#                 dense_prompt_embeddings=dense_embeddings,
#                 multimask_output=True,
#                 repeat_image=batched_mode,
#                 high_res_features=high_res_features,
#             )

#             prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

#             gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
#             prd_mask = torch.sigmoid(prd_masks[:, 0])

#             seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6)
#                         - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-6)).mean()

#             inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
#             iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

#             if step % (NO_OF_STEPS / 2) == 0:
#                 FINE_TUNED_MODEL = FINE_TUNED_MODEL_NAME + "_" + str(step) + ".pt"
#                 torch.save(predictor.model.state_dict(), FINE_TUNED_MODEL)

#             mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())

#             # Track max mean IoU in the interval
#             if step % (NO_OF_STEPS / 10) == 1:  # Start of interval
#                 max_mean_iou_in_interval = mean_iou  # Reset at interval start

#             max_mean_iou_in_interval = max(max_mean_iou_in_interval, mean_iou)  # Update max IoU
            
#             writer.add_scalar("Loss/Validation", seg_loss.item(), step)
#             writer.add_scalar("IoU/Validation", mean_iou, step)

#             if step % (NO_OF_STEPS / 10) == 1 or step == NO_OF_STEPS:
#                 valid_logs["step"].append(step)
#                 valid_logs["loss"].append(seg_loss.item())
#                 valid_logs["iou"].append(mean_iou)

#             if step == NO_OF_STEPS:
#                 valid_logs["step"].append(step)
#                 valid_logs["loss"].append(seg_loss.item())
#                 valid_logs["iou"].append(mean_iou)

#             if step % (NO_OF_STEPS / 20) == 0:
#                 print(f"Step {step}: Validation IoU = {mean_iou:.6f}, Validation Loss = {seg_loss:.6f}")

#             scheduler.step()  # Update the learning rate using the scheduler
#             current_lr = optimizer.param_groups[0]["lr"]
#             writer.add_scalar("LR/Validation", current_lr, step)

#     return max_mean_iou_in_interval  # Return the maximum IoU in the interval

# **AdamW**

In [None]:
def train(predictor, train_data, step, mean_iou):
    global max_mean_iou_in_interval  # Store max IoU across interval
    
    with torch.amp.autocast(device_type='cuda'):
        image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False)

        if image is None or mask is None or num_masks == 0:
            return max_mean_iou_in_interval  # Return max IoU in the interval  

        input_label = np.ones((num_masks, 1))

        # Ensure input_point has at least one valid point
        if input_point is None or input_point.size == 0:
            print(f"⚠ Step {step}: Skipping due to empty input_point")
            return max_mean_iou_in_interval  

        # Ensure correct shape (N,1,2)
        if input_point.ndim == 2 and input_point.shape[1] == 2:
            input_point = np.expand_dims(input_point, axis=1)

        predictor.set_image(image)
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
            input_point, input_label, box=None, mask_logits=None, normalize_coords=True
        )

        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=(unnorm_coords, labels), boxes=None, masks=None
        )

        batched_mode = unnorm_coords.shape[0] > 1
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=batched_mode,
            high_res_features=high_res_features,
        )

        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

        gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
        prd_mask = torch.sigmoid(prd_masks[:, 0])

        seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-6)).mean()

        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
        loss = seg_loss + score_loss * 0.05

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Update mean IoU
        mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())

        # Track max mean IoU in the interval
        if step % (NO_OF_STEPS / 10) == 1:  # Start of interval
            max_mean_iou_in_interval = mean_iou  # Reset at interval start

        max_mean_iou_in_interval = max(max_mean_iou_in_interval, mean_iou)  # Update max IoU
        
        # Save logs for visualization
        if step % (NO_OF_STEPS / 20) == 1 or step == NO_OF_STEPS:
            train_logs["step"].append(step)
            train_logs["loss"].append(seg_loss.item())
            train_logs["iou"].append(mean_iou)
            train_logs["lr"].append(optimizer.param_groups[0]["lr"])

        if step % (NO_OF_STEPS / 20) == 0:
            current_lr = optimizer.param_groups[0]["lr"]
            print(f"Step {step}: LR = {current_lr:.6f}, IoU = {mean_iou:.6f}, Loss = {seg_loss:.6f}")

    return max_mean_iou_in_interval  # Return the maximum IoU in the interval

In [None]:
def validate(predictor, test_data, step, mean_iou):
    global max_mean_iou_in_interval  # Store max IoU across interval
    
    predictor.model.eval()
    with torch.amp.autocast(device_type='cuda'):
        with torch.no_grad():
            image, mask, input_point, num_masks = read_batch(test_data, visualize_data=False)

            if image is None or mask is None or num_masks == 0:
                print(f"⚠ Step {step}: Skipping due to missing or empty test data")
                return max_mean_iou_in_interval  # Return max IoU in the interval  

            input_label = np.ones((num_masks, 1))

            if input_point is None or input_point.size == 0:
                print(f"⚠ Step {step}: Skipping due to empty input_point")
                return max_mean_iou_in_interval  # Return max IoU in the interval  

            if input_point.ndim == 2 and input_point.shape[1] == 2:
                input_point = np.expand_dims(input_point, axis=1)

            predictor.set_image(image)
            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                input_point, input_label, box=None, mask_logits=None, normalize_coords=True
            )

            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels), boxes=None, masks=None
            )

            batched_mode = unnorm_coords.shape[0] > 1
            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=batched_mode,
                high_res_features=high_res_features,
            )

            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

            gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])

            seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6)
                        - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-6)).mean()

            inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
            iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

            if step % (NO_OF_STEPS / 2) == 0:
                FINE_TUNED_MODEL = FINE_TUNED_MODEL_NAME + "_" + str(step) + ".pt"
                torch.save(predictor.model.state_dict(), FINE_TUNED_MODEL)

            mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())

            # Track max mean IoU in the interval
            if step % (NO_OF_STEPS / 10) == 1:  # Start of interval
                max_mean_iou_in_interval = mean_iou  # Reset at interval start

            # max_mean_iou_in_interval = mean_iou  # Reset at interval start

            max_mean_iou_in_interval = max(max_mean_iou_in_interval, mean_iou)  # Update max IoU
            
            writer.add_scalar("Loss/Validation", seg_loss.item(), step)
            writer.add_scalar("IoU/Validation", mean_iou, step)

            if step % (NO_OF_STEPS / 10) == 1 or step == NO_OF_STEPS:
                valid_logs["step"].append(step)
                valid_logs["loss"].append(seg_loss.item())
                valid_logs["iou"].append(mean_iou)

            # valid_logs["step"].append(step)
            # valid_logs["loss"].append(seg_loss.item())
            # valid_logs["iou"].append(mean_iou)

            if step == NO_OF_STEPS:
                valid_logs["step"].append(step)
                valid_logs["loss"].append(seg_loss.item())
                valid_logs["iou"].append(mean_iou)

            if step % (NO_OF_STEPS / 20) == 0:
                print(f"Step {step}: Validation IoU = {mean_iou:.6f}, Validation Loss = {seg_loss:.6f}")

    return max_mean_iou_in_interval  # Return the maximum IoU in the interval


In [None]:
from torch.utils.tensorboard import SummaryWriter
import pandas as pd

# Create a directory for logs and graphs
log_dir = "/kaggle/working/logs"
writer = SummaryWriter(log_dir=log_dir)

# Store values for plotting manually
train_logs = {"step": [], "loss": [], "iou": [], "lr": []}
valid_logs = {"step": [], "loss": [], "iou": []}


train_mean_iou = 0
valid_mean_iou = 0

for step in range(1, NO_OF_STEPS + 1):
    train_mean_iou = train(predictor, train_data, step, train_mean_iou)
    valid_mean_iou = validate(predictor, test_data, step, valid_mean_iou)

writer.close()

# Save logs as CSV files
pd.DataFrame(train_logs).to_csv("/kaggle/working/train_logs.csv", index=False)
pd.DataFrame(valid_logs).to_csv("/kaggle/working/valid_logs.csv", index=False)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Load logs
train_df = pd.read_csv("/kaggle/working/train_logs.csv")
valid_df = pd.read_csv("/kaggle/working/valid_logs.csv")

# Set save path
save_path = "/kaggle/working/"

# Plot Loss
plt.figure(figsize=(10,5))
plt.plot(train_df["step"], train_df["loss"], label="Train Loss")
plt.plot(valid_df["step"], valid_df["loss"], label="Validation Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.ylim(0, 0.1)
plt.legend()
plt.title("Loss Curve")
plt.savefig(save_path + "loss_curve.png")  # Save the image
plt.show()

# Plot IoU
plt.figure(figsize=(10,5))
plt.plot(train_df["step"], train_df["iou"], label="Train IoU")
plt.plot(valid_df["step"], valid_df["iou"], label="Validation IoU")
plt.xlabel("Steps")
plt.ylabel("IoU")
plt.legend()
plt.title("IoU Curve")
plt.savefig(save_path + "iou_curve.png")  # Save the image
plt.show()

In [None]:
# Function to read and resize image & mask
def read_image(image_path, mask_path):
    img = cv2.imread(image_path)
    mask = cv2.imread(mask_path, 0)
    
    if img is None or mask is None:
        raise FileNotFoundError(f"Error reading image/mask at {image_path} or {mask_path}")
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    r = min(1024 / img.shape[1], 1024 / img.shape[0])
    img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
    mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), interpolation=cv2.INTER_NEAREST)
    
    return img, mask

# Function to sample points inside the input mask
def get_points(mask, num_points=30):
    coords = np.argwhere(mask > 0)
    if len(coords) == 0:
        raise ValueError("No valid points found in the mask.")
    return np.array([[coords[np.random.randint(len(coords))][::-1]] for _ in range(num_points)])

# Load the fine-tuned model
FINE_TUNED_MODEL_WEIGHTS = "/kaggle/working/segment-anything-2/SAM2_FT_Kidney_3000.pt"
if not os.path.exists(FINE_TUNED_MODEL_WEIGHTS):
    raise FileNotFoundError(f"Model weights not found at {FINE_TUNED_MODEL_WEIGHTS}")

sam2_checkpoint = "/kaggle/input/sample/sam2_hiera_tiny.pt"
model_cfg = "sam2_hiera_t.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(FINE_TUNED_MODEL_WEIGHTS, map_location="cuda"))

# Randomly select a test image
selected_entry = random.choice(test_data)
image_path, mask_path = selected_entry['image'], selected_entry['annotation']
print(f"Selected Image: {image_path}\nMask Path: {mask_path}")

# Load the image and mask
image, target_mask = read_image(image_path, mask_path)

# Generate random points for input
input_points = get_points(target_mask, num_points=30)

# Perform inference
with torch.no_grad():
    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=np.ones([input_points.shape[0], 1])
    )

# Process the predicted masks
np_masks = np.array(masks[:, 0])
np_scores = scores[:, 0]
sorted_masks = np_masks[np.argsort(np_scores)][::-1]

# Initialize segmentation map
seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)

for i, mask in enumerate(sorted_masks):
    if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
        continue
    mask_bool = mask.astype(bool)
    mask_bool[occupancy_mask] = False
    seg_map[mask_bool] = i + 1
    occupancy_mask[mask_bool] = True

# Visualization
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.title('Test Image')
plt.imshow(image)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Ground Truth Mask')
plt.imshow(target_mask, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('Predicted Segmentation Map')
plt.imshow(seg_map, cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
from sklearn.metrics import jaccard_score

def test_model_on_sample(data_sample, predictor, num_points=30):
    total_iou = 0.0
    valid_samples = 0

    for entry in data_sample:
        image_path, mask_path = entry['image'], entry['annotation']
        # print(f"Testing: {os.path.basename(image_path)}")

        # Load image and mask
        image, target_mask = read_image(image_path, mask_path)

        # Skip if mask is empty
        try:
            input_points = get_points(target_mask, num_points=num_points)
        except ValueError:
            print("Skipped (empty mask)")
            continue

        # Predict
        with torch.no_grad():
            predictor.set_image(image)
            masks, scores, logits = predictor.predict(
                point_coords=input_points,
                point_labels=np.ones([input_points.shape[0], 1])
            )

        # Process predicted masks
        np_masks = np.array(masks[:, 0])
        np_scores = scores[:, 0]
        sorted_masks = np_masks[np.argsort(np_scores)][::-1]

        seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
        occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)

        for i, mask in enumerate(sorted_masks):
            if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
                continue
            mask_bool = mask.astype(bool)
            mask_bool[occupancy_mask] = False
            seg_map[mask_bool] = 1  # For binary IoU
            occupancy_mask[mask_bool] = True

        # Resize target_mask to match seg_map if needed
        if seg_map.shape != target_mask.shape:
            target_mask = cv2.resize(target_mask, (seg_map.shape[1], seg_map.shape[0]), interpolation=cv2.INTER_NEAREST)

        # Binarize the target mask (assumes mask values > 0 are foreground)
        target_mask_bin = (target_mask > 0).astype(np.uint8)

        # Flatten both masks to compute IoU
        iou = jaccard_score(target_mask_bin.flatten(), seg_map.flatten(), zero_division=0)
        total_iou += iou
        valid_samples += 1

    # Report average IoU
    if valid_samples == 0:
        print("No valid samples to evaluate IoU.")
    else:
        avg_iou = total_iou / valid_samples
        print(f"\n✅ Average IoU over {valid_samples} samples: {avg_iou:.4f}")

In [None]:
test_model_on_sample(data_sample, predictor)