In [None]:
import albumentations as A
import cv2
import os
import numpy as np

In [None]:
# Input/output directories (adjust paths)
input_img_dir = '../data/coral_bleaching/reef_support/UNAL_BLEACHING_TAYRONA/images'
input_bleached_mask_dir = '../data/coral_bleaching/reef_support/UNAL_BLEACHING_TAYRONA/masks_bleached'
input_nonbleached_mask_dir = '../data/coral_bleaching/reef_support/UNAL_BLEACHING_TAYRONA/masks_non_bleached'
output_img_dir = '../data/coral_bleaching/augmented/images'
output_bleached_mask_dir = '../data/coral_bleaching/augmented/masks_bleached'
output_nonbleached_mask_dir = '../data/coral_bleaching/augmented/masks_non_bleached'

os.makedirs(output_img_dir, exist_ok=True)
os.makedirs(output_bleached_mask_dir, exist_ok=True)
os.makedirs(output_nonbleached_mask_dir, exist_ok=True)

num_augmentations = 3

In [None]:
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=20, p=0.5, interpolation=cv2.INTER_NEAREST),
    A.RandomScale(scale_limit=0.2, p=0.5, interpolation=cv2.INTER_NEAREST),
    A.RandomBrightnessContrast(p=0.3),
    A.HueSaturationValue(p=0.3),
], additional_targets={'bleached_mask': 'mask', 'nonbleached_mask': 'mask'})

In [None]:
for img_name in os.listdir(input_img_dir):
    if not img_name.endswith(('.jpg', '.JPG')):
        continue

    img_path = os.path.join(input_img_dir, img_name)
    img = cv2.imread(img_path)

    base_name = os.path.splitext(img_name)[0]
    bleached_mask_path = os.path.join(input_bleached_mask_dir, f"{base_name}_bleached.png")
    nonbleached_mask_path = os.path.join(input_nonbleached_mask_dir, f"{base_name}_non_bleached.png")
    if not (os.path.exists(bleached_mask_path) and os.path.exists(nonbleached_mask_path)):
        continue
    bleached_mask = cv2.imread(bleached_mask_path, cv2.IMREAD_GRAYSCALE)
    nonbleached_mask = cv2.imread(nonbleached_mask_path, cv2.IMREAD_GRAYSCALE)

    for aug_idx in range(num_augmentations):
        augmented = transform(image=img, bleached_mask=bleached_mask, nonbleached_mask=nonbleached_mask)

        aug_img = augmented['image']
        aug_bleached_mask = augmented['bleached_mask']
        aug_nonbleached_mask = augmented['nonbleached_mask']

        aug_bleached_mask = np.where(aug_bleached_mask > 127, 255, 0).astype(np.uint8)
        aug_nonbleached_mask = np.where(aug_nonbleached_mask > 127, 255, 0).astype(np.uint8)

        aug_base_name = f"aug_{aug_idx}_{base_name}"
        aug_img_name = f"{aug_base_name}{os.path.splitext(img_name)[1]}"
        cv2.imwrite(os.path.join(output_img_dir, aug_img_name), aug_img)
        cv2.imwrite(os.path.join(output_bleached_mask_dir, f"{aug_base_name}_bleached.png"), aug_bleached_mask)
        cv2.imwrite(os.path.join(output_nonbleached_mask_dir, f"{aug_base_name}_nonbleached.png"), aug_nonbleached_mask)