<a href="https://colab.research.google.com/github/sujitojha1/EVA4/blob/rev8/S15/EVA4_S15_Solution_DenseDepth_step3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# EVA4 Session 15 Assignment - DepthMap & Mask Prediction

In [0]:
%load_ext tensorboard

In [1]:
!pip install kornia

Collecting kornia
[?25l  Downloading https://files.pythonhosted.org/packages/c2/60/f0c174c4a2a40b10b04b37c43f5afee3701cc145b48441a2dc5cf9286c3c/kornia-0.3.1-py2.py3-none-any.whl (158kB)
[K     |██                              | 10kB 32.7MB/s eta 0:00:01[K     |████▏                           | 20kB 1.5MB/s eta 0:00:01[K     |██████▏                         | 30kB 1.8MB/s eta 0:00:01[K     |████████▎                       | 40kB 2.1MB/s eta 0:00:01[K     |██████████▎                     | 51kB 1.9MB/s eta 0:00:01[K     |████████████▍                   | 61kB 2.1MB/s eta 0:00:01[K     |██████████████▍                 | 71kB 2.3MB/s eta 0:00:01[K     |████████████████▌               | 81kB 2.5MB/s eta 0:00:01[K     |██████████████████▋             | 92kB 2.4MB/s eta 0:00:01[K     |████████████████████▋           | 102kB 2.6MB/s eta 0:00:01[K     |██████████████████████▊         | 112kB 2.6MB/s eta 0:00:01[K     |████████████████████████▊       | 122kB 2.6MB/s eta 

In [0]:
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import torchvision
from torch import nn
import torch
from kornia.losses import SSIM

%matplotlib inline

In [0]:
data_root = Path('./data/')

f1,f2,f3 = data_root/'bg', data_root/'fg_bg', data_root/'mask'

print(len(list(f1.iterdir())))
print(len(list(f2.iterdir())))
print(len(list(f3.iterdir())))

In [0]:
scale_transform = transforms.Compose([
                                      transforms.Resize((256,256)),
                                      transforms.ToTensor()
                                      ])

class MasterDataset(Dataset):
  def __init__(self, data_root, transform=None):
    self.f1_files = list(f1.glob('*.jpg'))
    self.f2_files = list(f2.glob('*.jpg'))
    self.f3_files = list(f3.glob('*.jpg'))
    self.transform = transform

  def __len__(self):
    return len(self.f1_files)

  def __getitem__(self,index):
    f1_image = Image.open(self.f1_files[index])
    f2_image = Image.open(self.f2_files[index])
    f3_image = Image.open(self.f3_files[index])

    f1_image = scale_transform(f1_image)
    f2_image = scale_transform(f2_image)
    f3_image = scale_transform(f3_image)

    return {'f1': f1_image, 'f2': f2_image, 'f3': f3_image}


In [0]:
mean, std = torch.tensor([0.485,0.456,0.406])*255, torch.tensor([0.229,0.224,0.225])*255

train_transforms = transforms.Compose([
                                       transforms.Resize((256,256)),
                                       transforms.ColorJitter(brightness=0.05, contrast= 0.05, saturation = 0.05, hue = 0.05),
                                       transforms.ToTensor()
                                       ])

train_ds = MasterDataset(data_root, train_transforms)

In [0]:
[(k,v.shape) for k,v in train_ds[0].items()]

In [0]:
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, pin_memory=True)

In [0]:
sample = next(iter(train_dl))

In [0]:
[(k,v.shape) for k,v in sample.items()]

In [0]:
imgs = sample['f1']

In [0]:
grid_tensor = torchvision.utils.make_grid(imgs, 2)
grid_image = grid_tensor.permute(1,2,0)

def show(tensors, figsize = (10,10), *args, **kwargs):
  try:
    tensors = tensors.detach().cpu()
  except:
    pass
  grid_tensor = torchvision.utils.make_grid(tensors, *args, **kwargs)
  grid_image = grid_tensor.permute(1,2,0)

  plt.figure(figsize=figsize)
  plt.imshow(grid_image)

  plt.xticks([])
  plt.yticks([])

  plt.show()

def show_pred(tensors, *args, **kwargs):
  tensors = (tensors * std[None,:,None,None]) + mean[None,:,None,None]
  show(tensors, *args, **kwargs)

In [0]:
show(imgs, nrow=4)

In [0]:
class ConvGen(nn.Module):
  '''Generator'''
  def __init__(self):
    super(ConvGen,self).__init__()

    self.convblock1 = nn.Sequential(
        nn.Conv2d(3,32,3,stride=1,padding=1,bias=False),
        nn.BatchNorm2d(32),
        nn.ReLU()
    )

    self.convblock2 = nn.Sequential(
        nn.Conv2d(32,32,3,stride=1,padding=1,bias=False,group=32),
        nn.Conv2d(32,64,1,stride=1,padding=0,bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU()
    )

    self.convblock3 = nn.Sequential(
        nn.Conv2d(128,256,3,stride=1,padding=1,bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU()
    )

    self.convblock4 = nn.Sequential(
        nn.Conv2d(256,3,3,stride=1,padding=1,bias=False),
    )

  def forward(self,sample):
    f1=sample['f1']
    f2=sample['f2']

    f1 = self.convblock2(self.convblock1(f1))
    f2 = self.convblock2(self.convblock1(f2))

    f = torch.cat([f1,f2],dim=1)
    f = self.convblock4(self.convblock3(f))

    return f

In [0]:
criterion = nn.BCEWithLogitLoss()

In [0]:
model = ConvGen()
print(model)