# Self-Supervised Monocular Depth Estimation

In [1]:
from dataloader import KittiDataset
from resnet import MyNet

import torch
from torch.utils.data import DataLoader

import os
import random
import torchvision.transforms as transforms
import torchvision.transforms.functional as tF

## data loading

In [2]:
train_path = os.path.join(os.getcwd(), "data/train")
# test_path = os.path.join(os.getcwd(), 'data/test')

In [8]:
class ToTensor(object):
    def __init__(self):
        self.transform = transforms.ToTensor()
        
    def __call__(self, sample):
        left_img = sample['left_image']
        sample['left_image'] = self.transform(left_img)
        
        if 'right_image' in sample:
            right_img = sample['right_image']
            sample['right_image'] = self.transform(right_img)
            
        return sample

In [4]:
class ToRandomFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob
        
    def __call__(self, sample):
        if random.random() < self.prob:
            return sample
        
        left_img = sample['left_image']
        sample['left_image'] = tF.hflip(left_img)
        
        if 'right_image' in sample:
            right_img = sample['right_image']
            sample['right_image'] = tF.hflip(right_img)

        return sample

In [5]:
class ToResizeImage(object):
    def __init__(self, size=(256, 512)):
        self.size = size
        
    def __call__(self, sample):
        left_img = sample['left_image']
        sample['left_image'] = tF.resize(left_img, self.size)
        
        if 'right_image' in sample:
            right_img = sample['right_image']
            sample['right_image'] = tF.resize(right_img, self.size)

        return sample

In [13]:
class AugumentImagePair(object):
    def __init__(self, prob=0.5):
        self.prob = prob
        
    def __call__(self, sample):
        if 'right_image' not in sample:
            return sample

        left_img = sample['left_image']
        right_img = sample['right_image']
        if random.random() < self.prob:
            # shift gamma
            random_gamma = random.uniform(0.8, 1.2)
            left_img_aug = left_img ** random_gamma
            right_img_aug = right_img ** random_gamma
            
            # shift brightness
            random_brightness = random.uniform(0.5, 2.0)
            left_img_aug = left_img_aug * random_brightness
            right_img_aug = right_img_aug * random_brightness
            
            # shift color
            random_color = random.uniform(0.8, 1.2)
            for i in range(3):
                left_img_aug[i,:,:] *= random_color
                right_img_aug[i,:,:] *= random_color
                
            # saturate
            left_img_aug = torch.clamp(left_img_aug, 0, 1)
            right_img_aug = torch.clamp(right_img_aug, 0, 1)
            
            sample = {
                'left_image': left_img_aug,
                'right_image': right_img_aug
            }
        
        return sample

In [14]:
train_transform = transforms.Compose([
    ToResizeImage(),
    ToRandomFlip(),
    ToTensor(),
    AugumentImagePair(),
])
test_transform = transforms.Compose([
    ToResizeImage(),
    ToTensor(),
])

In [15]:
train_set = KittiDataset(train_path, 'train', transform = train_transform)
# test_set = KittiLoader(test_path, 'test', transform = test_transform)

train_loader = DataLoader(train_set, batch_size = 5, shuffle = True)

for batch_step, batch_sample in enumerate(train_loader):
    print(batch_step, batch_sample)
    break
    

0 {'left_image': tensor([[[[0.2510, 0.3373, 0.7569,  ..., 0.1098, 0.0784, 0.0745],
          [0.3373, 0.3725, 0.8078,  ..., 0.1098, 0.0745, 0.0745],
          [0.5569, 0.5725, 0.8549,  ..., 0.1137, 0.0784, 0.0745],
          ...,
          [0.5451, 0.5529, 0.5529,  ..., 0.4941, 0.4667, 0.3765],
          [0.5333, 0.5373, 0.5333,  ..., 0.4431, 0.4431, 0.4353],
          [0.5294, 0.5373, 0.5294,  ..., 0.2745, 0.3373, 0.4078]],

         [[0.2353, 0.2667, 0.4392,  ..., 0.1216, 0.0824, 0.0941],
          [0.3412, 0.3255, 0.5843,  ..., 0.1255, 0.0824, 0.0941],
          [0.7294, 0.6431, 0.7765,  ..., 0.1294, 0.0863, 0.0941],
          ...,
          [0.5216, 0.5255, 0.5216,  ..., 0.4784, 0.4588, 0.4157],
          [0.5216, 0.5176, 0.5412,  ..., 0.4314, 0.4549, 0.4471],
          [0.5373, 0.5176, 0.5412,  ..., 0.3294, 0.3725, 0.4196]],

         [[0.3176, 0.3882, 0.6000,  ..., 0.1490, 0.1216, 0.0902],
          [0.4784, 0.5804, 0.7922,  ..., 0.1529, 0.1255, 0.0941],
          [0.9020, 0.8863

## Implement CNN using PyTorch

In [16]:
net = MyNet().eval()