In [None]:
import numpy as np
import torch
import json
import os
import matplotlib.pyplot as plt
import cv2
import random
import torch.nn as nn

# Generate all hints

In [None]:
# Calculate bounding box
def generate_mask_calculate_bounding_box(points, mask_image_resized):
    x_coords = [p[0] for p in points]
    y_coords = [p[1] for p in points]
    min_x, max_x = min(x_coords), max(x_coords)
    min_y, max_y = min(y_coords), max(y_coords)
    # Create a new all-zero mask, the same size as mask_image_resized
    mask = np.zeros_like(mask_image_resized, dtype=np.uint8)
    mask[min_y:max_y, min_x:max_x] = np.where(mask_image_resized[min_y:max_y, min_x:max_x] > 0, 1, 0)
    mask = np.where(mask == 0, -50, mask)
    mask = np.where(mask == 1, 50, mask)
    mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
    return [min_x, min_y, max_x, max_y], mask

def calculate_weighted_centroid(mask):
    # Use connected components analysis to find labels for all masks
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=4)

    # If there is only one mask, directly calculate its centroid
    if num_labels == 2:
        return calculate_single_mask_centroid(mask)

    # Otherwise, find the largest mask
    max_label = np.argmax(stats[1:, cv2.CC_STAT_AREA]) + 1
    largest_mask = (labels == max_label).astype(np.uint8)

    return calculate_single_mask_centroid(largest_mask)

def calculate_single_mask_centroid(mask):
    area = mask.sum()
    # Calculate distance transform
    dist_transform = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
    # Adjust weights using the square of the distance
    adjusted_weights = np.square(dist_transform)

    # Get the coordinates of the foreground pixels in the mask
    y_indices, x_indices = np.where(mask > 0)

    # Calculate weighted centroid
    total_weight = np.sum(adjusted_weights[mask > 0])
    if total_weight == 0 or np.isinf(total_weight) or np.isnan(total_weight):
        return None
    centroid_x = np.sum(x_indices * adjusted_weights[mask > 0]) / total_weight
    centroid_y = np.sum(y_indices * adjusted_weights[mask > 0]) / total_weight
    return [int(centroid_x), int(centroid_y)], area

def calculate_foreground_centroid(bbox, polygon_points, original_height, original_width):
    mask = np.zeros((original_height, original_width), dtype=np.uint8)
    cv2.fillPoly(mask, np.array([polygon_points], dtype=np.int32), 255)
    x1, y1, x2, y2 = bbox  # Coordinates of the top left and bottom right corners
    # Convert x1, y1, x2, y2 to integers
    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
    w = x2 - x1
    h = y2 - y1
    cropped_mask = mask[y1:y2, x1:x2]
    # Calculate centroid
    weighted_centroid = calculate_weighted_centroid(cropped_mask)
    if weighted_centroid and cropped_mask[weighted_centroid[0][1], weighted_centroid[0][0]] > 0:
        # If the weighted centroid is within the mask, return it directly
        return weighted_centroid[0][0] + x1, weighted_centroid[0][1] + y1

    # Divide the bounding box into four equal triangles and calculate their respective weighted centroids
    triangles = [
        np.array([[0, 0], [w, 0], [0, h]], dtype=np.int32),
        np.array([[w, 0], [w, h], [0, h]], dtype=np.int32),
        np.array([[0, 0], [w, 0], [w, h]], dtype=np.int32),
        np.array([[0, 0], [0, h], [w, h]], dtype=np.int32)
    ]
    centroids = []
    for tri in triangles:
        # Create triangle mask
        tri_mask = np.zeros((h, w), dtype=np.uint8)
        cv2.fillConvexPoly(tri_mask, tri, 1)
        # Calculate the weighted centroid within the triangle area
        tri_weighted_centroid = calculate_weighted_centroid(cropped_mask * tri_mask)
        if tri_weighted_centroid and (cropped_mask * tri_mask)[tri_weighted_centroid[0][1], tri_weighted_centroid[0][0]] > 0:
            # Calculate the coordinates relative to the original image
            tri_weighted_centroid = ([tri_weighted_centroid[0][0] + x1, tri_weighted_centroid[0][1] + y1], tri_weighted_centroid[1])
            centroids.append((tri_weighted_centroid[0], tri_weighted_centroid[1]))
    if centroids:
        # Choose the centroid of the region with the largest area
        max_area_centroid = max(centroids, key=lambda x: x[1])[0]
        return max_area_centroid
    
    # If no suitable centroid is found, return None
    return None

