In [None]:
import cv2
import numpy as np
import albumentations as A
from network import LaneDataset
from common import pad_gt
from constants import IMG_HEIGHT, IMG_WIDTH

In [79]:
from os import listdir
from torch.utils.data import Dataset
import torch

class LaneDataset(Dataset):
    def __init__(self, img_folder: str, gt_folder: str, augment=False):
        self.img_folder = img_folder
        self.gt_folder = gt_folder

        img_name_lists = listdir(img_folder)
        gt_name_lists = listdir(gt_folder)
        self.img_gt_list = [(img, gt) for img, gt in zip(img_name_lists, gt_name_lists)]

        if augment:
            self.transform = A.Compose(
                [
                    A.HorizontalFlip(p=0.5),
                    # A.Rotate(limit=5., p=1.0),
                    # A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=1.0),
                    A.ToTensorV2(),
                ],
                # TODO: receive seed
                # seed=0
            )
        else:
            self.transform = A.ToTensorV2()

    def __len__(self):
        return len(self.img_gt_list)
    
    def __getitem__(self, idx):
        img_fn, gt_fn = self.img_gt_list[idx]
        img = cv2.imread(f"{self.img_folder}/{img_fn}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        gt = cv2.imread(f"{self.gt_folder}/{gt_fn}", cv2.IMREAD_GRAYSCALE)

        # Pad or crop img and gt to IMG_HEIGHT, IMG_WIDTH
        h, w = img.shape[:2]
        pad_h = max(0, IMG_HEIGHT - h)
        pad_w = max(0, IMG_WIDTH - w)
        img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
        gt = cv2.copyMakeBorder(gt, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
        img = img[:IMG_HEIGHT, :IMG_WIDTH]
        gt = gt[:IMG_HEIGHT, :IMG_WIDTH]

        # if augmented is False, this will be just cast as tensors
        augmented = self.transform(image = img, mask = gt)
        img = augmented['image']
        gt = augmented['mask']

        return img, gt

In [82]:
def prepare_for_cv(img, gt):
    img = img.permute(1, 2, 0).cpu().numpy().astype('uint8')
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

    gt = gt.squeeze().detach().numpy().astype(float)
    gt = pad_gt(gt, IMG_HEIGHT, IMG_WIDTH)
    gt = np.uint8(255*gt)
    gt = cv2.cvtColor(gt, cv2.COLOR_GRAY2BGR)
    gt[:, :, 0] = 0 
    gt[:, :, 2] = 0
    return img, gt

img_folder = r"C:\javier\personal_projects\computer_vision\data\KITTI_road_segmentation\data_road\training\image_2"
gt_folder = r"data\labels"

data = LaneDataset(img_folder, gt_folder)
img, gt = data[80]
img, gt = prepare_for_cv(img, gt)
original = cv2.addWeighted(img, 1, gt, 0.5, 0)

data_aug = LaneDataset(img_folder, gt_folder, augment=True)
img, gt = data_aug[80]
img, gt = prepare_for_cv(img, gt)
augmented = cv2.addWeighted(img, 1, gt, 0.5, 0)

# applying transformations
# transform = A.Compose(
#     [
#     A.HorizontalFlip(p=1),
#     # A.Rotate(limit=5., p=1.0),
#     # A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=1.0)
#     ],
# seed=0
# )
# augmented = transform(image=img_np, mask=gt)
# a_img, a_gt = augmented['image'], augmented['mask']

cv2.imshow('original', original)
cv2.imshow('augmented', augmented)

cv2.waitKey(0)
cv2.destroyAllWindows()