In [3]:
import os
from PIL import Image
import torch
from torchvision import transforms
import random

# Define the data augmentation transformations
def get_augmentation_transforms():
    return transforms.Compose([
        transforms.Resize((1024, 1024)),  # Resize to fixed size
        transforms.RandomResizedCrop((1024, 1024), scale=(0.3, 0.6)),
        transforms.RandomRotation(90),  # Random rotation within 30 degrees
        transforms.RandomHorizontalFlip(),  # Random horizontal flip
        transforms.RandomVerticalFlip(),  # Random vertical flip
        transforms.ToTensor(),
    ])

# Function to save augmented images
def augment_and_save_images(root_dir, save_dir, augmentations_per_image=5):
    transform = get_augmentation_transforms()

    for label, class_name in enumerate(['Apoptosis', 'Necroptosis', 'Necrosis']):
        class_dir = os.path.join(root_dir, class_name)
        save_class_dir = os.path.join(save_dir, class_name)
        os.makedirs(save_class_dir, exist_ok=True)

        for img_name in os.listdir(class_dir):
            img_path = os.path.join(class_dir, img_name)
            image = Image.open(img_path).convert("RGB")

            # Save the original image as well
            original_save_path = os.path.join(save_class_dir, img_name)
            image.save(original_save_path)

            # Generate multiple augmented versions of the image
            for i in range(augmentations_per_image):
                augmented_image = transform(image)
                augmented_image_pil = transforms.ToPILImage()(augmented_image)

                # Create a new file name for each augmented image
                new_img_name = f"{os.path.splitext(img_name)[0]}_aug_{i}.png"
                new_img_path = os.path.join(save_class_dir, new_img_name)

                # Save the augmented image
                augmented_image_pil.save(new_img_path)


In [5]:
root_dir = r'C:\rkka_Projects\cell_death_v1\Data\pathway\collected/test'
save_dir = r'C:\rkka_Projects\cell_death_v1\Data\pathway\collected/test/augmented'
augmentations_per_image = 10

augment_and_save_images(root_dir, save_dir, augmentations_per_image)
print(f"Augmented images saved in: {save_dir}")

Augmented images saved in: C:\rkka_Projects\cell_death_v1\Data\pathway\collected/test/augmented
