In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        '''
            2 conv
        '''
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out

In [3]:
class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        '''
            1 up + 2 conv
        '''
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
                                    nn.Conv2d(in_size, out_size, kernel_size=1))

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

In [4]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, n_classes=2, depth=5, wf=6, padding=False, batch_norm=True, up_mode='upconv'):
        """
        input_size = 572, output_size = 388
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Using the default arguments will yield the exact version used in the original paper
        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i), padding, batch_norm))
            prev_channels = 2**(wf+i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode, padding, batch_norm))
            prev_channels = 2**(wf+i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path)-1:
                blocks.append(x)
                x = F.avg_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i-1])

        return self.last(x)

In [5]:
model = UNet(in_channels=3, n_classes=2)

In [6]:
model

UNet(
  (down_path): ModuleList(
    (0): UNetConvBlock(
      (block): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (4): ReLU()
        (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): UNetConvBlock(
      (block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (4): ReLU()
        (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): UNetConvBlock(
      (block): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
        (1): 

In [None]:
output_388 = model(image_572)

In [None]:
output_388

In [None]:
from SegNet_standard import SegNet 

In [None]:
seg = SegNet()

In [None]:
output_seg = seg(image_512)

In [None]:
output_seg

In [None]:
seg

In [None]:
in_channels=1
n_classes=2
batch_norm=False
padding = False
up_mode='upconv'
wf = 6
depth = 5
prev_channels = 3
down_path = nn.ModuleList()
for i in range(depth):
    down_path.append(UNetConvBlock(prev_channels, 2**(wf+i), padding, batch_norm))
    prev_channels = 2**(wf+i)

up_path = nn.ModuleList()
for i in reversed(range(depth - 1)):
    up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode, padding, batch_norm))
    prev_channels = 2**(wf+i)

last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

In [None]:
x = image

blocks = []
for i, down in enumerate(down_path):
    x = down(x)
    print(x.size())
    if i != len(down_path)-1:
        blocks.append(x)
        x = F.avg_pool2d(x, 2)

In [None]:
x.size()

In [None]:
import numpy as np
import torch
import os
import pickle
import matplotlib.pyplot as plt
# %matplotlib inline
# plt.rcParams['figure.figsize'] = (20, 20)
# plt.rcParams['image.interpolation'] = 'bilinear'

from argparse import ArgumentParser

from torch.optim import SGD, Adam
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Normalize
from torchvision.transforms import ToTensor, ToPILImage
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import lr_scheduler

import collections
import numbers
import random
import math
from PIL import Image, ImageOps, ImageEnhance

In [None]:
image_path = '../../data/images_flip/train/'
mask_path = '../../data/images_flip/train_masks/'

In [None]:
with open('../../data/train_shuffle_names.pk', 'rb') as f:
    filenames = pickle.load(f)

In [None]:
class RandomCrop(object):
    def __init__(self, crop_size=512):
        self.crop_size = crop_size

    def __call__(self, img_and_label):
        img, label = img_and_label
        w, h = img.size
        
#         xmin = min_random(w)
#         ymin = min_random(h)
        xmin = random.randint(0, w - self.crop_size)
        ymin = random.randint(0, h - self.crop_size)
        
        img = img.crop((xmin, ymin, xmin+self.crop_size, ymin+self.crop_size))
        label = label.crop((xmin, ymin, xmin+self.crop_size, ymin+self.crop_size))
        
        return img, label

class RandomCrop_different_size_for_image_and_label(object):
    def __init__(self, image_size=572, label_size=388):
        self.image_size = image_size
        self.label_size = label_size
        self.bound = (self.image_size - self.label_size) // 2

    def __call__(self, img_and_label):
        img, label = img_and_label
        w, h = img.size
        
        xcenter = random.randint(self.label_size // 2, w - self.label_size // 2)
        ycenter = random.randint(self.label_size // 2, h - self.label_size // 2)
        
        img = img.crop((xcenter - self.image_size // 2, ycenter - self.image_size // 2, xcenter + self.image_size // 2, ycenter + self.image_size // 2))
        label = label.crop((xcenter - self.label_size // 2, ycenter - self.label_size // 2, xcenter + self.label_size // 2, ycenter + self.label_size // 2))
        
        return img, label
class ToTensor_Label(object):
    def __call__(self, img_and_label):
        img, label = img_and_label
        img_tensor = ToTensor()(img)
        label_tensor = torch.from_numpy(np.array(label)).long().unsqueeze(0)
        return img_tensor, label_tensor

class ImageNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    def __call__(self, img_and_label):
        img_tensor, label_tensor = img_and_label
        for t, m, s in zip(img_tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return img_tensor, label_tensor
    
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, x):
        for transform in self.transforms:
            x = transform(x)
        return x

In [None]:
crop_512 = RandomCrop(512)
crop_572 = RandomCrop_different_size_for_image_and_label(image_size=572, label_size=388)
to_tensor_label = ToTensor_Label()
normalize = ImageNormalize([.485, .456, .406], [.229, .224, .225])
transforms = Compose([crop_572, to_tensor_label, normalize])
transforms_512 = Compose([crop_512, to_tensor_label, normalize])

In [None]:
filename_img = filenames[5]
filename_mask = os.path.splitext(filename_img)[0]+'.png'

with open(os.path.join(image_path, filename_img), 'rb') as f:
    image = Image.open(f).convert('RGB')
with open(os.path.join(mask_path, filename_mask), 'rb') as f:
    label = Image.open(f).convert('P')

[image_572, label_388] = transforms([image, label])
[image_512, label_512] = transforms_512([image, label])
image_572 = image_572.unsqueeze(0)
image_512 = image_512.unsqueeze(0)