def calculate_background_centroid(bbox, polygon_points, original_height, original_width):
    mask = np.zeros((original_height, original_width), dtype=np.uint8)
    cv2.fillPoly(mask, np.array([polygon_points], dtype=np.int32), 1)
    x1, y1, x2, y2 = bbox  # Coordinates of the top left and bottom right corners
    # Convert x1, y1, x2, y2 to integers
    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
    cropped_mask = mask[y1:y2, x1:x2]
    inverted_mask = 1 - cropped_mask

    # Adjust to within the bounding box
    adjusted_corners = [
        (min(x1 + 1, original_width - 1), min(y1 + 1, original_height - 1)),  # Top left corner
        (max(x2 - 1, 0), min(y1 + 1, original_height - 1)),                   # Top right corner
        (min(x1 + 1, original_width - 1), max(y2 - 1, 0)),                    # Bottom left corner
        (max(x2 - 1, 0), max(y2 - 1, 0))                                      # Bottom right corner
    ]

    # Convert the corner points' coordinates relative to the original image to relative to the bounding box
    adjusted_corners = [(cx - x1, cy - y1) for cx, cy in adjusted_corners]

    # Use connected components analysis to find labels for all backgrounds
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(inverted_mask, connectivity=4)
    # Store centroids and areas of background regions that contain corners
    corner_centroids = []
    for corner in adjusted_corners:
        cx, cy = corner
        label_at_corner = labels[cy, cx]
        if label_at_corner != 0:  # Ensure the corner is in a background region
            corner_mask = (labels == label_at_corner).astype(np.uint8) 
            result = calculate_single_mask_centroid(corner_mask)
            if result is not None:
                centroid, area = result
                if centroid and inverted_mask[centroid[1], centroid[0]] > 0:
                    corner_centroids.append((centroid, area))

    # Choose the centroid of the region with the largest area
    if corner_centroids:
        max_area_centroid = max(corner_centroids, key=lambda x: x[1])[0]
        # Calculate the coordinates relative to the original image
        max_area_centroid = (max_area_centroid[0] + x1, max_area_centroid[1] + y1)
        return max_area_centroid
    else:
        print("Unable to find a background centroid, randomly selecting a point")
        # Choose the largest area connected region
        max_area_label = np.argmax(stats[1:, 4]) + 1  # Ignore background label
        max_area_mask = (labels == max_area_label).astype(np.uint8)

        # Randomly select a point from the largest area
        possible_points = np.argwhere(max_area_mask)
        if len(possible_points) > 0:
            random_point = random.choice(possible_points)
            random_point_global = (random_point[1] + x1, random_point[0] + y1)
            return random_point_global

# Adjust coordinates
def adjust_coordinates(coordinates, original_width, original_height, new_width, new_height):
    scale_x = new_width / original_width
    scale_y = new_height / original_height
    return [int(coord * scale_x if i % 2 == 0 else coord * scale_y) for i, coord in enumerate(coordinates)]

def enlarge_bbox_and_adjust_coordinates(bbox, original_width, original_height, new_width, new_height, scale_factor=1.2):
    # Calculate the center point of the bounding box
    bbox_center_x = (bbox[0] + bbox[2]) / 2
    bbox_center_y = (bbox[1] + bbox[3]) / 2

    # Enlarge bounding box coordinates
    enlarged_bbox = []
    for i, coord in enumerate(bbox):
        if i % 2 == 0:  # X coordinate
            new_coord = (coord - bbox_center_x) * scale_factor + bbox_center_x
        else:           # Y coordinate
            new_coord = (coord - bbox_center_y) * scale_factor + bbox_center_y
        enlarged_bbox.append(int(new_coord))

    # Scale coordinates proportionally
    adjusted_bbox = adjust_coordinates(enlarged_bbox, original_width, original_height, new_width, new_height)

    # Ensure the bounding box does not exceed the new image boundaries
    adjusted_bbox[0] = max(0, min(adjusted_bbox[0], new_width - 1))  # x1
    adjusted_bbox[1] = max(0, min(adjusted_bbox[1], new_height - 1)) # y1
    adjusted_bbox[2] = max(0, min(adjusted_bbox[2], new_width - 1))  # x2
    adjusted_bbox[3] = max(0, min(adjusted_bbox[3], new_height - 1)) # y2

    return adjusted_bbox

