In [2]:
import numpy as np
import cv2
import tensorflow as tf
import os
import random

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [11]:
import matplotlib.pyplot as plt

TFRecord_path = './TFRecord'
save_path = os.path.join(TFRecord_path, 'train.tfrecord')
writer = tf.python_io.TFRecordWriter(save_path)

def load_image(path, index, load_label=True):
    img_path = os.path.join(path, 'src/%d.png'%index)
    image = cv2.imread(img_path, flags=cv2.IMREAD_UNCHANGED)
    image = image[:, :, 0:3]
    image = np.uint8(image)
    
    if load_label:
        label_path = os.path.join(path, 'label/%d.png'%index)
        label = cv2.imread(label_path)
        label = label[:,:,0]
        label = np.uint8(label)
        return image, label
    return image


def random_patch(image, label, patch_size):
    row = image.shape[0]
    col = image.shape[1]
    r = random.randint(0, row-patch_size)
    c = random.randint(0, col-patch_size)
    sub_image = image[r:r+patch_size, c:c+patch_size]
    sub_label = label[r:r+patch_size, c:c+patch_size]
    return sub_image, sub_label


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def save_image(sub_image, sub_label, writer, augment=False):
    image = sub_image
    label = sub_label
    if augment:
        for i in range(2):
            for j in range(4):
                image_raw = sub_image.tostring()
                label_raw = sub_label.tostring()
                row = sub_image.shape[0]
                col = sub_image.shape[1]
                example = tf.train.Example(features=tf.train.Features(feature={
                    'row': _int64_feature(row),
                    'col': _int64_feature(col),
                    'image_raw': _bytes_feature(image_raw),
                    'label_raw': _bytes_feature(label_raw),
                }))
                writer.write(example.SerializeToString())
                image = np.rot90(image)
                label = np.rot90(label)
            image = np.fliplr(image)
            label = np.fliplr(label)
    else:
        image_raw = sub_image.tostring()
        label_raw = sub_label.tostring()
        row = sub_image.shape[0]
        col = sub_image.shape[1]
        example = tf.train.Example(features=tf.train.Features(feature={
            'row': _int64_feature(row),
            'col': _int64_feature(col),
            'image_raw': _bytes_feature(image_raw),
            'label_raw': _bytes_feature(label_raw),
        }))
        writer.write(example.SerializeToString())


patch_size = 256

data_size = 2
sample_size = 1024
path = './data/'
for i in range(1, 6):
    image, label = load_image(path, i)
    
    row = image.shape[0]
    col = image.shape[1]
    num = np.int64((row/sample_size+1) * (col/sample_size+1))*4
    for j in range(0, num):
        sub_image, sub_label = random_patch(image, label, sample_size)
        sub_image = cv2.resize(sub_image, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC)
        sub_label = cv2.resize(sub_label, (patch_size, patch_size), interpolation=cv2.INTER_NEAREST)
        save_image(sub_image, sub_label, writer, augment=True)
        print('NO.%d patch in src-%d is saving...'%(j, i))

writer.close()

for i in range(1, 6):
    image, label = load_image(path, i)
    save_image(image, label, writer)
    print('NO.%d valid sample is saving...'%i)

NO.0 patch in src-1 is saving...
NO.1 patch in src-1 is saving...
NO.2 patch in src-1 is saving...
NO.3 patch in src-1 is saving...
NO.4 patch in src-1 is saving...
NO.5 patch in src-1 is saving...
NO.6 patch in src-1 is saving...
NO.7 patch in src-1 is saving...
NO.8 patch in src-1 is saving...
NO.9 patch in src-1 is saving...
NO.10 patch in src-1 is saving...
NO.11 patch in src-1 is saving...
NO.12 patch in src-1 is saving...
NO.13 patch in src-1 is saving...
NO.14 patch in src-1 is saving...
NO.15 patch in src-1 is saving...
NO.16 patch in src-1 is saving...
NO.17 patch in src-1 is saving...
NO.18 patch in src-1 is saving...
NO.19 patch in src-1 is saving...
NO.20 patch in src-1 is saving...
NO.21 patch in src-1 is saving...
NO.22 patch in src-1 is saving...
NO.23 patch in src-1 is saving...
NO.24 patch in src-1 is saving...
NO.25 patch in src-1 is saving...
NO.26 patch in src-1 is saving...
NO.27 patch in src-1 is saving...
NO.28 patch in src-1 is saving...
NO.29 patch in src-1 is 

