In [61]:
%load_ext blackcellmagic

The blackcellmagic extension is already loaded. To reload it, use:
  %reload_ext blackcellmagic


In [197]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import os
import pathlib
from tqdm import tqdm_notebook as tqdm

In [198]:
torch.__version__, np.__version__

('1.3.1', '1.16.3')

In [205]:
class SegmentationDataset(Dataset):
    def __init__(self, data_path, dataset_mean=(69.2614, 55.9220, 32.6043), out_size=(1024, 768)):
        all_filenames = [x.name for x in pathlib.Path(data_path).resolve().iterdir()]
        inputs = [path for path in all_filenames if "gt" not in path.lower()]
        self.all_pairs = []
        self.out_size = out_size
        self.mean = torch.tensor(dataset_mean)
        for input_path in inputs:
            labels = [
                path
                for path in all_filenames
                if "gt" in path.lower() and path.startswith(str(pathlib.Path(input_path).with_suffix("")))
            ]
            if labels:
                self.all_pairs.append((data_path + "/" + input_path, data_path + "/" + labels[0]))
        self.all_data = []
        for input_path, label_path in tqdm(self.all_pairs, desc="Loading images to RAM"):
            input_image = Image.open(input_path).resize(self.out_size, Image.LANCZOS)
            label_image = Image.open(label_path).resize(self.out_size, Image.NEAREST)
            input_array = np.array(input_image, dtype=np.float32)
            label_array = np.array(label_image)
            processed_label = torch.from_numpy(
                np.stack(
                    [
                        (label_array == (0, 0, 0, 255))[:, :, 0],
                        (label_array == (128, 128, 128, 255))[:, :, 0],
                        (label_array == (255, 255, 255, 255))[:, :, 0],
                    ]
                ).astype(np.float32)
            )
            input_normalized = ((torch.from_numpy(input_array) - self.mean) / 255.0).permute(2, 0, 1)
            self.all_data.append((input_normalized, processed_label))

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, idx):
        return self.all_data[idx]

In [206]:
dataset = SegmentationDataset("./data/training-resized")

HBox(children=(IntProgress(value=0, description='Loading images to RAM', max=245, style=ProgressStyle(descript…

In [210]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [211]:
class Down(nn.Module):
    def __init__(self, in_channels, out_channels, dropout):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels),
            nn.Dropout2d(p=dropout, inplace=True)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

In [212]:
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, dropout, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = nn.Sequential(
            DoubleConv(in_channels, out_channels), 
            nn.Dropout2d(p=dropout, inplace=True)
        )

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


In [213]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [214]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128, 0.2)
        self.down2 = Down(128, 256, 0.2)
        self.down3 = Down(256, 512, 0.2)
        self.down4 = Down(512, 512, 0.5)
        self.up1 = Up(1024, 256, 0.2, bilinear)
        self.up2 = Up(512, 128, 0.2, bilinear)
        self.up3 = Up(256, 64, 0.2, bilinear)
        self.up4 = Up(128, 64, 0.0, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [215]:
unet = UNet(3,3)

In [216]:
unet = unet.to("cuda")

In [217]:
input_data = dataset[0][0].unsqueeze(0)

In [218]:
input_data = input_data.to("cuda")

In [219]:
%%time
output1 = unet(input_data)

CPU times: user 140 ms, sys: 40 ms, total: 180 ms
Wall time: 1.06 s


In [223]:
output1[0, :, 500, 500]

tensor([ 0.2494,  0.4982, -0.1582], device='cuda:0', grad_fn=<SelectBackward>)

In [None]:
output2 = unet(input_data)

In [None]:
output1.shape