<a href="https://colab.research.google.com/github/yasohasakii/unet-segmentation/blob/master/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
!rm -rf *
!git clone https://github.com/yasohasakii/unet-segmentation.git
!cp -r unet-segmentation/* ./
!rm -rf unet-segmentation/

Cloning into 'unet-segmentation'...
remote: Enumerating objects: 33, done.[K
remote: Counting objects: 100% (33/33), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 233 (delta 8), reused 0 (delta 0), pack-reused 200
Receiving objects: 100% (233/233), 513.91 MiB | 32.54 MiB/s, done.
Resolving deltas: 100% (12/12), done.
Checking out files: 100% (220/220), done.


In [0]:
import os
import sys
import random

import numpy as np
import cv2
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
from keras.callbacks import ModelCheckpoint,EarlyStopping
from keras import backend as K
from PIL import Image

In [68]:
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)
    return c, p

def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
    us = keras.layers.UpSampling2D((2, 2))(x)
    concat = keras.layers.Concatenate()([us, skip])
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    return c

def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    return c
def UNet():
    f = [16, 32, 64, 128, 256]
    inputs = keras.layers.Input((512, 512, 3))
    
    p0 = inputs
    c1, p1 = down_block(p0, f[0]) #128 -> 64
    c2, p2 = down_block(p1, f[1]) #64 -> 32
    c3, p3 = down_block(p2, f[2]) #32 -> 16
    c4, p4 = down_block(p3, f[3]) #16->8
    
    bn = bottleneck(p4, f[4])
    
    u1 = up_block(bn, c4, f[3]) #8 -> 16
    u2 = up_block(u1, c3, f[2]) #16 -> 32
    u3 = up_block(u2, c2, f[1]) #32 -> 64
    u4 = up_block(u3, c1, f[0]) #64 -> 128
    
    out = keras.layers.Conv2D(1, (1, 1), padding="same", activation="sigmoid")(u4)
    outputs = keras.layers.Reshape((512,512))(out)
    model = keras.models.Model(inputs, outputs)
    return model
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
model = UNet()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=[dice_coef])
model.summary()

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_95 (Conv2D)              (None, 512, 512, 16) 448         input_6[0][0]                    
__________________________________________________________________________________________________
conv2d_96 (Conv2D)              (None, 512, 512, 16) 2320        conv2d_95[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_20 (MaxPooling2D) (None, 256, 256, 16) 0           conv2d_96[0][0]                  
____________________________________________________________________________________________

In [0]:
class DataGen(keras.utils.Sequence):
    def __init__(self, path, batch_size=1, image_size=512):
        self.path = path
        self.batch_size = batch_size
        self.image_size = image_size
        files = os.listdir(self.path)
        files = [os.path.join(self.path,x) for x in files]
        self.trains, self.vals = train_test_split(files, test_size=0.1, random_state=42)
    
    def generate(self,files): 
        random.shuffle(files)
        while True:
            image_batch = np.zeros([self.batch_size,self.image_size,self.image_size,3])
            label_batch = np.zeros([self.batch_size,self.image_size,self.image_size])
            index = random.randint(0,len(files)-self.batch_size)
            for i,img in enumerate(files[index:index+self.batch_size]):
        
                ## Reading Image
                image = Image.open(img)
                image = image.resize((self.image_size, self.image_size))
                image = np.array(image)
        
                _mask_image = Image.open(img.replace('raw','label'))
                _mask_image = _mask_image.convert('L')
                _mask_image = _mask_image.resize((self.image_size, self.image_size)) #128x128
                mask = np.array(_mask_image)
            
                ## Normalizaing 
                image = image/255.0
                mask = mask/255.0
                # print(np.max(mask))
            image_batch[i]=image
            label_batch[i]=mask
        
            yield image_batch, label_batch

In [0]:
train_path = '/content/raw'
batch_size= 1
gen = DataGen( train_path, image_size=512, batch_size=batch_size)
train_gen = gen.generate(gen.trains)
val_gen = gen.generate(gen.vals)


train_steps = len(gen.trains)//batch_size
valid_steps = len(gen.vals)//batch_size

In [70]:
model_checkpoint = ModelCheckpoint('unet_membrane.h5', monitor='val_dice_coef',mode='max',verbose=1, save_best_only=True)
earlystop = EarlyStopping(monitor = 'val_dice_coef',patience=5,mode = 'max')
h = model.fit_generator(train_gen,steps_per_epoch=train_steps,epochs=100,
                    callbacks=[model_checkpoint,earlystop],
                    validation_data = val_gen,validation_steps = valid_steps)

Epoch 1/100
10/90 [==>...........................] - ETA: 22s - loss: 0.3299 - dice_coef: 0.1223
Epoch 00001: val_dice_coef improved from -inf to 0.12229, saving model to unet_membrane.h5
Epoch 2/100
10/90 [==>...........................] - ETA: 21s - loss: 0.3392 - dice_coef: 0.1313
Epoch 00002: val_dice_coef improved from 0.12229 to 0.13131, saving model to unet_membrane.h5
Epoch 3/100
10/90 [==>...........................] - ETA: 20s - loss: 0.3481 - dice_coef: 0.1346
Epoch 00003: val_dice_coef improved from 0.13131 to 0.13464, saving model to unet_membrane.h5
Epoch 4/100
10/90 [==>...........................] - ETA: 21s - loss: 0.2259 - dice_coef: 0.1051
Epoch 00004: val_dice_coef did not improve from 0.13464
Epoch 5/100
10/90 [==>...........................] - ETA: 21s - loss: 0.2065 - dice_coef: 0.1039
Epoch 00005: val_dice_coef did not improve from 0.13464
Epoch 6/100
10/90 [==>...........................] - ETA: 22s - loss: 0.3687 - dice_coef: 0.1726
Epoch 00006: val_dice_coef 

In [72]:
from PIL import Image
import glob, cv2
import matplotlib
matplotlib.use('Agg')

model.load_weights('/content/unet_membrane.h5')

def predict(image):
    image = np.array(image,np.float)/255.0
    image = np.expand_dims(image,axis=0)
    pred = model.predict(image)[0]
    pred = (pred-np.min(pred))/(np.max(pred)-np.min(pred))
    pred = cv2.merge([pred,pred,pred])
    return pred

def plot_result(img):
    imgname = os.path.basename(img)
    print(imgname)
    image = Image.open(img)
    h,w = image.size
    copy = image.resize((512,512))
    copy = np.array(copy,np.float)
    pred = predict(copy)
    print(np.max(pred))
    pred = cv2.resize(pred,(h,w))
    blend = np.array(image)*pred
    blend = np.asarray(blend,np.uint8)
    savedir = 'unet-result'
    if not os.path.isdir(savedir):
        os.makedirs(savedir)
    blend = Image.fromarray(blend)
    blend.save(os.path.join(savedir,imgname))


if __name__ == '__main__':
    images = glob.glob('/content/test/*.png')
    for image in images:
        plot_result(image)

671a56ce44a141acb59d6e10b28ddb3f.png
1.0
8396aaabe1ab42439cb2c8838cd3d783.png
1.0
ca967714d16b464aa8bad0bebd07687a.png
1.0
5221cf979fe645959c6e45e523092145.png
1.0
a34393704d624e0c9430e012a73b6b02.png
1.0
09816413ea8f42d88479f300d689fb51.png
1.0
566e58ce8e874cca80a2cee472361529.png
1.0
e5f90522a6084d3c9b9f52117e53ac4d.png
1.0
1135c7a4d9a84e1fbf60f60a34030267.png
1.0
f9a297618ecb44c0b9a74d4c863653b3.png
1.0
ca708ea4d3124568b9df41610be9f001.png
1.0
dbb4b5b5679441f4bbc643de599dc221.png
1.0
1fa17f9a553b4c95bf92442e6dc65340.png
1.0
f90f388214f24b3d944710de566c0705.png
1.0
bf269b76ec6c477fa30d07a6f61dce4b.png
1.0
