In [149]:
import os
import sys
import math

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import cv2
import numpy as np
import pandas as pd
import torch

from tqdm import tqdm
import torch
from torchsummary import summary

from utils import affinity_utils
from utils.loss import mIoULoss, CrossEntropyLoss2d
from models.hourglas_spin import HourglassNet

In [150]:
weights = torch.ones(2)
miou = mIoULoss(weight=weights)
weights_angles = torch.ones(37)
ce = CrossEntropyLoss2d(weight=weights_angles, size_average=True, ignore_index=255, reduce=True)

In [151]:
model = HourglassNet()
summary(model, (3, 400, 400))

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 200, 200]        9,472
├─BatchNorm2d: 1-2                       [-1, 64, 200, 200]        128
├─ReLU: 1-3                              [-1, 64, 200, 200]        --
├─ModuleList: 1                          []                        --
|    └─Sequential: 2                     []                        --
|    |    └─ReLU: 3-1                    [-1, 64, 200, 200]        --
|    └─Sequential: 2                     []                        --
|    |    └─ReLU: 3-2                    [-1, 64, 200, 200]        --
├─ModuleList: 1                          []                        --
|    └─Sequential: 2                     []                        --
|    |    └─ReLU: 3-3                    [-1, 64, 200, 200]        --
|    └─Sequential: 2                     []                        --
|    |    └─ReLU: 3-4                    [-1, 64, 200, 200]        --
├─Sequentia

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 200, 200]        9,472
├─BatchNorm2d: 1-2                       [-1, 64, 200, 200]        128
├─ReLU: 1-3                              [-1, 64, 200, 200]        --
├─ModuleList: 1                          []                        --
|    └─Sequential: 2                     []                        --
|    |    └─ReLU: 3-1                    [-1, 64, 200, 200]        --
|    └─Sequential: 2                     []                        --
|    |    └─ReLU: 3-2                    [-1, 64, 200, 200]        --
├─ModuleList: 1                          []                        --
|    └─Sequential: 2                     []                        --
|    |    └─ReLU: 3-3                    [-1, 64, 200, 200]        --
|    └─Sequential: 2                     []                        --
|    |    └─ReLU: 3-4                    [-1, 64, 200, 200]        --
├─Sequentia

# Dataset

In [102]:
class RoadDataset(torch.utils.data.Dataset):
    def __init__(
        self, dataframe, target_size, args, multi_scale_pred=True, is_train=True
    ):
        
        self.dataframe = dataframe
        self.args = args
        # paths
        self.dir = self.args.data_path

        # list of all images
        # self.images = [line.rstrip("\n") for line in open(self.image_list)]

        # augmentations
        self.augmentation = is_train
        self.crop_size = [target_size[0], target_size[1]]
        self.multi_scale_pred = multi_scale_pred

        # preprocess
        self.angle_theta = 10 # TODO: change to args

        # to avoid Deadloack  between CV Threads and Pytorch Threads caused in resizing
        cv2.setNumThreads(0)

    def __len__(self):
        return len(self.dataframe)

    def getRoadData(self, index):
        image_data = self.dataframe.iloc[index, :]

        # load image
        img_path = os.path.join(self.dir, image_data["fpath"])
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f'Image not found: {img_path}')

        image = cv2.imread(img_path).astype(float)

        # load mask
        mask_path = os.path.join(self.dir, image_data["mpath"])
        if not os.path.isfile(mask_path):
            raise FileNotFoundError(f'Image not found: {mask_path}')

        gt = cv2.imread(mask_path, 0).astype(float)

        h, w, c = image.shape
        if self.augmentation: # TODO: change to troch transforms
            flip = np.random.choice(2) * 2 - 1
            image = np.ascontiguousarray(image[:, ::flip, :])
            gt = np.ascontiguousarray(gt[:, ::flip])
            rotation = np.random.randint(4) * 90
            M = cv2.getRotationMatrix2D((w / 2, h / 2), rotation, 1)
            image = cv2.warpAffine(image, M, (w, h))
            gt = cv2.warpAffine(gt, M, (w, h))

        image = image / 255.0
        image = torch.tensor(image, dtype=torch.float32)
        image = image.permute(2, 0, 1)

        return image, gt

    def getOrientationGT(self, keypoints, height, width):
        vecmap, vecmap_angles = affinity_utils.getVectorMapsAngles(
            (height, width), keypoints, theta=self.angle_theta, bin_size=10
        )
        vecmap_angles = torch.from_numpy(vecmap_angles)

        return vecmap_angles


    def __getitem__(self, index):

        image, gt = self.getRoadData(index)
        c, h, w = image.shape

        labels = []
        vecmap_angles = []
        if self.multi_scale_pred:
            smoothness = [1, 2, 4]
            scale = [4, 2, 1]
        else:
            smoothness = [4]
            scale = [1]

        for i, val in enumerate(scale):
            if val != 1:
                gt_ = cv2.resize(
                    gt,
                    (int(math.ceil(h / (val * 1.0))), int(math.ceil(w / (val * 1.0)))),
                    interpolation=cv2.INTER_NEAREST,
                )
            else:
                gt_ = gt

            gt_orig = np.copy(gt_)
            gt_orig /= 255.0
            
            gt_orig_tens = torch.tensor(gt_orig, dtype=torch.float32)
            labels.append(gt_orig_tens)

            # Create Orientation Ground Truth
            keypoints = affinity_utils.getKeypoints(
                gt_orig, is_gaussian=False, smooth_dist=smoothness[i]
            )
            vecmap_angle = self.getOrientationGT(
                keypoints,
                height=int(math.ceil(h / (val * 1.0))),
                width=int(math.ceil(w / (val * 1.0))),
            )
            vecmap_angles.append(vecmap_angle)


        return image, labels, vecmap_angles