NO.42 patch in src-3 is saving...
NO.43 patch in src-3 is saving...
NO.44 patch in src-3 is saving...
NO.45 patch in src-3 is saving...
NO.46 patch in src-3 is saving...
NO.47 patch in src-3 is saving...
NO.48 patch in src-3 is saving...
NO.49 patch in src-3 is saving...
NO.50 patch in src-3 is saving...
NO.51 patch in src-3 is saving...
NO.52 patch in src-3 is saving...
NO.53 patch in src-3 is saving...
NO.54 patch in src-3 is saving...
NO.55 patch in src-3 is saving...
NO.56 patch in src-3 is saving...
NO.57 patch in src-3 is saving...
NO.58 patch in src-3 is saving...
NO.59 patch in src-3 is saving...
NO.60 patch in src-3 is saving...
NO.61 patch in src-3 is saving...
NO.62 patch in src-3 is saving...
NO.63 patch in src-3 is saving...
NO.64 patch in src-3 is saving...
NO.65 patch in src-3 is saving...
NO.66 patch in src-3 is saving...
NO.67 patch in src-3 is saving...
NO.68 patch in src-3 is saving...
NO.69 patch in src-3 is saving...
NO.70 patch in src-3 is saving...
NO.71 patch in

NO.166 patch in src-4 is saving...
NO.167 patch in src-4 is saving...
NO.168 patch in src-4 is saving...
NO.169 patch in src-4 is saving...
NO.170 patch in src-4 is saving...
NO.171 patch in src-4 is saving...
NO.172 patch in src-4 is saving...
NO.173 patch in src-4 is saving...
NO.174 patch in src-4 is saving...
NO.175 patch in src-4 is saving...
NO.176 patch in src-4 is saving...
NO.177 patch in src-4 is saving...
NO.178 patch in src-4 is saving...
NO.179 patch in src-4 is saving...
NO.180 patch in src-4 is saving...
NO.181 patch in src-4 is saving...
NO.182 patch in src-4 is saving...
NO.183 patch in src-4 is saving...
NO.184 patch in src-4 is saving...
NO.185 patch in src-4 is saving...
NO.186 patch in src-4 is saving...
NO.187 patch in src-4 is saving...
NO.188 patch in src-4 is saving...
NO.189 patch in src-4 is saving...
NO.190 patch in src-4 is saving...
NO.191 patch in src-4 is saving...
NO.192 patch in src-4 is saving...
NO.193 patch in src-4 is saving...
NO.194 patch in src-

NO.105 patch in src-5 is saving...
NO.106 patch in src-5 is saving...
NO.107 patch in src-5 is saving...
NO.108 patch in src-5 is saving...
NO.109 patch in src-5 is saving...
NO.110 patch in src-5 is saving...
NO.111 patch in src-5 is saving...
NO.112 patch in src-5 is saving...
NO.113 patch in src-5 is saving...
NO.114 patch in src-5 is saving...
NO.115 patch in src-5 is saving...
NO.116 patch in src-5 is saving...
NO.117 patch in src-5 is saving...
NO.118 patch in src-5 is saving...
NO.119 patch in src-5 is saving...
NO.120 patch in src-5 is saving...
NO.121 patch in src-5 is saving...
NO.122 patch in src-5 is saving...
NO.123 patch in src-5 is saving...
NO.124 patch in src-5 is saving...
NO.125 patch in src-5 is saving...
NO.126 patch in src-5 is saving...
NO.127 patch in src-5 is saving...
NO.128 patch in src-5 is saving...
NO.129 patch in src-5 is saving...
NO.130 patch in src-5 is saving...
NO.131 patch in src-5 is saving...
NO.132 patch in src-5 is saving...
NO.133 patch in src-