def process_json_file(json_path, mask_image):
    with open(json_path, 'r') as json_file:
        data = json.load(json_file)
    
    original_width, original_height = data['imageWidth'], data['imageHeight']
    new_width, new_height = 1024, 1024
    mask_image_resized = cv2.resize(mask_image, (original_width, original_height), interpolation=cv2.INTER_NEAREST)
    bounding_boxes = []
    sampled_foreground_points = []
    sampled_background_points = []
    masks = []

    for shape in data['shapes']:
        bbox, mask = generate_mask_calculate_bounding_box(shape['points'], mask_image_resized)
        foreground_point = calculate_foreground_centroid(bbox, shape['points'], original_height, original_width)
        background_point = calculate_background_centroid(bbox, shape['points'], original_height, original_width)
        adjusted_bbox = enlarge_bbox_and_adjust_coordinates(bbox, original_width, original_height, new_width, new_height)
        adjusted_foreground_point = adjust_coordinates(foreground_point, original_width, original_height, new_width, new_height)
        adjusted_background_point = adjust_coordinates(background_point, original_width, original_height, new_width, new_height)
        masks.append(mask)
        bounding_boxes.append(adjusted_bbox)
        sampled_foreground_points.append(adjusted_foreground_point)
        sampled_background_points.append(adjusted_background_point)

    return masks, bounding_boxes, sampled_foreground_points, sampled_background_points

# Path to the folder containing JSON files
folder_path = 'json'
prompt_masks_path = 'prompt_mask'
file_list = os.listdir(folder_path)
random.shuffle(file_list)
image_data = {}
# Get all filenames in the folder
for filename in file_list:
    if filename.endswith('.json'):
        file_path = os.path.join(folder_path, filename)
        mask_file_path = os.path.join(prompt_masks_path, filename.split('.')[0] + '.jpg')
        # Read the mask image
        mask_image = cv2.imread(mask_file_path, cv2.IMREAD_GRAYSCALE)
        if mask_image is None:
            raise FileNotFoundError("Mask image file not found")
        masks, bounding_boxes, foreground_points, background_points = process_json_file(file_path, mask_image)
        image_name = filename.split('.')[0]
        image_data[image_name] = {
            'masks': masks,
            'bounding_boxes': bounding_boxes,
            'foreground_points': foreground_points,
            'background_points': background_points,
        }

# Output results
for image_name, data in image_data.items():
    print(f"Image: {image_name}")
    print("Number of masks:", len(data['masks']))
    print("Number of bounding boxes:", len(data['bounding_boxes']))
    print("Number of foreground points:", len(data['foreground_points']))
    print("Number of background points:", len(data['background_points']))

In [None]:
masks_path = 'mask'
images_path = 'datasets/JPEGImages'
for key in image_data.keys():
    image_data[key]['ground_truth_mask'] = []
    image_data[key]['prompt_mask'] = []

for filename in os.listdir(masks_path):
    if filename.endswith('.png'):
        file_path = os.path.join(masks_path, filename)
        # Read as a single-channel binary image
        mask = cv2.imread(file_path)
        gray_mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        mask = cv2.resize(gray_mask, (1024, 1024), interpolation=cv2.INTER_NEAREST)
        # mask[mask > 0] = 1
        # Visualize mask
        # plt.imshow(mask)
        # plt.show()
        image_name = filename.split('.')[0]
        image_data[image_name]['ground_truth_mask'] = mask

for filename in os.listdir(prompt_masks_path):
    if filename.endswith('.png') or filename.endswith('.jpg'):
        file_path = os.path.join(prompt_masks_path, filename)
        # Read as a single-channel binary image
        mask = cv2.imread(file_path)
        gray_mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        mask = cv2.resize(gray_mask, (1024, 1024), interpolation=cv2.INTER_NEAREST)
        mask[mask > 0] = 1
        # Visualize mask
        # plt.imshow(mask, cmap='gray')
        # plt.show()
        image_name = filename.split('.')[0]
        image_data[image_name]['prompt_mask'] = mask

for filename in os.listdir(images_path):
    if filename.endswith('.jpg'):
        file_path = os.path.join(images_path, filename)
        # Read the image
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_NEAREST)
        # Visualize image
        # plt.imshow(image)
        # plt.show()
        image_name = filename.split('.')[0]
        image_data[image_name]['image'] = image

In [None]:
model_type = 'vit_b'
checkpoint = 'sam_vit_b_01ec64.pth'
device = 'cuda:0'

