[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/owenip/LNP-MOD/blob/dev/YoloV8Training.ipynb)

## Setup environment

### Make sure the notebook is using GPU via `nvidia-smi` command.
In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`.

*This tool can also work without GPU but the processing time will be significantly longer.*

In [None]:
# @title ##Run this cell to check if you have GPU access { display-mode: "form" }
import tensorflow as tf
from IPython import display


display.clear_output()

if tf.test.gpu_device_name() != '/device:GPU:0':
    print('\nYou do not have GPU access.')
    print('\nDid you change your runtime ?')
    print('\nIf the runtime setting is correct then Google did not allocate a GPU for your session')
    print('\nExpect slow performance. To access GPU try reconnecting later')

else:
    print('You have GPU access')
    !nvidia-smi

### **Mount your Google Drive**

In [None]:
# @title ##Play the cell to connect your Google Drive to Colab { display-mode: "form" }

# @markdown 1. Click 'Connect to Google Drive' at the pop up window

# @markdown 2. Sign in your Google Account.

# @markdown 3. Click 'Allow' to give this notebook access to the data on the drive

# @markdown * Once this is done, your data are available in the Files tab on the top left of notebook.

# Mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/drive')

# Force session restart
exit(0)

### Install required library

In [None]:
# Install ultralytics library
from IPython import display
%pip install -U ultralytics 

display.clear_output()
import ultralytics
ultralytics.checks()

### Setup Training Dataset and Model

In [None]:
# @title ### Setup Dataset and Model { display-mode: "form" }
# @markdown ### Provide the path to the dataset folder

import os
import glob

def check_file_exist(file_path):
    if os.path.exists(file_path) == False:
        raise Exception('File does not exist: ' + file_path)

def get_supported_images_path_list(images_folder):
    images = []
    # Exclude '*.jpg', '*.jpeg' as this format are always used for the overview image
    supported_images = ('*.tif', '*.tiff', '*.png', '*.bmp',
                        '*.dng', '*.webp', '*.pfm', '*.mpo')
    for image_type in supported_images:
        images.extend(glob.glob(images_folder + 'training/images/' + image_type))
        images.extend(glob.glob(images_folder + 'validation/images/' + image_type))


    return sorted(images)

path_to_dataset_folder = "/content/drive/Shareddrives/Bleb_Counting/dataset/"  # @param {type:"string"}
path_to_dataset_config_file = "/content/drive/Shareddrives/Bleb_Counting/dataset/data.yaml" # @param {type:"string"}
epochs = 200  # @param {type:"integer"}
batch_size = 16  # @param {type:"integer"}
workers = 8  # @param {type:"integer"}
image_width = 4096  # @param {type:"integer"}
image_height = 4224  # @param {type:"integer"}

# @markdown ### Image Augmentation
translate = 0.1  # @param {type:"number"}
scale = 0.5  # @param {type:"number"}
shear = 0.1  # @param {type:"number"}
flipud = 0.0  # @param {type:"number"}
fliplr = 0.5  # @param {type:"number"}
mosaic = 1.0  # @param {type:"number"}

### Resume Training
resume_training = False # @param {type:"boolean"}
path_to_model = ""  # @param {type:"string"}


check_file_exist(path_to_dataset_folder)
if resume_training:
    check_file_exist(path_to_model)
dataset_folder = os.path.join(path_to_dataset_folder)
predict_result_folder = os.path.join(path_to_dataset_folder + 'predict/')

num_of_supported_images = len(get_supported_images_path_list(dataset_folder))
if num_of_supported_images == 0:
    raise Exception('No supported images found in the dataset folder')
else:
    print('Number of supported images found in the dataset folder: ' + str(num_of_supported_images))


# @markdown #*Play the cell to ensure the dataset folder contains supported images*

#### Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir {path_to_dataset_folder}

In [None]:
from ultralytics import YOLO

if resume_training == True:
    model = YOLO(path_to_model)
else:
    model = YOLO(f'yolov8x.pt')

model.train(
    device=0,
    epochs=epochs,
    data=path_to_dataset_config_file,
    project=path_to_dataset_folder,
    imgsz=[image_height, image_width],
    plots=True,
    batch=batch_size,
    save=True,
    resume=resume_training,
    cache='disk',
    workers=workers,
    translate=translate,
    scale=scale,
    shear=shear,
    flipud=flipud,
    fliplr=fliplr,
    mosaic=mosaic
)


# Utilities

## Crop images with Albumentation

In [None]:
%pip install -U albumentations

## Patchify Images and annotations

### Installed requried functions and library

In [None]:
%pip install patchify

In [51]:
import patchify
import numpy as np
import cv2
import matplotlib.pyplot as plt
from decimal import Decimal, getcontext
from PIL import Image
import os
import pandas as pd

getcontext().prec = 6

BLEB_WITH_MRNA = 0
ICE_CRYSTAL = 1
OIL_CORE = 2
OTHER_LNP = 4
NOT_FULLY_VISIBLE_LNP = 3

def create_folder_if_not_exist(folder_path):
  if os.path.isdir(folder_path) == False:
    os.mkdir(folder_path)


def is_valid_annotation_after_splitting(standard_bbox):
    is_valid = True
    label, x_min, y_min, x_max, y_max, bbox_width, bbox_height = standard_bbox

    if bbox_width < 1 or bbox_height < 1:
        is_valid = False
    if bbox_width > 2048 or bbox_height > 2048:
        is_valid = False
    if x_min < 0 or x_max > 2048:
        is_valid = False
    if y_min < 0 or y_max > 2048:
        is_valid = False

    return is_valid


def is_valuable_annotation(standard_bbox):
    is_valuable = True
    bbox_area_threshold = 10
    bbox_width_threshold = 5
    bbox_height_threshold = 5
    label, x_min, y_min, x_max, y_max, bbox_width, bbox_height = standard_bbox

    if (bbox_width * bbox_height) < bbox_area_threshold or bbox_width < bbox_width_threshold or bbox_height < bbox_height_threshold:
        is_valuable = False

    return is_valuable


def check_bbox(bbox, x_range=[0, 1], y_range=[0, 1]):
    _bbox = bbox[1:]
    _bbox = map(Decimal, _bbox)
    x_min, y_min, x_max, y_max, bbox_width, bbox_height = _bbox

    for name, value in zip(['x_min', 'x_max'], [x_min, x_max]):
        if not x_range[0] <= value <= x_range[1] and not np.isclose(float(value), x_range[0]) and not np.isclose(float(value), x_range[1]):
            raise ValueError(
                f"Expected {name} for bbox {bbox} to be in {x_range}, got {value}")

    for name, value in zip(['y_min', 'y_max'], [y_min, y_max]):
        if not y_range[0] <= value <= y_range[1] and not np.isclose(float(value), y_range[0]) and not np.isclose(float(value), y_range[1]):
            raise ValueError(
                f"Expected {name} for bbox {bbox} to be in {y_range}, got {value}")

    if x_max < x_min:
        raise ValueError(
            f"xmax is less than or equal to x_min for bbox {bbox}")
    if y_max < y_min:
        raise ValueError(
            f"ymax is less than or equal to y_min for bbox {bbox}")


def check_yolo_box(yolo_bbox):
    check_bbox(yolo_bbox, x_range=[0, 1], y_range=[0, 1])
    for name, value in zip(['bbox_width', 'bbox_height'], yolo_bbox[5:]):
        if not 0 <= value <= 1 and not np.isclose(value, 0) and not np.isclose(value, 1):
            raise ValueError(
                f"Expected {name} for bbox {yolo_bbox} to be in [0.0, 1.0], got {value}")


def fix_bbox(bbox, x_range=[0, 1], y_range=[0, 1]):
    _bbox = bbox[1:]
    _bbox = map(Decimal, _bbox)
    x_min, y_min, x_max, y_max, bbox_width, bbox_height = _bbox

    if x_min < x_range[0]:
        print('x_min < x_range[0]', x_min, x_range[0])
    
    if x_max > x_range[1]:
        print('x_max > x_range[1]', x_max, x_range[1])
    
    if y_min < y_range[0]:
        print('y_min < y_range[0]', y_min, y_range[0])
    
    if y_max > y_range[1]:
        print('y_max > y_range[1]', y_max, y_range[1])
    
    return [bbox[0], x_min.max(x_range[0]), y_min.max(y_range[0]), x_max.min(x_range[1]), y_max.min(y_range[1]), bbox_width, bbox_height]


def fix_bboxes(bboxes, x_range=[0, 1], y_range=[0, 1]):
  for index, bbox in enumerate(bboxes):
    if len(bbox) == 0:
      continue
    bboxes[index] = fix_bbox(bbox, x_range, y_range)

  return bboxes


def update_edge_label(bbox, from_labels=[BLEB_WITH_MRNA, OTHER_LNP], to_label=NOT_FULLY_VISIBLE_LNP):
    if bbox[0] in from_labels:
        bbox[0] = to_label

    return bbox


def update_edge_labels(bboxes, from_labels=[BLEB_WITH_MRNA, OTHER_LNP], to_label=NOT_FULLY_VISIBLE_LNP):
    for index, bbox in enumerate(bboxes):
        if len(bbox) == 0:
            continue
        bboxes[index] = update_edge_label(bbox, from_labels, to_label)

    return bboxes

def update_on_edge_LNP_labels(bboxes, from_labels=[BLEB_WITH_MRNA, OTHER_LNP], to_label=NOT_FULLY_VISIBLE_LNP):
    return update_edge_labels(bboxes, from_labels, to_label)

def check_splitted_annotation(regions):
    for index, bbox in enumerate(regions):
        if len(bbox) == 0:
            break

        if is_valid_annotation_after_splitting(bbox) == False:
            print('Invalid annotation at region', index,  bbox)
            regions[index] = []

        check_bbox(bbox, x_range=[0, 2048], y_range=[0, 2048])

    # Remove splitted annotation is not valuable
    for index, bbox in enumerate(regions):
        if len(bbox) == 0:
            break

        if is_valuable_annotation(bbox) == False:
            print('Not valuable annotation at region', index,  bbox)
            regions[index] = []

    return regions


def convert_yolo_bbox_to_standard_bbox(yolo_bbox, img_width=4096, img_height=4224):
    # Convert yolo normalized coordinates to normal coordinates
    label, norm_x_center, norm_y_center, norm_width, norm_height = yolo_bbox
    x_center = Decimal(norm_x_center) * img_width
    y_center = Decimal(norm_y_center) * img_height

    bbox_width = Decimal(norm_width) * img_width
    bbox_height = Decimal(norm_height) * img_height

    x_min = x_center - (bbox_width/2)
    x_max = x_center + (bbox_width/2)
    y_min = y_center - (bbox_height/2)
    y_max = y_center + (bbox_height/2)

    return [label, x_min, y_min, x_max, y_max, bbox_width, bbox_height]


def convert_standard_bbox_to_yolo_format(standard_bbox, img_width=2048, img_height=2048):
    # Convert normal coordinates to yolo normalized coordinates
    label = standard_bbox[0]
    _standard_bbox = standard_bbox[1:]
    _standard_bbox = map(Decimal, _standard_bbox)
    x_min, y_min, x_max, y_max, bbox_width, bbox_height = _standard_bbox

    x_center = (x_min + x_max) / Decimal(2)
    y_center = (y_min + y_max) / Decimal(2)
    height = y_max - y_min
    width = x_max - x_min

    norm_x_center = x_center / img_width
    norm_y_center = y_center / img_height
    norm_width = width / img_width
    norm_height = height / img_height

    return [label, norm_x_center, norm_y_center, norm_width, norm_height]


def split_annotation_accross_x_axis_left_of_y_axis(standard_bbox, patch_width=2048, patch_height=2048):
    # Convert annotation that is splitted by x axis
    label, x_min, y_min, x_max, y_max, bbox_width, bbox_height = standard_bbox

    bbox_1 = [label, x_min, y_min, x_max, patch_height, bbox_width, patch_height - y_min]
    bbox_2 = [label, x_min, 0, x_max, y_max - patch_height, bbox_width, y_max - patch_height]

    region_bboxes = fix_bboxes([bbox_1, [],  bbox_2, []], [0, patch_width - 1], [0, patch_height - 1])

    return check_splitted_annotation(region_bboxes)


def split_annotation_accross_x_axis_right_of_y_axis(standard_bbox, patch_width=2048, patch_height=2048):
    # Convert annotation that is splitted by x axis
    label, x_min, y_min, x_max, y_max, bbox_width, bbox_height = standard_bbox

    bbox_1 = [label, x_min - patch_width, y_min, x_max - patch_width, patch_height, bbox_width, patch_height - y_min]
    bbox_2 = [label, x_min - patch_width, 0, x_max - patch_width, y_max - patch_height, bbox_width, y_max - patch_height]

    region_bboxes = fix_bboxes([[], bbox_1, [], bbox_2], [
                               0, patch_width - 1], [0, patch_height - 1])

    return check_splitted_annotation(region_bboxes)


def split_annotation_accross_y_axis_above_x_axis(standard_bbox, patch_width=2048, patch_height=2048):
    # Convert annotation that is splitted by y axis above x axis
    label, x_min, y_min, x_max, y_max, bbox_width, bbox_height = standard_bbox

    bbox_1 = [label, x_min, y_min, patch_width, y_max, patch_width - x_min, bbox_height]
    bbox_2 = [label, 0, y_min, x_max - patch_width, y_max, x_max - patch_width, bbox_height]

    region_bboxes = fix_bboxes([bbox_1, bbox_2, [], []], [0, patch_width - 1], [0, patch_height - 1])

    return check_splitted_annotation(region_bboxes)


def split_annotation_accross_y_axis_below_x_axis(standard_bbox, patch_width=2048, patch_height=2048):
    # Convert annotation that is splitted by y axis below x axis
    label, x_min, y_min, x_max, y_max, bbox_width, bbox_height = standard_bbox

    bbox_1 = [label, x_min, y_min - patch_height, patch_width,
              y_max - patch_height, patch_width - x_min, bbox_height]
    bbox_2 = [label, 0, y_min - patch_height, x_max - patch_width, y_max - patch_height, x_max - patch_width, bbox_height]

    region_bboxes = fix_bboxes([[], [], bbox_1, bbox_2], [
                               0, patch_width - 1], [0, patch_height - 1])

    return check_splitted_annotation(region_bboxes)


def split_annotation_accross_center(standard_bbox, patch_width=2048, patch_height=2048):
    # Convert annotation that is splitted by center
    label, x_min, y_min, x_max, y_max, bbox_width, bbox_height = standard_bbox

    bbox_1 = [label, x_min, y_min, patch_width, patch_height, patch_width - x_min, patch_height - y_min]
    bbox_2 = [label, 0, y_min, x_max - patch_width, patch_height, x_max - patch_width, patch_height - y_min]
    bbox_3 = [label, x_min, 0, patch_width, y_max - patch_height, patch_width - x_min, y_max - patch_height]
    bbox_4 = [label, 0, 0, x_max - patch_width, y_max - patch_height, x_max - patch_width, y_max - patch_height]

    region_bboxes = fix_bboxes([bbox_1, bbox_2, bbox_3, bbox_4], [
                               0, patch_width - 1], [0, patch_height - 1])

    return check_splitted_annotation(region_bboxes)


# Regions are divided into 4 parts: top left, top right, bottom left, bottom right

def patchify_bbox(standard_bbox, patch_width=2048, patch_height=2048, update_on_edge_labels=True):
    result = []
    class_name, x_min, y_min, x_max, y_max, bbox_width, bbox_height = standard_bbox

    if x_min < patch_width:
        if y_min < patch_height:
            if x_max < patch_width:
                if y_max < patch_height:
                    result = [[class_name, x_min, y_min, x_max, y_max, bbox_width, bbox_height], [], [], []]
                else:
                    result = split_annotation_accross_x_axis_left_of_y_axis(standard_bbox, patch_width, patch_height)
            else:
                if y_max < patch_height:
                    result = split_annotation_accross_y_axis_above_x_axis(
                        standard_bbox, patch_width, patch_height)
                else:
                    result = split_annotation_accross_center(standard_bbox, patch_width, patch_height)
        else:
            if x_max < patch_width:
                result = [[], [], [class_name, x_min, y_min - patch_height, x_max, y_max - patch_height, bbox_width, bbox_height], []]
            else:
                result = split_annotation_accross_y_axis_below_x_axis(standard_bbox, patch_width, patch_height)
    else:
        if y_min < patch_height:
            if y_max < patch_height:
                result = [[], [class_name, x_min - patch_width, y_min, x_max - patch_width, y_max, bbox_width, bbox_height], [], []]
            else:
                result = split_annotation_accross_x_axis_right_of_y_axis(standard_bbox, patch_width, patch_height)
        else:
            result = [[], [], [], [class_name, x_min - patch_width, y_min - patch_height, x_max - patch_width, y_max - patch_height, bbox_width, bbox_height]]

    if update_on_edge_labels == True:
        result = update_on_edge_LNP_labels(result)

    return result

In [7]:
# @title ### Pachify Paramters { display-mode: "form" }
# @markdown ### Provide the path to the dataset folder
patch_width = 2048  # @param {type:"integer"}
patch_height = 2048  # @param {type:"integer"}
patch_step = 2048  # @param {type:"integer"}

dataset = ''  # @param {type:"string"}
dataset_folder = os.path.join(dataset, '')
images_folder = os.path.join(dataset_folder, 'images', '')
labels_folder = os.path.join(dataset_folder, 'labels', '')

output_folder = '' # @param {type:"string"}
output_folder = os.path.join(output_folder, '')
output_images_folder = os.path.join(output_folder, 'images', '')
output_labels_folder = os.path.join(output_folder, 'labels', '')

create_folder_if_not_exist(output_folder)
create_folder_if_not_exist(output_images_folder)
create_folder_if_not_exist(output_labels_folder)

### Generate Patches

In [12]:
# @title ##Play the cell to generate patches { display-mode: "form" }

import glob
# Only process images in the images folder
getcontext().prec = 6
images_total_count = len(glob.glob1(images_folder, '*.png'))
images_list = os.listdir(images_folder)

for index, image in enumerate(images_list):
    if not image.endswith('.png') and not image.endswith('.tif'):
        continue

    filename = image.split('.')[0]
    label_file = os.path.join(labels_folder, filename + '.txt')
    print(f"Processing:{index + 1}/{images_total_count}: {image}")

    # Check if label file exists
    if os.path.isfile(label_file) == False:
        print('Label file does not exist: ', label_file)
        continue

    # Patchify image
    img = cv2.imread(os.path.join(images_folder, image), 0)
    img_h, img_w = img.shape
    patches = patchify.patchify(
        img, (patch_width, patch_height), step=patch_step)
    counter = 0
    for i in range(patches.shape[0]):
        for j in range(patches.shape[1]):
            single_patch = patches[i, j, :, :]
            plt.imsave(os.path.join(output_images_folder,
                       f"{filename}_patch_{str(counter)}.png"), single_patch, cmap='gray')
            counter += 1

    # Read label file
    with open(label_file, 'r') as f:
        label_lines = f.readlines()

    # Split annotation
    converted_bboxes = [[], [], [], []]
    for line in label_lines:
        line = line.split(' ')
        standard_bbox = convert_yolo_bbox_to_standard_bbox(
            line, img_w, img_h)

        # Convert to patch coordinates
        patchify_bboxes = patchify_bbox(
            standard_bbox, patch_width, patch_height)
        for index, bboxes in enumerate(converted_bboxes):
            converted_bboxes[index].append(patchify_bboxes[index])

    # Write label file in yolo format
    for index, bboxes in enumerate(converted_bboxes):
        patch_yolo_bboxes = []
        for bbox in bboxes:
            if len(bbox) == 0:
                continue
            patch_yolo_bboxes.append(convert_standard_bbox_to_yolo_format(
                bbox, patch_width, patch_height))

        pd.DataFrame(patch_yolo_bboxes).to_csv(output_labels_folder + filename +
                                               '_patch_' + str(index) + '.txt', sep=' ', header=False, index=False)

Processing image:  Sample7_73kx_0001.tif


Generate Patches with COCO labels

In [56]:
# @title ### Pachify Paramters { display-mode: "form" }
# @markdown ### Provide the path to the dataset folder

import glob

def create_folder_if_not_exist(folder_path):
  if os.path.isdir(folder_path) == False:
    os.mkdir(folder_path)


patch_width = 2048  # @param {type:"integer"}
patch_height = 2048  # @param {type:"integer"}
patch_step = 2048  # @param {type:"integer"}

images_folder = ''  # @param {type:"string"}
if os.path.isdir(images_folder) == False:
    print('Images folder does not exist: ', images_folder)

num_of_supported_images = len(glob.glob1(images_folder, '*.png'))
if num_of_supported_images== 0:
    print('No supported images found in the images folder')
else:
   print(num_of_supported_images, 'supported images found in the images folder')

coco_labels_folder = '' # @param {type:"string"}
coco_labels_file = os.path.join(coco_labels_folder)
#check if coco_labels_file exists
if os.path.isfile(coco_labels_file) == False:
    print('Label file does not exist: ', coco_labels_file)

output_images_folder = '' # @param {type:"string"}
output_images_folder = os.path.join(output_images_folder)
create_folder_if_not_exist(output_images_folder)

output_labels_folder = '' # @param {type:"string"}
output_labels_folder = os.path.join(output_labels_folder)
create_folder_if_not_exist(output_labels_folder)

15 supported images found in the images folder


In [59]:
import glob
import os
import cv2
import json
import albumentations as A

transform_1 = A.Compose([
    A.Crop(x_min=0, y_min=0, x_max=2048, y_max=2048, always_apply=True),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

transform_2 = A.Compose([
    A.Crop(x_min=2049, y_min=0, x_max=4096, y_max=2048, always_apply=True),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

transform_3 = A.Compose([
    A.Crop(x_min=0, y_min=2049, x_max=2048, y_max=4096, always_apply=True),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

transform_4 = A.Compose([
    A.Crop(x_min=2049, y_min=2049, x_max=4096, y_max=4096, always_apply=True),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

transform_horizontal_flip = A.Compose([
    A.HorizontalFlip(p=1),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

# Vertical flip
transform_vertical_flip = A.Compose([
    A.VerticalFlip(p=1),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

# 90 degree rotation clockwise
transform_rotate90cw = A.Compose([
    A.Rotate(p=1, limit=[-90, -90]),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

# 180 degree rotation clockwise
transform_rotate180cw = A.Compose([
    A.Rotate(p=1, limit=[-180, -180]),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

# 270 degree rotation clockwise
transform_rotate270cw = A.Compose([
    A.Rotate(p=1, limit=[-270, -270]),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))

def transform_image(image, bboxes, category_ids, transform_function):
    transformed = transform_function(image=image, bboxes=bboxes,
                                     category_ids=category_ids)
    
    return transformed

def save_transform_image(transformed, output_annotations_data, new_image_name, output_image_folder):
    transformed_image = transformed['image']
    transformed_bboxes = transformed['bboxes']
    full_image_name = f"{new_image_name}.png"

    output_image_path = os.path.join(output_image_folder, full_image_name)
    transformed_image = cv2.cvtColor(transformed_image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(output_image_path, transformed_image)

    new_image_id = len(output_annotations_data["images"])

    output_annotations_data["images"].append({
        "id": new_image_id,
        "height": transformed_image.shape[0],
        "width": transformed_image.shape[1],
        "file_name": full_image_name,
    })

    # Append transformed annotations
    for bbox, category_id in zip(transformed_bboxes, transformed['category_ids']):
        new_annotation_id = len(output_annotations_data["annotations"])
        output_annotations_data["annotations"].append({
            'id': new_annotation_id,
            'image_id': new_image_id,
            'category_id': 3 if is_bbox_touch_edge(bbox) and category_id != 1 else category_id,
            'bbox': list(bbox),
            'iscrowd': 0,
            # assuming bbox is in format [x, y, width, height]
            'area': bbox[2] * bbox[3]
        })

    return output_annotations_data

def is_bbox_touch_edge(bbox, patch_width=2048, patch_height=2048):
    x_min, y_min, bbox_width, bbox_height = bbox
    x_max = x_min + bbox_width
    y_max = y_min + bbox_height

    x_axis_threshold = 5
    y_axis_threshold = 5

    if x_min <= x_axis_threshold or x_max >= patch_width - x_axis_threshold:
        return True
    
    if y_min <= y_axis_threshold or y_max >= patch_height - y_axis_threshold:
        return True
    
    return False

def process_images(coco_labels_file, images_folder, output_image_folder, output_label_folder):
    with open(coco_labels_file, 'r') as f:
        annotations_data = json.load(f)

    output_annotations_data = {
        "images": [],
        "annotations": [],
        "categories": annotations_data["categories"]
    }
    images_total_count = len(annotations_data['images'])

    # Iterate all annoations and split them base on images
    bboxes_by_image = {}
    category_ids_by_image = {}
    for ann in annotations_data['annotations']:
        image_id = ann['image_id']
        if image_id not in bboxes_by_image:
            bboxes_by_image[image_id] = []
        if image_id not in category_ids_by_image:
            category_ids_by_image[image_id] = []
        
        bboxes_by_image[image_id].append(ann['bbox'])
        category_ids_by_image[image_id].append(ann['category_id'])

    for img_data in annotations_data['images']:
        print(f"Process image {img_data['file_name']} ({img_data['id'] + 1}/{images_total_count})")
        image_path = os.path.join(images_folder, img_data['file_name'])
        image = cv2.imread(image_path)
        image_name = img_data['file_name'].split('.')[0]
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        bboxes = bboxes_by_image[img_data['id']]
        category_ids = category_ids_by_image[img_data['id']]
        
        crop_functions = [
            transform_1,
            transform_2,
            transform_3,
            transform_4
        ]

        augmentation_functions = {
            'horizontal_flip': transform_horizontal_flip,
            'vertical_flip': transform_vertical_flip,
            'rotate90cw': transform_rotate90cw,
            'rotate180cw': transform_rotate180cw,
            'rotate270cw': transform_rotate270cw
        }

        for patch_index, augmentation_function in enumerate(crop_functions):
            transform_result = transform_image(image, bboxes, category_ids, augmentation_function)
            patch_file_name = f"{image_name}_patch_{patch_index}"
            # output_annotations_data = save_transform_image(
            #     transform_result, output_annotations_data, patch_file_name, output_image_folder)
            
            for augmentation_index, augmentation_function in augmentation_functions.items():
                transform_result = transform_image(
                    transform_result["image"], transform_result["bboxes"], transform_result['category_ids'], augmentation_functions[augmentation_index])
                output_annotations_data = save_transform_image(
                    transform_result, output_annotations_data, f"{patch_file_name}_{augmentation_index}", output_image_folder)    
        
    
    print('Saving annotations file')
    output_annotations_file = os.path.join(output_label_folder, f"coco.json")
    with open(output_annotations_file, 'w') as f:
        json.dump(output_annotations_data, f)
    print('Done')

In [None]:
process_images(coco_labels_file, images_folder, output_images_folder, output_labels_folder)


Covert COCO labels to YOLO

In [None]:
# @title ### Covert COCO labels to YOLO { display-mode: "form" }

import glob
import os
import json
import pandas as pd

def convert_coco_bbox_to_yolo_format(category_id, coco_bbox, img_width=2048, img_height=2048):
    label = category_id
    x_min, y_min, bbox_width, bbox_height = coco_bbox

    x_center = x_min + (bbox_width / 2)
    y_center = y_min + (bbox_height / 2)

    norm_x_center = x_center / img_width
    norm_y_center = y_center / img_height
    norm_width = bbox_width / img_width
    norm_height = bbox_height/ img_height

    return [label, norm_x_center, norm_y_center, norm_width, norm_height]


coco_labels_file = ''  # @param {type:"string"}
coco_labels_file = os.path.join(coco_labels_file)
#check if coco_labels_file exists
if os.path.isfile(coco_labels_file) == False:
    print('Label file does not exist: ', coco_labels_file)

output_yolo_labels_folder = '' # @param {type:"string"}
output_yolo_labels_folder = os.path.join(output_yolo_labels_folder)
if os.path.isdir(output_yolo_labels_folder) == False:
    os.mkdir(output_yolo_labels_folder)

with open(coco_labels_file, 'r') as f:
    annotations_data = json.load(f)

images_total_count = len(annotations_data['images'])
# Iterate all annoations and split them base on images
bboxes_by_image = {}
category_ids_by_image = {}
for ann in annotations_data['annotations']:
    image_id = ann['image_id']
    if image_id not in bboxes_by_image:
        bboxes_by_image[image_id] = []
    if image_id not in category_ids_by_image:
        category_ids_by_image[image_id] = []

    bboxes_by_image[image_id].append(ann['bbox'])
    category_ids_by_image[image_id].append(ann['category_id'])

#iterate all images to generate yolo labels
for img_data in annotations_data['images']:
    print(f"Process image {img_data['file_name']} ({img_data['id'] + 1}/{images_total_count})")
    image_name = img_data['file_name'].split('.')[0]
    label_file = os.path.join(output_yolo_labels_folder, image_name + '.txt')

    bboxes = bboxes_by_image[img_data['id']]
    category_ids = category_ids_by_image[img_data['id']]

    yolo_bboxes = []
    for bbox, category_id in zip(bboxes, category_ids):
        yolo_bboxes.append(convert_coco_bbox_to_yolo_format(category_id, bbox, img_data['width'], img_data['height']))

    pd.DataFrame(yolo_bboxes).to_csv(label_file, sep=' ', header=False, index=False)