In [14]:
import os
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, ConcatDataset

# Define the path to the root directory of the dataset
data_dir = "dataset_strw_vs_pear/train/"

# Define the transformation to apply to the dataset images
transform = transforms.Compose([
    transforms.Resize(200),
    transforms.CenterCrop(200),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=45),
    # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor()
])

# Load the dataset using the ImageFolder class and apply the transformation
dataset = ImageFolder(root=data_dir, transform=transform)

# Define the data loader for the augmented dataset
loader = DataLoader(dataset, batch_size=32, shuffle=True)




In [15]:
dataset.classes

['pear', 'strawberry']

In [13]:
# Loop through each batch of data, save the augmented images to their respective folders
for i, (images, labels) in enumerate(loader):
    for j in range(len(images)):
        image = images[j]
        label = labels[j]
        class_dir = os.path.join(data_dir, dataset.classes[label])
        filename = f"{i}_{j}.png"
        path = os.path.join(class_dir, filename)
        transforms.ToPILImage()(image).save(path)



In [1]:
# Merge the original and augmented data in their respective folders
for class_name in dataset.classes:
    class_dir = os.path.join(data_dir, class_name)
    orig_dataset = ImageFolder(root=class_dir)
    aug_dataset = ImageFolder(root=class_dir)
    merged_dataset = ConcatDataset([orig_dataset, aug_dataset])
    merged_loader = DataLoader(merged_dataset, batch_size=32, shuffle=True)
    for i, (images, labels) in enumerate(merged_loader):
        for j in range(len(images)):
            image = images[j]
            label = labels[j]
            filename = f"{i}_{j}.png"
            path = os.path.join(class_dir, filename)
            transforms.ToPILImage()(image).save(path)