In [None]:
# prepare res_unet
def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial_value=initial,name=name)


def bias_variable(shape, name):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial_value=initial,name=name)

def conv2d(x, w, s=1):
    return tf.nn.conv2d(x, w, strides=[1, s, s, 1], padding='SAME')

def deconv2d(x,w):
    return tf.nn.conv2d_transpose(x,w,strides=[1,1,1,1],padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                            strides=[1, 2, 2, 1], padding='SAME')

def batchnorm(x):
    mean,variance=tf.nn.moments(x,[0,1,2,3])
    return tf.nn.batch_normalization(x,
                                    mean=mean,
                                    variance=variance,
                                    offset=0,
                                    scale=1,
                                    variance_epsilon=1e-6)

def conv_layer(x, input_channel, output_channel,k_size=3, relu=True,stride=1,bn=True,name='conv_layer'):
    with tf.name_scope(name):
        w = weight_variable([k_size,k_size,input_channel,output_channel],'weight')
        b = bias_variable([output_channel],'bias')
        answer = conv2d(x,w,s=stride)+b
        if bn:
            answer = batchnorm(answer)
        if relu:
            answer = tf.nn.relu(answer)
        return answer

def res_conv_layer(x, input_channel, output_channel,relu=True,stride=1,name='res_conv_layer'):
    with tf.name_scope(name):
        if input_channel == output_channel and stride == 1:
            conv1 = conv_layer(x,input_channel,output_channel,name='conv1')
            conv2 = conv_layer(conv1,output_channel,output_channel,name='conv2')
            conv3 = conv_layer(conv2,output_channel,output_channel,name='conv3')
            answer = conv3+x
            if relu:
                return tf.nn.relu(answer)
            else:
                return answer
        else:
            conv1 = conv_layer(x,input_channel,output_channel,name='conv1',stride=stride)
            conv2 = conv_layer(conv1,output_channel,output_channel,name='conv2')
            conv3 = conv_layer(conv2,output_channel,output_channel,name='conv3',relu=False)
            conv1_ = conv_layer(x,input_channel,output_channel,name='conv1_',relu=False,stride=stride)
            answer = conv1_+conv3
            if relu:
                return tf.nn.relu(answer)
            else:
                return answer

def Fully_ResNet(x, class_num):
    repeat = 16
    layer_num = 128
    input_row = x.shape[1]
    input_col = x.shape[2]
    input_channel = x.shape[3]
    if [input_row,input_col,input_channel] != [256,256,3]:
        print('U_Net input error: the size of input not matched\n')
        return
    net=batchnorm(x)
    net= res_conv_layer(net,3,64,name='res1')
    for i in range(0,repeat):
        name = 'res1_'+str(i)
        net = res_conv_layer(net,64,64,name=name)
    net = res_conv_layer(net,64,class_num,name='res2',relu=False)
    return net