In [None]:
from segment_anything import SamPredictor, sam_model_registry
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model.to(device)
sam_model.train();

In [None]:
# Preprocess the images
from collections import defaultdict

import torch

from segment_anything.utils.transforms import ResizeLongestSide

transformed_data = defaultdict(dict)
for filename in image_data.keys():
  image = cv2.imread(f'datasets/JPEGImages/{filename}.jpg')
  image=cv2.resize(image,(1024,1024),interpolation=cv2.INTER_NEAREST)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  
  transform = ResizeLongestSide(sam_model.image_encoder.img_size)
  input_image = transform.apply_image(image)
  input_image_torch = torch.as_tensor(input_image, device=device)
  transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
  
  input_image = sam_model.preprocess(transformed_image)
  original_image_size = image.shape[:2]
  input_size = tuple(transformed_image.shape[-2:])

  transformed_data[filename]['image'] = input_image
  transformed_data[filename]['input_size'] = input_size
  transformed_data[filename]['original_image_size'] = original_image_size

In [None]:
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
lr = 1e-4
wd = 0
# Get the parameters for mask_decoder and prompt_encoder
mask_decoder_params = sam_model.mask_decoder.parameters()
prompt_encoder_params = sam_model.prompt_encoder.parameters()
image_encoder_params = sam_model.image_encoder.parameters()
# Combine parameters
# all_params = list(mask_decoder_params) + list(prompt_encoder_params) + list(image_encoder_params)
all_params = list(prompt_encoder_params)
# Create optimizer
optimizer = torch.optim.Adam(all_params, lr=lr, weight_decay=wd)
# Assuming optimizer is the optimizer you have already defined
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)

class CombinedDiceCrossEntropyFocalLoss(nn.Module):
    def __init__(self, dice_loss_weight=1/3, focal_loss_weight=1/3, mse_loss_weight=1/3, alpha=0.8, gamma=2.0):
        super(CombinedDiceCrossEntropyFocalLoss, self).__init__()
        self.dice_loss_weight = dice_loss_weight
        self.focal_loss_weight = focal_loss_weight
        self.mse_loss_weight = mse_loss_weight
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, y_true, smooth=1.0):
        # Dice Loss calculation
        probs = torch.sigmoid(logits)
        y_true_f = y_true.view(-1)
        probs_f = probs.view(-1)
        intersection = torch.sum(y_true_f * probs_f)
        dice_loss = 1 - (2. * intersection + smooth) / (torch.sum(y_true_f) + torch.sum(probs_f) + smooth)
        
        # Focal Loss calculation
        BCE_loss = F.binary_cross_entropy_with_logits(logits.view(-1), y_true_f, reduction='none')
        pt = torch.exp(-BCE_loss) # Prevents nans when probability 0
        focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        # MSE Loss calculation
        mse_loss = F.mse_loss(probs_f, y_true_f)
        
        combined_loss = (self.dice_loss_weight * dice_loss +
                         self.focal_loss_weight * torch.mean(focal_loss) + self.mse_loss_weight * mse_loss)
        return combined_loss

loss_fn = CombinedDiceCrossEntropyFocalLoss()

## Train the model

In [None]:
from statistics import mean

from tqdm import tqdm
from torch.nn.functional import threshold, normalize
from torch.cuda.amp import GradScaler, autocast

# Initialize GradScaler
scaler = GradScaler()
# Train the model
num_epochs = 100
losses = []
best_loss = float('inf')
with open('train.txt', 'r') as file:
    train_filenames = file.read().splitlines()

