In [29]:
import os
import torch
import imageio
import torchvision

from torch.utils.data import Dataset

from torchvision import transforms

from sklearn.model_selection import train_test_split

## Dataset

In [36]:
class ITS520Dataset(Dataset):
    """
    Dataset of images
        data_path: path to images folder. This should contain multiple folders of each class
            - images
                - cats
                    - cat1.jpg
                - dogs
                    - some-dog.jpg
                - ...
        root: where the pt file will be / is stored
        train: whether to load train or test data
        shuffle: whether to shuffle the data
            - This should only be used in special cases
            - The train and test data should be shuffled together
        transform: transform to apply to the data
        target_transform: transform to apply to the targets
        convert: if the data should be converted from images, or loaded from a pt file
        size: size of the images to convert to. This should be the same as the size of the images in the pt file
    """
    def __init__(self, dataset_save="data.pt", raw_data=None, train=True, shuffle=False, transform=None, target_transform=None, convert=False, size=32):
        self.targets = []
        self.labels = []
        self.data = []

        self.X_train = []
        self.X_test = []
        self.y_train = []
        self.y_test = []
        
        self.transform = transform
        self.target_transform = target_transform


        if convert:
            self.convert(dataset_save, raw_data, size)
        else:
            self.load(dataset_save)

        seed = int(random.random() * 100) if shuffle else 42

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(self.data, self.targets, test_size=0.2, random_state=seed)

        if train:
            self.data = self.X_train
            self.targets = self.y_train
        else:
            self.data = self.X_test
            self.targets = self.y_test

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image, target = self.data[index], self.targets[index]
        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return image, target
    
    def convert(self, dataset_save, raw_data, size):
        dataset = []
        self.labels = []
        targets = []
        if not os.path.exists(raw_data):
            raise ValueError('Raw image directory does not exist.') 
        for folder in os.listdir(raw_data):
            if folder == ".DS_Store":
                continue

            for image in os.listdir(os.path.join(raw_data, folder)):
                if folder not in self.labels:
                    self.labels.append(folder)
                targets.append(self.labels.index(folder))

                img_arr = imageio.imread(os.path.join(raw_data, folder, image), pilmode="RGB")
                resize = torchvision.transforms.Resize(size)
                crop_center = torchvision.transforms.CenterCrop(size)

                img = torch.from_numpy(img_arr).permute(2, 0, 1).float()
                img = resize(img)
                img = crop_center(img)
                img /= 255

                dataset.append(img)

        self.data = torch.stack(dataset)
        self.targets = torch.Tensor(targets).type(torch.LongTensor)

        torch.save((self.data, self.targets, self.labels), dataset_save)

    def load(self, dataset_save):
        if not os.path.exists(dataset_save):
            raise ValueError('Dataset file does not exist. Try creating the dataset by running with convert=True first.') 
        self.data, self.targets, self.labels = torch.load(dataset_save)

In [37]:
print("Dataset class loaded")

Dataset class loaded
