In [3]:
import torch
from torch import nn
import torch.nn.functional as F

In [5]:
class UNet_check(nn.Module):
    def __init__(self, in_channels=3, n_classes=2):
        super(UNet_check, self).__init__()

        down1 = []
        down1.append(nn.Conv2d(in_channels, 64, kernel_size=3, padding=0))
        down1.append(nn.ReLU())
        down1.append(nn.BatchNorm2d(64))
        down1.append(nn.Conv2d(64, 64, kernel_size=3, padding=0))
        down1.append(nn.ReLU())
        down1.append(nn.BatchNorm2d(64))
        self.down1 = nn.Sequential(*down1)
        
        down2 = []
        down2.append(nn.Conv2d(64, 128, kernel_size=3, padding=0))
        down2.append(nn.ReLU())
        down2.append(nn.BatchNorm2d(128))
        down2.append(nn.Conv2d(128, 128, kernel_size=3, padding=0))
        down2.append(nn.ReLU())
        down2.append(nn.BatchNorm2d(128))
        self.down2 = nn.Sequential(*down2)
        
        down3 = []
        down3.append(nn.Conv2d(128, 256, kernel_size=3, padding=0))
        down3.append(nn.ReLU())
        down3.append(nn.BatchNorm2d(256))
        down3.append(nn.Conv2d(256, 256, kernel_size=3, padding=0))
        down3.append(nn.ReLU())
        down3.append(nn.BatchNorm2d(256))
        self.down3 = nn.Sequential(*down3)
        
        down4 = []
        down4.append(nn.Conv2d(256, 512, kernel_size=3, padding=0))
        down4.append(nn.ReLU())
        down4.append(nn.BatchNorm2d(512))
        down4.append(nn.Conv2d(512, 512, kernel_size=3, padding=0))
        down4.append(nn.ReLU())
        down4.append(nn.BatchNorm2d(512))
        self.down4 = nn.Sequential(*down4)
        
        down5 = []
        down5.append(nn.Conv2d(512, 1024, kernel_size=3, padding=0))
        down5.append(nn.ReLU())
        down5.append(nn.BatchNorm2d(1024))
        down5.append(nn.Conv2d(1024, 1024, kernel_size=3, padding=0))
        down5.append(nn.ReLU())
        down5.append(nn.BatchNorm2d(1024))
        self.down5 = nn.Sequential(*down5)
        
        self.up4_x = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        up4 = []
        up4.append(nn.Conv2d(1024, 512, kernel_size=3, padding=0))
        up4.append(nn.ReLU())
        up4.append(nn.BatchNorm2d(512))
        up4.append(nn.Conv2d(512, 512, kernel_size=3, padding=0))
        up4.append(nn.ReLU())
        up4.append(nn.BatchNorm2d(512))
        self.up4 = nn.Sequential(*up4)
        
        self.up3_x = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        up3 = []
        up3.append(nn.Conv2d(512, 256, kernel_size=3, padding=0))
        up3.append(nn.ReLU())
        up3.append(nn.BatchNorm2d(256))
        up3.append(nn.Conv2d(256, 256, kernel_size=3, padding=0))
        up3.append(nn.ReLU())
        up3.append(nn.BatchNorm2d(256))
        self.up3 = nn.Sequential(*up3)
        
        self.up2_x = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        up2 = []
        up2.append(nn.Conv2d(256, 128, kernel_size=3, padding=0))
        up2.append(nn.ReLU())
        up2.append(nn.BatchNorm2d(128))
        up2.append(nn.Conv2d(128, 128, kernel_size=3, padding=0))
        up2.append(nn.ReLU())
        up2.append(nn.BatchNorm2d(128))
        self.up2 = nn.Sequential(*up2)
        
        self.up1_x = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        up1 = []
        up1.append(nn.Conv2d(128, 64, kernel_size=3, padding=0))
        up1.append(nn.ReLU())
        up1.append(nn.BatchNorm2d(64))
        up1.append(nn.Conv2d(64, 64, kernel_size=3, padding=0))
        up1.append(nn.ReLU())
        up1.append(nn.BatchNorm2d(64))
        self.up1 = nn.Sequential(*up1)
        
        self.last = nn.Conv2d(64, n_classes, kernel_size=1)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]


    def forward(self, x):
        down1 = self.down1(x)
        x = F.avg_pool2d(down1, 2)
        
        down2 = self.down2(x)
        x = F.avg_pool2d(down2, 2)
        
        down3 = self.down3(x)
        x = F.avg_pool2d(down3, 2)
        
        down4 = self.down4(x)
        x = F.avg_pool2d(down4, 2)
        
        x = self.down5(x)
        
        x = self.up4_x(x)
        crop4 = self.center_crop(down4, x.shape[2:])
        x = torch.cat([x, crop4], 1)
        x = self.up4(x)
        
        x = self.up3_x(x)
        crop3 = self.center_crop(down3, x.shape[2:])
        x = torch.cat([x, crop3], 1)
        x = self.up3(x)
        
        x = self.up2_x(x)
        crop2 = self.center_crop(down2, x.shape[2:])
        x = torch.cat([x, crop2], 1)
        x = self.up2(x)
        
        x = self.up1_x(x)
        crop1 = self.center_crop(down1, x.shape[2:])
        x = torch.cat([x, crop1], 1)
        x = self.up1(x)
        
        x = self.last(x)
    
        return x


In [6]:
model = UNet_check()

In [13]:
output = model(image_572)

In [14]:
output.size()

torch.Size([1, 2, 388, 388])

In [15]:
image_572