for epoch in range(num_epochs):
  epoch_losses = []
  for filename in tqdm(train_filenames, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True):
    # Skip if the number of bounding boxes is greater than 150 or if it's empty
    if len(image_data[filename]['bounding_boxes']) > 150 or len(image_data[filename]['bounding_boxes']) == 0:
        continue
    input_image = transformed_data[filename]['image'].to(device)
    input_size = transformed_data[filename]['input_size']
    original_image_size = transformed_data[filename]['original_image_size']
    # Use autocast context manager for forward pass
    with autocast():
        image_embedding = sam_model.image_encoder(input_image)
        # Create two empty lists for storing sparse_embeddings and dense_embeddings
        sparse_embeddings_list = []
        dense_embeddings_list = []
        for i in range(len(image_data[filename]['bounding_boxes'])):
          prompt_box = np.array(image_data[filename]['bounding_boxes'][i])
          box = transform.apply_boxes(prompt_box, original_image_size)
          box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
          box_torch = box_torch[None, :]
          mask_torch = torch.from_numpy(image_data[filename]['masks'][i]).type(torch.float32).to(device)
          mask_torch = mask_torch.unsqueeze(0).unsqueeze(0)  # Add a dimension at 0, becomes [1,1, height, width]
          point_torch = torch.tensor([[image_data[filename]['foreground_points'][i],image_data[filename]['background_points'][i]]],dtype=torch.float32).to(device)
          type_torch = torch.tensor([[1,0]],dtype=torch.float32).to(device)
          points_torch = [point_torch,type_torch]
          sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
            points=points_torch,
            boxes=box_torch,
            masks=mask_torch,
            )
          # Add the obtained sparse_embeddings and dense_embeddings to their respective lists
          sparse_embeddings_list.append(sparse_embeddings)
          dense_embeddings_list.append(dense_embeddings)
        # Concatenate all elements in sparse_embeddings_list and dense_embeddings_list
        sparse_embeddings_all = torch.cat(sparse_embeddings_list, dim=0)
        dense_embeddings_all = torch.cat(dense_embeddings_list, dim=0)
        low_res_masks, iou_predictions = sam_model.mask_decoder(
          image_embeddings=image_embedding,
          image_pe=sam_model.prompt_encoder.get_dense_pe(),
          sparse_prompt_embeddings=sparse_embeddings_all,
          dense_prompt_embeddings=dense_embeddings_all,
          multimask_output=True,
        )
        upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
        # Calculate the index of the highest score for each channel
        max_score_indices = torch.argmax(iou_predictions, dim=1)
        # Initialize an empty tensor for storing selected masks
        selected_masks = torch.zeros_like(upscaled_masks[:, 0, :, :]).to(device)  # shape [C, W, H]
        # Select the highest scoring mask channel
        for i, index in enumerate(max_score_indices):
            selected_masks[i] = upscaled_masks[i, index]
        # Merge masks to get a binary mask
        binary_mask, _ = torch.max(selected_masks, dim=0, keepdim=True)
        gt_mask_resized = torch.from_numpy(image_data[filename]['ground_truth_mask']).unsqueeze(0).to(device)
        gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)
        prompt_mask = torch.from_numpy(image_data[filename]['prompt_mask']).unsqueeze(0).to(device)
        # Calculate loss
        loss = loss_fn(binary_mask, gt_binary_mask)

    # Clear gradients
    optimizer.zero_grad()
    # Scale the loss and perform backpropagation using GradScaler
    scaler.scale(loss).backward()
    # Unscale the gradients and perform optimizer step using GradScaler
    scaler.step(optimizer)
    scaler.update()
    epoch_losses.append(loss.item())
  losses.append(epoch_losses)
  current_lr = optimizer.param_groups[0]['lr']
  print(f'EPOCH: {epoch}, Mean loss: {mean(epoch_losses)}, Current LR: {current_lr}')
  # Check and update best loss
  if mean(epoch_losses) < best_loss:
      best_loss = mean(epoch_losses)
      # Save model weights with lowest loss
      torch.save(sam_model.state_dict(), 'sam_vit_b_best_loss.pth')
      print("Saved model with lower loss:", best_loss)
  # Update learning rate at the end of each epoch
  scheduler.step()

In [None]:
mean_losses = [mean(x) for x in losses]
mean_losses

plt.plot(list(range(len(mean_losses))), mean_losses)
plt.title('Mean epoch loss')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.show()

## Predict

In [None]:
with open('val.txt', 'r') as file:
    val_filenames = file.read().splitlines()
