In [None]:
import glob

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms


class TempDataset(Dataset):
    def __init__(self, image_path, mask_path):
        self.image_files = glob.glob(image_path + "/*.jpg")
        self.mask_files = glob.glob(mask_path + "/*.png")
        self.transform = transforms.ToTensor()
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image = np.asarray(Image.open(self.image_files[idx]).convert("RGB"))
        mask = np.asarray(Image.open(self.mask_files[idx]).convert("L"))
        image = self.transform(image)
        mask = self.transform(mask)
        return image, mask, self.image_files[idx]
    
dataset = TempDataset(image_path="../original", mask_path="../annotated")
loader = DataLoader(dataset, batch_size=len(dataset))

for image, mask, name in loader:
    unique = torch.unique(mask)
    assert len(unique) == 2
    assert unique[0] == 0 and unique[1] == 1
    
    image = image.numpy()
    mask = mask.numpy()
    for i in range(0, image.shape[0]):
        plt.imshow(image[i].transpose(1, 2, 0))
        plt.imshow(mask[i].transpose(1, 2, 0), alpha=0.5)
        plt.title(name[i])
        plt.figure(figsize=(10, 10))