In [19]:
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset

import os
import pandas as pd
from skimage import io

# Load data

In [23]:
class CatsAndDogsDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.annotations) # 25000
    
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        image = io.imread(img_path)
        y_label = torch.tensor(int(self.annotations.iloc[index, 1]))
        
        if self.transform:
            image = self.transform(image)
        
        return (image, y_label)

# Augmentation

In [25]:
my_transforms = transforms.ToTensor()
my_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)),
    transforms.ColorJitter(brightness=0.5),
    transforms.RandomRotation(degrees=45),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.05),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
])

dataset = CatsAndDogsDataset(
    csv_file='cats_dogs.csv',
    root_dir='cats_dogs_resized',
    transform=my_transforms
)

In [30]:
img_num = 0

for _ in range(5):
    for img, label in dataset:
        # print(img.shape)
        img_path = os.path.join('dataset', 'img' + str(img_num) + '.png')
        save_image(img, img_path)
        img_num += 1