In [22]:
import os
import albumentations as A
import numpy as np
from typing import Tuple
from pathlib import Path
from albumentations.pytorch import ToTensorV2
from PIL import Image as Img
from torch.utils.data import Dataset


CURR_DIR = Path("__file__").parent.resolve()
TRAIN_IMG_DIR = CURR_DIR / "../dataset/train_images/"
TRAIN_MASK_DIR = CURR_DIR / "../dataset/train_masks/"
VAL_IMG_DIR = CURR_DIR / "../dataset/val_images/"
VAL_MASK_DIR = CURR_DIR / "../dataset/val_masks/"

class CarvanaSet(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index) -> Tuple:
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(
            self.mask_dir, self.images[index].replace(".jpg", "_mask.gif")
        )
        # to augment using Albumentation, load into numpy array.
        image = np.array(Img.open(img_path).convert("RGB"))
        mask = np.array(Img.open(mask_path).convert("L"), dtype=np.float32)
        # normalize/convert white value from 255 to 1.0
        mask[mask == 255] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return (image, mask)


In [23]:
# Create a transform.
simple_tx = A.Compose(
    [
        A.Resize(width=240, height=240),
        A.Rotate(limit=35, p=0.5),
        A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
        )
    ]
)

dataset = CarvanaSet(
    image_dir=TRAIN_IMG_DIR,
    mask_dir=TRAIN_MASK_DIR,
    transform=simple_tx,
)

In [25]:
# Using PIL.
training_images = os.listdir(TRAIN_IMG_DIR)
training_masks = os.listdir(TRAIN_MASK_DIR)

count = 0
for img, mask in zip(training_images, training_masks):
  img = Img.open(TRAIN_IMG_DIR / img)
  img.show()
  
  if count == 2:
    break

  count += 1