predict_checkpoint = 'sam_vit_b_best_loss.pth'
# Load the model for prediction with the specified checkpoint
sam_model_predict = sam_model_registry[model_type](checkpoint=predict_checkpoint)
sam_model_predict.to(device)
# Initialize the predictor
predictor = SamPredictor(sam_model_predict)
for filename in val_filenames:
    # Skip if bounding boxes are empty
    if len(image_data[filename]['bounding_boxes']) == 0:
        continue
    predicted_masks = []
    input_image = image_data[filename]['image']
    original_image_size = transformed_data[filename]['original_image_size']
    predictor.set_image(input_image)
    for i in range(len(image_data[filename]['bounding_boxes'])):
        prompt_box = image_data[filename]['bounding_boxes'][i]
        mask = image_data[filename]['masks'][i]
        foreground_point = image_data[filename]['foreground_points'][i]
        background_point = image_data[filename]['background_points'][i]
        prompt_box = np.array(prompt_box)
        prompt_point = [foreground_point, background_point]
        prompt_point = np.array(prompt_point, dtype=float)
        input_label = np.array([1, 0])
        masks = np.array(mask, dtype=float)
        masks, score, logits = predictor.predict(
            point_coords=prompt_point,
            point_labels=input_label,
            box=prompt_box,
            mask_input=masks[None, :, :],
            multimask_output=True,
        )
        # Find the mask with the highest score
        max_score = np.argmax(score)
        predicted_masks.append(masks[max_score])
    # Merge into a single mask, values greater than 1 after merging are set to 1
    predicted_mask = np.array(predicted_masks).sum(axis=0)
    predicted_mask = np.where(predicted_mask > 1, 1, predicted_mask)
    # Save as a binary mask
    predicted_mask = cv2.resize(predicted_mask, (1360, 1024), interpolation=cv2.INTER_NEAREST)
    cv2.imwrite(f'predict/{filename}.png', predicted_mask * 255)

# Metric calculation

In [None]:
from PIL import Image
import numpy as np
import os

# Define a function to calculate metrics for a single image
def calculate_metrics(gt_image, pred_image):
    ## Resize images to match the smaller one's dimensions
    if gt_image.size != pred_image.size:
        new_size = (min(gt_image.size[0], pred_image.size[0]), min(gt_image.size[1], pred_image.size[1]))
        gt_image = gt_image.resize(new_size, Image.Resampling.LANCZOS)
        pred_image = pred_image.resize(new_size, Image.Resampling.LANCZOS)

    # Convert images to grayscale if they are not
    if gt_image.mode != 'L':
        gt_image = gt_image.convert('L')
    if pred_image.mode != 'L':
        pred_image = pred_image.convert('L')

    # Load images and convert them to NumPy arrays
    gt = np.array(gt_image)
    pred = np.array(pred_image)

    # Ensure the dimensions of the images are the same
    if gt.shape != pred.shape:
        raise ValueError("Image dimensions do not match")

    # Calculate True Positives
    true_positives = np.sum(np.logical_and(gt == 255, pred == 255))

    # Calculate False Positives
    false_positives = np.sum(np.logical_and(gt == 0, pred == 255))

    # Calculate False Negatives
    false_negatives = np.sum(np.logical_and(gt == 255, pred == 0))

    # Calculate True Negatives
    true_negatives = np.sum(np.logical_and(gt == 0, pred == 0))

    # Calculate IoU (Intersection over Union)
    iou = true_positives / (true_positives + false_positives + false_negatives)

    # Calculate Dice coefficient
    dice = 2 * true_positives / (2 * true_positives + false_positives + false_negatives)

    # Calculate Precision
    if true_positives + false_positives > 0:
        precision = true_positives / (true_positives + false_positives)
    else:
        precision =1

    # Calculate FPR (False Positive Rate)
    fpr = false_positives / (false_positives + true_negatives)

    # Calculate PA
    pa = (true_positives + true_negatives) / (true_positives + false_positives + false_negatives + true_negatives)

    return iou, dice, precision, fpr,pa

# Define folder paths
gt_folder = "mask"
pred_folder = "predict"

# Get all image filenames in the folder
gt_files = os.listdir(gt_folder)

# Initialize lists to store metrics
ious = []
dices = []
precisions = []
fprs = []
pas=[]

# Iterate over each image file and calculate metrics
for filename in gt_files:
    gt_path = os.path.join(gt_folder, filename)
    pred_path = os.path.join(pred_folder, filename)

    # Open image files
    # Skip if unable to open
    try:
        gt_image = Image.open(gt_path)
        pred_image = Image.open(pred_path)
    except:
        continue

    iou, dice, precision, fpr, pa = calculate_metrics(gt_image, pred_image)

    ious.append(iou)
    dices.append(dice)
    precisions.append(precision)
    fprs.append(fpr)
    pas.append(pa)

# Calculate averages
mean_iou = np.mean(ious)
mean_dice = np.mean(dices)
mean_precision = np.mean(precisions)
mean_fpr = np.mean(fprs)
mean_pa=np.mean(pa)


# Print the results
print(f"IoU: {mean_iou:.8f}")
print(f"Dice Coefficient: {mean_dice:.8f}")
print(f"Precision: {mean_precision:.8f}")
print(f"False Positive Rate: {mean_fpr:.8f}")
print(f"PA: {mean_pa:.8f}")