def UNet_ResNet(x, class_num):
    input_row = x.shape[1]
    input_col = x.shape[2]
    input_channel = x.shape[3]
    if [input_row,input_col,input_channel] != [256,256,3]:
        print('U_Net input error: the size of input not matched\n')
        return
    #norm=batchnorm(x)
    net_res_conv1 = res_conv_layer(x,3,64,name='res_conv1',relu=True,stride=1)	#256x256x64
    net_res_conv2 = res_conv_layer(net_res_conv1,64, 128,name='res_conv2',relu=True,stride=2)	#128x128x128
    net_res_conv3 = res_conv_layer(net_res_conv2,128,256,name='res_conv3',relu=True,stride=2)	#64x64x256
    net_res_conv4 = res_conv_layer(net_res_conv3,256,512,name='res_conv4',relu=True,stride=2)	#32x32x512
    net_res_conv5 = res_conv_layer(net_res_conv4,512,512,name='res_conv5',relu=True,stride=1)	#32x32x512

    net_up6 = tf.image.resize_bilinear(net_res_conv5,[64,64],name='upsample1')	#64x64x512
    net_res_conv3_cut = res_conv_layer(net_res_conv3, 256,512,name='res_conv3_cut',relu=True,stride=1) #64x64x512
    net_fp6 = net_up6 + net_res_conv3_cut	#64x64x512
    net_res_conv6 = res_conv_layer(net_fp6,512,512,name='res_conv6',relu=True,stride=1)	#64x64x512

    net_up7 = tf.image.resize_bilinear(net_res_conv6,[128,128],name='upsample2')	#128x128x512
    net_res_conv2_cut = res_conv_layer(net_res_conv2, 128,512,name='res_conv2_cut',relu=True,stride=1)	#128x128x512
    net_fp7 = net_up7 + net_res_conv2_cut	#128x128x512
    net_res_conv7 = res_conv_layer(net_fp7,512,512,name='res_conv7',relu=True,stride=1)	#128x128x512

    net_up8 = tf.image.resize_bilinear(net_res_conv7,[256,256],name='upsample3')	#256x256x512
    net_res_conv1_cut = res_conv_layer(net_res_conv1, 64, 512,name='res_conv1_cut',relu=True,stride=1)	#256x256x512
    net_fp8 = net_up8 + net_res_conv1_cut	#256x256x512
    net_res_conv8 = res_conv_layer(net_fp8,512,512,name='res_conv8',relu=True,stride=1)	#256x256x512
    net_fc = conv_layer(net_res_conv8,512,class_num,k_size=1,name='fc',relu=False,bn=False,stride=1)

    return net_fc


def U_Net(x, class_num):
    # H W C
    input_row = x.shape[1]
    input_col = x.shape[2]
    input_channel = x.shape[3]
    if [input_row,input_col,input_channel] != [256,256,3]:
        print('U_Net input error: the size of input not matched\n')
        return
    norm=batchnorm(x)
    net_conv1 = conv_layer(norm,3,64,name='conv1') #256x256
    net_conv2 = conv_layer(net_conv1,64,64,name='conv2')
    net_pool1 = max_pool_2x2(net_conv2)

    net_conv3 = conv_layer(net_pool1,64,128,name='conv3') #128x128
    net_conv4 = conv_layer(net_conv3,128,128,name='conv4')
    net_pool2 = max_pool_2x2(net_conv4)
    
    net_conv5 = conv_layer(net_pool2,128,256,name='conv5') #64x64
    net_conv6 = conv_layer(net_conv5,256,256,name='conv6')
    net_pool3 = max_pool_2x2(net_conv6)

    net_conv7 = conv_layer(net_pool3,256,512,name='conv7')#32x32
    net_conv8 = conv_layer(net_conv7,512,512,name='conv8')

    net_conv9 = conv_layer(net_conv8,512,256,name='conv9')
    net_up1 = tf.image.resize_bilinear(net_conv9, [64,64],name='upsample1') #64x64
    net_concat1 = tf.concat([net_up1,net_conv6],axis=-1,name='concat1')
    net_conv10 = conv_layer(net_concat1,512,256,name='conv10')
    net_conv11 = conv_layer(net_conv10,256,256,name='conv11')

    net_conv12 = conv_layer(net_conv11,256,128,name='conv12')
    net_up2 = tf.image.resize_bilinear(net_conv12,[128,128],name='upsample2') #128x128
    net_concat2 = tf.concat([net_up2,net_conv4],axis=-1,name='concat2')
    net_conv13 = conv_layer(net_concat2,256,128,name='conv13')
    net_conv14 = conv_layer(net_conv13,128,128,name='conv14')

    net_conv15 = conv_layer(net_conv14,128,64,name='conv15')
    net_up3 = tf.image.resize_bilinear(net_conv15,[256,256],name='upsample3') #256x256
    net_concat3 = tf.concat([net_up3,net_conv2],axis=-1,name='concat3')
    net_conv16 = conv_layer(net_concat3,128,64,name='conv16')
    net_conv17 = conv_layer(net_conv16,64,64,name='conv17')

    net_conv18 = conv_layer(net_conv17,64,class_num,k_size=1,name='conv18',relu=False)

    return net_conv18