In [None]:
import os
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from PIL import Image

# Define paths to your original and augmented datasets
original_data_path = r"path to dataset folder"
augmented_data_path = r"path to augmented dataset folder"

# Define augmentation transforms
augmentation_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.5, 1.0), ratio=(3 / 4, 4 / 3)),
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load original dataset
original_dataset = ImageFolder(root=original_data_path, transform=None)

# Create a new dataset with augmentation and save
if not os.path.exists(augmented_data_path):
    os.makedirs(augmented_data_path)

for idx, (image_path, label) in enumerate(original_dataset.imgs):
    image = Image.open(image_path).convert('RGB')
    augmented_image = augmentation_transform(image)

    # Save augmented image
    image_name = os.path.basename(image_path)
    augmented_image_path = os.path.join(augmented_data_path, image_name)
    transforms.ToPILImage()(augmented_image).save(augmented_image_path)


print("Augmentation complete.")