In [1]:
#@ Downloading necessary data:
import os
if not os.path.exists('dataset1'):
  !wget -q https://www.dropbox.com/s/0pigmmmynbf9xwq/dataset1.zip
  !unzip -q dataset1.zip
  !rm dataset1.zip
  !pip install -q torch_snippets pytorch_model_summary

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/78.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.7/82.7 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.4/119.4 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m218.7/218.7 kB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m15.2 MB/s[0m eta [36m

In [10]:
#@ Importing necessaries dependencies:
import torch
from torch_snippets import *
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch import nn
device ='cuda' if torch.cuda.is_available() else 'cpu'
from torch.utils.data import Dataset, DataLoader
import cv2

In [3]:
#@ function for image transformation:
tfms=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]) # accord to imagenet dataset
])

In [8]:
#@ Fetching input and output images for training:
class SegData(Dataset):
  def __init__(self, split):
    self.items=stems(f'dataset1/images_prepped_{split}')
    self.split=split

  def __len__(self):
    return len(self.items)

  def __getitem__(self, ix):
    image=read(f'dataset1/images_prepped_{self.split}/ {self.items[ix]}.png', 1)
    image=cv2.resize(image, (224, 224))
    mask=read(f'dataset1/images_prepped_{self.split}/ {self.items[ix]}.png')[:, :, 0]
    mask=cv2.resize(mask, (224, 224))
    return image, mask

  # random image index for debugging purpose:
  def choose(self):
    return self[randint(len(self))]

  def collate_fn(self, batch):
    ims, masks=list(zip(*batch))
    ims=torch.cat([tfms(im.copy()/255.)[None] for im in ims]).float().to(device)
    ce_mask=torch.cat([torch.Tensor(mask[None]) for mask in masks]).long().to(device)
    return ims, ce_mask

In [11]:
#@ training and valid dataset:
train_ds=SegData('train')
valid_ds=SegData('test')
trn_dl=DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=train_ds.collate_fn)
val_dl=DataLoader(valid_ds, batch_size=1, shuffle=True, collate_fn=valid_ds.collate_fn)


##### Architecture for image segmentation

In [4]:
#@ defining convolution blocks:
def conv(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
  )

In [5]:
#@ defining Up-Convolution:
def up_conv(in_channels, out_channels):
  return nn.Sequential(
      nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), # ensure image upscaling
      nn.ReLU(inplace=True)
  )

In [9]:
#@ Defining Network Class:
from torchvision.models import vgg16_bn # for large scale
class UNet(nn.Module):
  def __init__(self, pretrained=True, out_channels=12):
    super().__init__()
    self.encoder=vgg16_bn(pretrained=pretrained).features # excluding FC at end

    # encoder blocks
    self.block1=nn.Sequential(*self.encoder[:6])
    self.block2=nn.Sequential(*self.encoder[6:13])
    self.block3=nn.Sequential(*self.encoder[13:20])
    self.block4=nn.Sequential(*self.encoder[20:27])
    self.block5=nn.Sequential(*self.encoder[27:34])

    self.bottleneck=nn.Sequential(*self.encoder[34:]) #acts between encoder and decoder
    self.conv_bottleneck=conv(512, 1024)

    self.up_conv6=up_conv(1024, 512)
    self.conv6=conv(512 + 512, 512)
    self.up_conv7=up_conv(512, 256)
    self.conv7=conv(512 + 256, 256)
    self.up_conv8=up_conv(256, 128)
    self.conv8=conv(128 + 256, 128)
    self.up_conv9=up_conv(128, 64)
    self.conv9=conv(128 + 64, 64)
    self.up_conv10=up_conv(64, 32)
    self.conv10=conv(32 + 64, 32)

    self.conv11=nn.Conv2d(32, out_channels, kernel_size=1)

  def forward(self, x):
    block1=self.block1(x)
    block2=self.block2(block1)
    block3=self.block2(block2)
    block4=self.block2(block3)
    block5=self.block2(block4)

    bottleneck=self.bottleneck(block5)
    x=self.conv_bottleneck(bottleneck)

    x=self.up_conv6(x)
    x=torch.cat([x, block5], dim=1)
    x=self.conv6(x)

    x=self.up_conv7(x)
    x=torch.cat([x, block4], dim=1)
    x=self.conv7(x)

    x=self.up_conv8(x)
    x=torch.cat([x, block3], dim=1)
    x=self.conv8(x)

    x=self.up_conv9(x)
    x=torch.cat([x, block2], dim=1)
    x=self.conv9(x)


    x=self.up_conv10(x)
    x=torch.cat([x, block1], dim=1)
    x=self.conv10(x)

    x=self.conv11(x)

    return x
