#### This script converts format of object detection annotation file from YOLO square bounding box to YOLO mask/polygon.
#### The conversion is done via Segment Anything Model (SAM).

### 1. Import Libraries

In [1]:
import os
import cv2
import torch
import numpy as np
import sys
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt

# Append path to Segment Anything Model (SAM)
sys.path.append("..")  # NOTE: Adjust this path based on your SAM installation
from segment_anything import sam_model_registry, SamPredictor
# from segment_anything.utils.transforms import ResizeLongestSide


In [2]:
# Define Global Hyperparameters
EPS = 0.01     # larger value results fewer vertices for a polygon

### 2. Define Uitility Fucntions

In [4]:
def query_valid_annotation(input_dir_path, output_dir_path):
    non_empty_files = []
    for filename in os.listdir(input_dir_path):
        file_path = os.path.join(input_dir_path, filename)
        if os.path.isfile(file_path) and filename.endswith('.txt'):
            if os.path.getsize(file_path) > 0:
                non_empty_files.append(file_path)
            else:
                # copy txt file to output_dir_path
                shutil.copy(file_path, os.path.join(output_dir_path, filename))
    return non_empty_files

def query_bboxes(image_height, image_width, label_path):
    converted_annotations = []
    with open(label_path, 'r') as file:
        for line in file:
            if line.strip() == '':
                continue
            obj_class, x_center, y_center, width, height = map(float, line.split())
            x_center, y_center = x_center * image_width, y_center * image_height
            width, height = width * image_width, height * image_height
            top_left_x = x_center - (width / 2)
            top_left_y = y_center - (height / 2)
            bottom_right_x = x_center + (width / 2)
            bottom_right_y = y_center + (height / 2)
            converted_annotations.append([top_left_x, top_left_y, bottom_right_x, bottom_right_y])
    return converted_annotations

def prepare_image(image_path, transform, device):
    # Prepare Images for SAM
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = transform.apply_image(image)
    image = torch.as_tensor(image, device=device)
    return image.permute(2, 0, 1).contiguous()

def mask_to_polygons(mask, image_h, image_w):
    """
    Converts a mask (numpy array) of multiple object instances to a list of polygons
        with normalized vertex values.
    """
    # Perform erosion followed by dilation (opening)
    kernel = np.ones((5,5),np.uint8)
    mask = cv2.erode(mask, kernel, iterations = 1)
    mask = cv2.dilate(mask, kernel, iterations = 1)
    # Find contours
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    polygons = []
    # for our application, there is just one contour for one object
    for i, contour in enumerate(contours):
        # Approximate contour to polygon
        epsilon = EPS * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True).astype(float)

        # Flatten and convert to list of points
        polygon = approx.flatten().reshape(-1, 2)

        # Normalize the coordinate values with respect to image width and height
        polygon[:, 0] /= image_w
        polygon[:, 1] /= image_h
        
        polygon = polygon.tolist()
        polygons.append(polygon)
    return polygons

def save_polygons(class_idx, polygons, filename):
    """
    Saves polygons to a text file.

    class_idx (str): class index of all polygons
    polygons (list): shape is N x M x 2, N is number of polygons (objects), M is number of vertices, 
                        and vertex coordinates in [x, y]
    filename (str): text file path to save
    """
    with open(filename, 'w') as file:
        for polygon in polygons:
            # Format polygon points for saving
            # Assuming the format: 'x1,y1 x2,y2 x3,y3,...'
            line = class_idx + ' ' + ' '.join([f'{coord[0]} {coord[1]}' for coord in polygon])
            file.write(f'{line}\n')

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

def visualize_mask_image(image, masks, boxes_tensor):
    # Setup a figure with 2 subplots (1 row, 2 columns)
    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    # First subplot with masks and boxes
    axs[0].imshow(image)
    for mask in masks:
        # print(np.unique(mask.cpu().numpy(), return_counts=True))
        show_mask(mask.cpu().numpy(), axs[0], random_color=True)  # Pass the appropriate Axes object
    for box in boxes_tensor:
        show_box(box.cpu().numpy(), axs[0])  # Pass the appropriate Axes object
    axs[0].axis('off')  # Turn off axis for the first subplot

    # Second subplot with only the image
    axs[1].imshow(image)
    axs[1].axis('on')  # Keep the axis on for the second subplot

    plt.tight_layout()  # Adjust the layout to make sure there's no overlap
    plt.show()

### 3. Load Segment Anython Model (SAM)

In [5]:
sam_checkpoint = "/home/psc/Desktop/segment-anything/assets/weights/sam_vit_h_4b8939.pth"  # NOTE: change this to custom path to SAM weight
model_type = "vit_h"
assert torch.cuda.is_available()
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

### 4. Main Processing Loop

In [7]:
# Specify data root path
data_root = '/home/psc/Desktop/PSC/Custom_Dataset/custom-dataset-4120.v2i.yolov8/train'     # NOTE: change this to your dataset path
image_dir = os.path.join(data_root, 'images')
label_dir = os.path.join(data_root, 'labels')
# Specify polygon label save directory
polygon_label_dir = os.path.join(data_root, 'labels-polygon')
os.makedirs(polygon_label_dir, exist_ok=True)

non_empty_files = query_valid_annotation(label_dir, polygon_label_dir)

# Convert the image / label one at a time
for i, label_path in enumerate(tqdm(non_empty_files, total=len(non_empty_files))):

    # if i != 6: continue
    polygon_label_path = os.path.join(polygon_label_dir, os.path.basename(label_path))
    # print(polygon_label_path)

    image_name = os.path.basename(label_path).replace('.txt', '.jpg')
    image_path = os.path.join(image_dir, image_name)
    assert os.path.exists(image_path), f"{image_path} not found."
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_h, image_w = image.shape[:2]
    boxes = query_bboxes(image_h, image_w, label_path)
    # Convert boxes to tensor and transform to current image
    boxes_tensor = torch.tensor(boxes, dtype=torch.float, device=device)
    transformed_boxes = predictor.transform.apply_boxes_torch(boxes_tensor, image.shape[:2])

    # # Preparing image for SAM
    # prepared_image = prepare_image(image_path, ResizeLongestSide(sam.image_encoder.img_size), device)
    
    # Process the image to produce an image embedding for mask predicton
    predictor.set_image(image)
    
    # Make mask prediciton
    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )

    # Optionally Visualize the mask with original image
    # visualize_mask_image(image, masks, boxes_tensor)

    
    mask = masks[0]
    
    # Assume each mask only contains one instance
    # So here we perform bitwise-or for all masks of individual objects of the same class
    for m in masks[1:]:
        mask += m
    
    mask = torch.squeeze(mask, 0).cpu().numpy().astype(np.uint8)

    # Convert mask image to polygons
    polygons = mask_to_polygons(mask, image_h, image_w)
    
    save_polygons('0', polygons, polygon_label_path)


100%|██████████| 645/645 [07:41<00:00,  1.40it/s]
