In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

In [5]:
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform_img = None, transform_mask = None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform_img = transform_img
        self.transform_mask = transform_mask

        # Ensure consistent ordering
        self.img_filenames = sorted(os.listdir(img_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])

        # Open image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Apply transformations (if any)
        if self.transform_img:
            image = self.transform_img(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)

        return image, mask

In [14]:
transform_img = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])

def normalize(x):
    return x / 255.0

transform_mask = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(normalize)
])

In [15]:
from torch.utils.data import DataLoader

train_dataset = SegmentationDataset(
    img_dir = "./archive/train/resized_train_img",
    mask_dir = "./archive/train/resized_train_masks",
    transform_img = transform_img,
    transform_mask = transform_mask
)

val_dataset = SegmentationDataset(
    img_dir = "./archive/val/resized_val_img",
    mask_dir = "./archive/val/resized_val_masks",
    transform_img = transform_img,
    transform_mask = transform_mask
)

train_loader = DataLoader(train_dataset, batch_size = 8, shuffle = True, num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size = 8, shuffle = False, num_workers = 4)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_sample(image, mask):
    fig, ax = plt.subplots(1, 2, figsize = (8, 4))

    image = image.permute(1, 2, 0).numpy()
    mask = mask.squeeze().numpy()

    ax[0].imshow(image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap = "gray")
    ax[1].set_title("Mask")
    ax[1].axis("off")

    plt.show()

data_iter = iter(train_loader)
images, masks = next(data_iter)

visualize_sample(images[0], masks[0])