In [8]:
import cv2 as cv
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision import transforms

In [9]:
class ImageDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.folder = folder
        self.transform = transform
        self.labels = []
        self.image_paths = []
        self.label_map = {'cats':0, 'dogs':1}

        for label in ['cats', 'dogs']:
            label_path = os.path.join(folder, label)
            for img in os.listdir(label_path):
                if img is not None:
                    self.image_paths.append(os.path.join(label_path, img))
                    self.labels.append(self.label_map[label])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image = self.image_paths[index]
        image = Image.open(image)
        label = self.labels[index]

        if self.transform:
            image = self.transform(image)

        return image, label

In [10]:
train = "C://large_files//cuties//train"
test = "C://large_files//cuties//test"

img_size = (64, 64)

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor()
])


train_dataset = ImageDataset(train, transform)
test_dataset = ImageDataset(test, transform)

In [11]:
image, label = train_dataset.__getitem__(600)

image.shape, label

(torch.Size([3, 64, 64]), 0)

In [12]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [15]:
for image, label in train_dataloader:
    print(image.shape)
    print(label.shape)
    break

torch.Size([32, 3, 64, 64])
torch.Size([32])