In [135]:
class arg:
    data_path = '/Users/alexanderveicht/Desktop/Coding/cil-road-segmentation.nosync/data/big-dataset'
    weight_miou = 2
    weight_vec = 1
args = arg()

df = pd.read_csv(os.path.join(args.data_path, "dataset.csv"))
df = df[df.dataset == "CIL"]
df.head()

Unnamed: 0,filename,dataset,fpath,mpath,split
0,satimage_132,CIL,CIL/images/satimage_132.jpg,CIL/groundtruth/satimage_132-mask.png,train
1,satimage_126,CIL,CIL/images/satimage_126.jpg,CIL/groundtruth/satimage_126-mask.png,train
2,satimage_41,CIL,CIL/images/satimage_41.jpg,CIL/groundtruth/satimage_41-mask.png,train
3,satimage_55,CIL,CIL/images/satimage_55.jpg,CIL/groundtruth/satimage_55-mask.png,train
4,satimage_69,CIL,CIL/images/satimage_69.jpg,CIL/groundtruth/satimage_69-mask.png,train


In [136]:
rds = RoadDataset(df, (400, 400), args, is_train=True)

In [137]:
data = rds.__getitem__(2)
len(data)

3

In [138]:
img, tar, o = data
img = img.unsqueeze(0)
label = [t.unsqueeze(0) for t in tar]
vecmap_angles = [t.unsqueeze(0) for t in o]

In [148]:
label[-1].shape

torch.Size([1, 400, 400])

In [140]:
pred_mask, pred_vec = model(img)
len(pred_mask)

4

In [143]:
def criterion(pred_mask, pred_vec, label, vecmap_angles):
    loss1 = miou(pred_mask[0], label[0], False)

    num_stacks = model.num_stacks
    for idx in range(num_stacks - 1):
        loss1 += miou(pred_mask[idx + 1], label[0], False)

    for idx, output in enumerate(pred_mask[-2:]):
        loss1 += miou(output, label[idx + 1], False)

    loss2 = ce(pred_vec[0], vecmap_angles[0])
    for idx in range(num_stacks - 1):
        loss2 += ce(pred_vec[idx + 1], vecmap_angles[0])
    for idx, pred_vecmap in enumerate(pred_vec[-2:]):
        loss2 += ce(pred_vecmap, vecmap_angles[idx + 1])

    loss = args.weight_miou * loss1 + args.weight_vec * loss2
    return loss, loss1, loss2

criterion(pred_mask, pred_vec, label, vecmap_angles)

(tensor(12.4241, grad_fn=<AddBackward0>),
 tensor(-1.0879, grad_fn=<AddBackward0>),
 tensor(14.5998, grad_fn=<AddBackward0>))