tensor([[[[-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          ...,
          [-0.8678, -0.9192, -0.9534,  ..., -0.9020, -0.9363, -1.1247],
          [-0.8849, -0.9363, -0.9705,  ..., -0.8849, -0.9192, -1.0562],
          [-0.9020, -0.9534, -0.9705,  ..., -0.9020, -0.9534, -1.0048]],

         [[-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          ...,
          [-0.7752, -0.8277, -0.8627,  ..., -0.7927, -0.8277, -1.0203],
          [-0.7927, -0.8452, -0.8803,  ..., -0.7752, -0.8102, -0.9503],
          [-0.8102, -0.8627, -0.8803,  ..., -0.7927, -0.8452, -0.8978]],

         [[-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
          [-1.8044, -1.8044, -

In [16]:
model

UNet_check(
  (down1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (down2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (down3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5)

In [None]:
in_channels=1
n_classes=2
batch_norm=False
padding = False
up_mode='upconv'
wf = 6
depth = 5
prev_channels = 3
down_path = nn.ModuleList()
for i in range(depth):
    down_path.append(UNetConvBlock(prev_channels, 2**(wf+i), padding, batch_norm))
    prev_channels = 2**(wf+i)

up_path = nn.ModuleList()
for i in reversed(range(depth - 1)):
    up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode, padding, batch_norm))
    prev_channels = 2**(wf+i)

last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

In [None]:
x = image

blocks = []
for i, down in enumerate(down_path):
    x = down(x)
    print(x.size())
    if i != len(down_path)-1:
        blocks.append(x)
        x = F.avg_pool2d(x, 2)

In [None]:
x.size()

In [1]:
import numpy as np
import torch
import os
import pickle
import matplotlib.pyplot as plt
# %matplotlib inline
# plt.rcParams['figure.figsize'] = (20, 20)
# plt.rcParams['image.interpolation'] = 'bilinear'

from argparse import ArgumentParser

from torch.optim import SGD, Adam
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Normalize
from torchvision.transforms import ToTensor, ToPILImage
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import lr_scheduler

import collections
import numbers
import random
import math
from PIL import Image, ImageOps, ImageEnhance

In [2]:
image_path = '../../data/images_flip/train/'
mask_path = '../../data/images_flip/train_masks/'

In [3]:
with open('../../data/train_shuffle_names.pk', 'rb') as f:
    filenames = pickle.load(f)

In [4]:
class RandomCrop(object):
    def __init__(self, crop_size=512):
        self.crop_size = crop_size

    def __call__(self, img_and_label):
        img, label = img_and_label
        w, h = img.size
        
#         xmin = min_random(w)
#         ymin = min_random(h)
        xmin = random.randint(0, w - self.crop_size)
        ymin = random.randint(0, h - self.crop_size)
        
        img = img.crop((xmin, ymin, xmin+self.crop_size, ymin+self.crop_size))
        label = label.crop((xmin, ymin, xmin+self.crop_size, ymin+self.crop_size))
        
        return img, label

class RandomCrop_different_size_for_image_and_label(object):
    def __init__(self, image_size=572, label_size=388):
        self.image_size = image_size
        self.label_size = label_size
        self.bound = (self.image_size - self.label_size) // 2

    def __call__(self, img_and_label):
        img, label = img_and_label
        w, h = img.size
        
        xcenter = random.randint(self.label_size // 2, w - self.label_size // 2)
        ycenter = random.randint(self.label_size // 2, h - self.label_size // 2)
        
        img = img.crop((xcenter - self.image_size // 2, ycenter - self.image_size // 2, xcenter + self.image_size // 2, ycenter + self.image_size // 2))
        label = label.crop((xcenter - self.label_size // 2, ycenter - self.label_size // 2, xcenter + self.label_size // 2, ycenter + self.label_size // 2))
        
        return img, label
class ToTensor_Label(object):
    def __call__(self, img_and_label):
        img, label = img_and_label
        img_tensor = ToTensor()(img)
        label_tensor = torch.from_numpy(np.array(label)).long().unsqueeze(0)
        return img_tensor, label_tensor

class ImageNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    def __call__(self, img_and_label):
        img_tensor, label_tensor = img_and_label
        for t, m, s in zip(img_tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return img_tensor, label_tensor
    
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, x):
        for transform in self.transforms:
            x = transform(x)
        return x

In [11]:
crop_512 = RandomCrop(512)
crop_572 = RandomCrop_different_size_for_image_and_label(image_size=572, label_size=388)
to_tensor_label = ToTensor_Label()
normalize = ImageNormalize([.485, .456, .406], [.229, .224, .225])
transforms = Compose([crop_572, to_tensor_label, normalize])
transforms_512 = Compose([crop_512, to_tensor_label, normalize])

In [12]:
filename_img = filenames[5]
filename_mask = os.path.splitext(filename_img)[0]+'.png'

with open(os.path.join(image_path, filename_img), 'rb') as f:
    image = Image.open(f).convert('RGB')
with open(os.path.join(mask_path, filename_mask), 'rb') as f:
    label = Image.open(f).convert('P')

[image_572, label_388] = transforms([image, label])
[image_512, label_512] = transforms_512([image, label])
image_572 = image_572.unsqueeze(0)
image_512 = image_512.unsqueeze(0)

In [5]:
filename_img = filenames[5]
filename_mask = os.path.splitext(filename_img)[0]+'.png'

with open(os.path.join(image_path, filename_img), 'rb') as f:
    image = Image.open(f).convert('RGB')

In [7]:
image.size

(1918, 1280)