<a href="https://colab.research.google.com/github/yasohasakii/unet-segmentation/blob/master/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
!rm -rf *
!git clone https://github.com/yasohasakii/unet-segmentation.git
!cp -r unet-segmentation/* ./
!rm -rf unet-segmentation/

Cloning into 'unet-segmentation'...
remote: Enumerating objects: 33, done.[K
remote: Counting objects: 100% (33/33), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 233 (delta 8), reused 0 (delta 0), pack-reused 200
Receiving objects: 100% (233/233), 513.91 MiB | 32.54 MiB/s, done.
Resolving deltas: 100% (12/12), done.
Checking out files: 100% (220/220), done.


In [0]:
import os
import sys
import random

import numpy as np
import cv2
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
from keras.callbacks import ModelCheckpoint

In [27]:
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)
    return c, p

def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
    us = keras.layers.UpSampling2D((2, 2))(x)
    concat = keras.layers.Concatenate()([us, skip])
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    return c

def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    return c
def UNet():
    f = [16, 32, 64, 128, 256]
    inputs = keras.layers.Input((512, 512, 3))
    
    p0 = inputs
    c1, p1 = down_block(p0, f[0]) #128 -> 64
    c2, p2 = down_block(p1, f[1]) #64 -> 32
    c3, p3 = down_block(p2, f[2]) #32 -> 16
    c4, p4 = down_block(p3, f[3]) #16->8
    
    bn = bottleneck(p4, f[4])
    
    u1 = up_block(bn, c4, f[3]) #8 -> 16
    u2 = up_block(u1, c3, f[2]) #16 -> 32
    u3 = up_block(u2, c2, f[1]) #32 -> 64
    u4 = up_block(u3, c1, f[0]) #64 -> 128
    
    outputs = keras.layers.Conv2D(1, (1, 1), padding="same", activation="sigmoid")(u4)
    model = keras.models.Model(inputs, outputs)
    return model
model = UNet()
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 512, 512, 16) 448         input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 512, 512, 16) 2320        conv2d_19[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 256, 256, 16) 0           conv2d_20[0][0]                  
____________________________________________________________________________________________

In [0]:
class DataGen(keras.utils.Sequence):
    def __init__(self, path, batch_size=1, image_size=512):
        self.path = path
        self.batch_size = batch_size
        self.image_size = image_size
        files = os.listdir(self.path)
        files = [os.path.join(self.path,x) for x in files]
        self.trains, self.vals = train_test_split(files, test_size=0.1, random_state=42)
    
    def generate(self,files): 
        random.shuffle(files)
        while True:
            image_batch = np.zeros([self.batch_size,self.image_size,self.image_size,3])
            label_batch = np.zeros([self.batch_size,self.image_size,self.image_size])
            index = random.randint(0,len(files)-self.batch_size)
            for i,img in enumerate(files[index:index+self.batch_size]):
        
                ## Reading Image
                image = cv2.imread(img)
                image = cv2.resize(image, (self.image_size, self.image_size))
        
        
                _mask_image = cv2.imread(img.replace('raw','label'))
                _mask_image = cv2.cvtColor(_mask_image,cv2.COLOR_BGR2GRAY)
                _mask_image = cv2.resize(_mask_image, (self.image_size, self.image_size)) #128x128
                mask = np.expand_dims(_mask_image, axis=0)
            
                ## Normalizaing 
                image = image/255.0
                mask = mask/255.0
            image_batch[i]=image
            label_batch[i]=mask
        
            yield image_batch, label_batch

In [0]:
train_path = '/content/raw'
batch_size= 1
gen = DataGen( train_path, image_size=512, batch_size=batch_size)
train_gen = gen.generate(gen.trains)
val_gen = gen.generate(gen.vals)


train_steps = len(gen.trains)//batch_size
valid_steps = len(gen.vals)//batch_size

In [0]:
model_checkpoint = ModelCheckpoint('unet_membrane.hdf5', monitor='val_mean_absolute_error',mode='min',verbose=1, save_best_only=True)
model.fit_generator(train_gen,steps_per_epoch=train_steps,epochs=60,
                    callbacks=[model_checkpoint],
                    validation_data = val_gen,validation_steps = valid_steps)

Epoch 1/60
10/90 [==>...........................] - ETA: 18s - loss: 0.0385 - mean_absolute_error: 0.0387
Epoch 00001: val_mean_absolute_error improved from inf to 0.03874, saving model to unet_membrane.hdf5
Epoch 2/60
10/90 [==>...........................] - ETA: 19s - loss: 0.0980 - mean_absolute_error: 0.0984
Epoch 00002: val_mean_absolute_error did not improve from 0.03874
Epoch 3/60
10/90 [==>...........................] - ETA: 17s - loss: 0.0792 - mean_absolute_error: 0.0796
Epoch 00003: val_mean_absolute_error did not improve from 0.03874
Epoch 4/60
10/90 [==>...........................] - ETA: 17s - loss: 0.0442 - mean_absolute_error: 0.0444
Epoch 00004: val_mean_absolute_error did not improve from 0.03874
Epoch 5/60
10/90 [==>...........................] - ETA: 17s - loss: 0.0387 - mean_absolute_error: 0.0388
Epoch 00005: val_mean_absolute_error did not improve from 0.03874
Epoch 6/60
10/90 [==>...........................] - ETA: 19s - loss: 0.1471 - mean_absolute_error: 0.147

In [0]:
from PIL import Image
import glob, cv2
import matplotlib
matplotlib.use('Agg')

model.load_weights('/content/unet_membrane.hdf5')

def predict(image):
    image = np.array(image,np.float)/255.0
    image = np.expand_dims(image,axis=0)
    pred = model.predict(image)[0]
    pred = cv2.merge([pred,pred,pred])
    return pred

def plot_result(img):
    imgname = os.path.basename(img)
    print(imgname)
    image = Image.open(img)
    h,w = image.size
    copy = image.resize((1024,1024))
    copy = np.array(copy,np.float)
    pred = predict(copy)
    pred = cv2.resize(pred,(h,w))
    blend = np.array(image)*pred
    blend = np.asarray(blend,np.uint8)
    savedir = 'unet-result'
    if not os.path.isdir(savedir):
        os.makedirs(savedir)
    blend = Image.fromarray(blend)
    blend.save(os.path.join(savedir,imgname))


if __name__ == '__main__':
    images = glob.glob('/content/test/*.png')
    for image in images:
        plot_result(image)