In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm
from albumentations import Compose, HorizontalFlip, Rotate, RandomBrightnessContrast
from albumentations.pytorch import ToTensorV2

# Path configuration
base_path = "/content/drive/MyDrive/Colab Notebooks/2024_CV_project/dataset"
image_dir = os.path.join(base_path, "images/train")
label_dir = os.path.join(base_path, "labels/train")
augmented_image_dir = os.path.join(base_path, "images/augmentation/train")
augmented_label_dir = os.path.join(base_path, "labels/augmentation/train")
os.makedirs(augmented_image_dir, exist_ok=True)
os.makedirs(augmented_label_dir, exist_ok=True)

# Define Albumentations transforms
transform = Compose(
    [
        HorizontalFlip(p=0.5),  # Horizontal flip
        Rotate(limit=10, p=0.5),  # Rotate within Â±10 degrees
        RandomBrightnessContrast(
            brightness_limit=0.2, contrast_limit=0.2, p=0.5
        ),  # Adjust brightness/contrast
    ],
    keypoint_params={"format": "xy", "remove_invisible": False}  # Handle bbox corner keypoints
)

# Read a YOLO label file
def read_yolo_label(label_path):
    with open(label_path, "r") as file:
        lines = file.readlines()

    bboxes = []
    for line in lines:
        data = line.strip().split()
        class_id = int(data[0])
        # Convert flat list into (N, 2) array of (x, y) keypoints
        points = np.array(data[1:], dtype=np.float32).reshape(-1, 2)
        bboxes.append((class_id, points))
    return bboxes

# Write a YOLO label file
def write_yolo_label(label_path, bboxes):
    with open(label_path, "w") as file:
        for class_id, points in bboxes:
            points_flat = points.flatten()
            line = f"{class_id} " + " ".join(map(str, points_flat))
            file.write(line + "\n")

# Apply augmentation and save augmented image/labels
def augment_and_save(image_path, label_path, output_image_dir, output_label_dir):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Read original labels
    bboxes = read_yolo_label(label_path)
    class_ids = [bbox[0] for bbox in bboxes]
    keypoints = [bbox[1] for bbox in bboxes]  # bbox corner keypoints

    # Apply Albumentations transform
    transformed = transform(image=image, keypoints=keypoints)
    aug_image = transformed["image"]
    aug_keypoints = transformed["keypoints"]

    # Save augmented image
    filename = os.path.basename(image_path)
    cv2.imwrite(
        os.path.join(output_image_dir, filename),
        cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR)
    )

    # Save augmented labels
    aug_bboxes = [
        (class_id, np.array(keypoints))
        for class_id, keypoints in zip(class_ids, aug_keypoints)
    ]
    write_yolo_label(
        os.path.join(output_label_dir, filename.replace(".jpg", ".txt")),
        aug_bboxes
    )

# Process the entire dataset for augmentation
def process_dataset(image_dir, label_dir, output_image_dir, output_label_dir):
    for image_file in tqdm(os.listdir(image_dir), desc="Processing images"):
        if not image_file.endswith((".jpg", ".png")):
            continue

        image_path = os.path.join(image_dir, image_file)
        label_path = os.path.join(label_dir, image_file.replace(".jpg", ".txt"))

        if not os.path.exists(label_path):
            continue

        augment_and_save(image_path, label_path, output_image_dir, output_label_dir)

# Run augmentation
print("Processing and augmenting dataset...")
process_dataset(image_dir, label_dir, augmented_image_dir, augmented_label_dir)
print("Data augmentation completed!")
