### Imports

In [None]:
import scipy.io as scio
import numpy as np    
import matplotlib.pyplot as plt
import sys
import os
import math
import pprint
import cv2
from scipy.misc import imsave
from helper import *

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
# plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
import skimage.restoration as sr
import numpy as np
import glob
# import h5py
import os
import scipy.io as scio
from skimage import exposure
from skimage.io import imsave, imread
from scipy.misc import imresize
from scipy.io import savemat
from scipy import ndimage, misc
import matplotlib.pyplot as plt
%matplotlib inline
import re

from helper import *

In [None]:
import keras
from keras.layers import Activation
from keras.layers import Conv2D, MaxPooling2D
from keras.models import Model
from keras.layers import Input
from keras.layers import BatchNormalization
from keras.layers import UpSampling2D
from keras.layers import Concatenate
from keras.layers import Lambda 
from keras.utils import to_categorical
import tensorflow as tf

from keras.layers import Reshape

from keras import backend as K
from keras import regularizers, optimizers
# %matplotlib inline

In [None]:
def get_info(filenames, root, ext):
    images = []
    for filename in filenames :
        filepath = os.path.join(root,filename)
        if ext == '.npy':
            image = np.load(filepath)
            h,w = image.shape
            
            if h != 512 or w != 64:
#                 print(h,w) 
                amount = 512 - h
                id_full = np.full((amount, 64), 0)
                image = np.concatenate((image, id_full))
#                 print(image.shape)
        elif ext == '.JPG' or ext == '.tif' or ext =='.png':
            image = ndimage.imread(filepath, mode = "L")
        images.append(image)
    return images

### Folder For Training Files

In [None]:
# Setting the directories
import os


wanted_folder = 'alldata/'
# wanted_folder = 'pruned/'
# wanted_folder = 'Atrium/'
# wanted_folder = 'Ventricle/'

cwd = os.getcwd()
cwd = cwd + '/datasets/OCTData/'+wanted_folder
print(cwd)

#### Raw Files

In [None]:
whole_raw_image_folder = cwd + 'whole_raw_image/'
print(whole_raw_image_folder)

root_path = ""
filenames = []
for files in os.listdir(whole_raw_image_folder):
    for ext in ['.tif', '.jpg', '.JPG', '.png', '.npy', '.DS_Store']: 
        if files.endswith(ext):
            filenames.append(files)
filenames = sorted(filenames)
print(len(filenames))

In [None]:
raw_images = get_info(filenames, whole_raw_image_folder, '.tif')

In [None]:
print (len(raw_images))
plt.imshow(raw_images[0], cmap='gray')

### Labels

In [None]:
manual_label_folder = cwd + 'manual_label/'
print(whole_raw_image_folder)

root_path = ""
filenames = []
for files in os.listdir(manual_label_folder):
    for ext in ['.tif', '.jpg', '.JPG', '.png', '.npy', '.DS_Store']: 
        if files.endswith(ext):
            filenames.append(files)
filenames = sorted(filenames)
print(len(filenames))

In [None]:
labels = get_info(filenames, manual_label_folder, '.JPG')

In [None]:
print (len(labels))
plt.imshow(labels[0])

### Ids

In [None]:
ids_folder = cwd + 'png_labels_method/'
print(ids_folder)

root_path = ""
filenames = []
for files in os.listdir(ids_folder):
    for ext in ['.tif', '.jpg', '.JPG', '.png', '.npy', '.DS_Store']: 
        if files.endswith(ext):
            filenames.append(files)
filenames = sorted(filenames)
print(len(filenames))

In [None]:
ids = get_info(filenames, ids_folder, '.png')
print (len(labels))
plt.imshow(ids[0])

In [None]:
h,w = 512, 600
data_shape = h*w
weight_decay = 0.0001
# Defines the input tensor
inputs = Input(shape=(h,w,1))

L1 = Conv2D(64,kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))(inputs)
L2 = BatchNormalization()(L1)
L2 = Activation('relu')(L2)
#L3 = Lambda(maxpool_1,output_shape = shape)(L2)
L3 = MaxPooling2D(pool_size=(2,2))(L2)
L4 = Conv2D(64,kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))(L3)
L5 = BatchNormalization()(L4)
L5 = Activation('relu')(L5)
#L6 = Lambda(maxpool_2,output_shape = shape)(L5)
L6 = MaxPooling2D(pool_size=(2,2))(L5)
L7 = Conv2D(64,kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))(L6)
L8 = BatchNormalization()(L7)
L8 = Activation('relu')(L8)
#L9 = Lambda(maxpool_3,output_shape = shape)(L8)
L9 = MaxPooling2D(pool_size=(2,2))(L8)
L10 = Conv2D(64,kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))(L9)
L11 = BatchNormalization()(L10)
L11 = Activation('relu')(L11)
L12 = UpSampling2D(size = (2,2))(L11)
#L12 = Lambda(unpool_3,output_shape = unpool_shape)(L11)
L13 = Concatenate(axis = 3)([L8,L12])
L14 = Conv2D(64,kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))(L13)
L15 = BatchNormalization()(L14)
L15 = Activation('relu')(L15)
L16 = UpSampling2D(size= (2,2))(L15)
#L16 = Lambda(unpool_2,output_shape=unpool_shape)(L15)
L17 = Concatenate(axis = 3)([L16,L5])
L18 = Conv2D(64,kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))(L17)
L19 = BatchNormalization()(L18)
L19 = Activation('relu')(L19)
#L20 = Lambda(unpool_1,output_shape=unpool_shape)(L19)
L20 = UpSampling2D(size=(2,2),name = "Layer19")(L19)
L21 = Concatenate(axis=3)([L20,L2])
L22 = Conv2D(64,kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))(L21)
L23 = BatchNormalization()(L22)
L23 = Activation('relu')(L23)
L24 = Conv2D(8,kernel_size=(1,1),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))(L23)
L = Reshape((data_shape,8),input_shape = (h,w,8))(L24)
L = Activation('softmax')(L)
model = Model(inputs = inputs, outputs = L)
# model.summary()

In [None]:
def dice_coef(y_true, y_pred):
    '''
    y_true = label
    y_pred = prediction
    '''
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

In [None]:
def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [None]:
def customized_loss(y_true,y_pred):
    cross_ent = K.categorical_crossentropy(y_true, y_pred)
    loss_dice_coef = dice_coef_loss(y_true, y_pred)
    return (1 * cross_ent)+(0.5*loss_dice_coef)

In [None]:
smooth = 1

In [None]:
lrs = [0.01]
bs = 40
epoch = 100
for i in lrs:
    optimiser = optimizers.Adam(lr = i)
    model.compile(optimizer=optimiser,loss=customized_loss,metrics=['accuracy',dice_coef],sample_weight_mode='temporal')

In [None]:
# model.load_weights("/home/sim/notebooks/relaynet_pytorch/"+saved_name+".hdf5")
# saved_name = 'notnormalised_bs50_ep_500_01'
saved_name = 'notnormalised_bs40_ep_150'
model.load_weights("/home/sim/notebooks/relaynet_pytorch/models/Normalised/"+saved_name+".hdf5")

In [None]:
SEG_LABELS_LIST = [
#     {"id": -1, "name": "void", "rgb_values": [0, 0, 0]},
    {"id": 0, "name": "void", "rgb_values": [0, 0, 0]}, # black
    {"id": 1, "name": "Myocardium", "rgb_values": [255,0,0]}, # red
    {"id": 2, "name": "Endocardium", "rgb_values": [0, 0, 255]}, # blue
    {"id": 3, "name": "Fibrosis", "rgb_values": [177,10,255]}, # purple
    {"id": 4, "name": "Fat", "rgb_values": [0, 255, 0]}, # green
    {"id": 5, "name": "Dense Collagen", "rgb_values": [255, 140, 0]}, # orange
    {"id": 6, "name": "Loose Collagen", "rgb_values": [255, 255, 0]}, # yellow
    {"id": 7, "name": "Smooth Muscle", "rgb_values": [255,0,255]}# magenta/pink
]; 

def label_img_to_rgb(label_img):
    label_img = np.squeeze(label_img)
    labels = np.unique(label_img)
    label_infos = [l for l in SEG_LABELS_LIST if l['id'] in labels]

    label_img_rgb = np.array([label_img,
                              label_img,
                              label_img]).transpose(1,2,0)
    for l in label_infos:
        mask = label_img == l['id']
        label_img_rgb[mask] = l['rgb_values']

    return label_img_rgb.astype(np.uint8)

In [None]:
(15*600)/60

In [None]:
# for i in range(15,16):
ind = 0

# Raw Test Image 
testing_image = raw_images[ind]
test_labels=labels
testing_image = testing_image[:,:64]


plt.figure(figsize=(20,10))
plt.imshow(testing_image, cmap=plt.cm.gray)

# # Manual Test Image 
# plt.figure(figsize=(20,10))
# plt.imshow(labels[ind])

# testing_image = segmented_images[0]
h,w = testing_image.shape

testing_image = testing_image.reshape((1,h,w,1))
prediction = model.predict(testing_image)
prediction = np.squeeze(prediction,axis = 0)


prediction = np.reshape(prediction,(h,w,8))

print(prediction.shape)
output = np.zeros((h,w))
ground = np.zeros((h,w))
for i in range(h):
    for j in range(w):
        index = np.argmax(prediction[i][j])
        output[i][j] = index
idx = np.asarray(ids[0])
print(idx.shape)


color = label_img_to_rgb(output)

plt.imshow(color)


# fig, axes = plt.subplots(nrows=8, ncols=1, figsize=(20,20))
# for i, ax in enumerate(axes):
#     ax.imshow(idx[:,:,i])
#     ax.set_title("slice " + str(i))
# plt.show()


In [None]:
# for i in range(15,16):
ind = 0

# Raw Test Image 
testing_image = raw_images[ind]
test_labels=labels
testing_image = testing_image[:,:64]
# segmented_images = segment_image(testing_image, 0, 600, 64)

plt.figure(figsize=(20,10))
plt.imshow(testing_image, cmap=plt.cm.gray)

# # Manual Test Image 
# plt.figure(figsize=(20,10))
# plt.imshow(labels[ind])

# testing_image = segmented_images[0]
h,w = testing_image.shape

testing_image = testing_image.reshape((1,h,w,1))
prediction = model.predict(testing_image)
prediction = np.squeeze(prediction,axis = 0)


prediction = np.reshape(prediction,(h,w,8))

print(prediction.shape)
output = np.zeros((h,w))
ground = np.zeros((h,w))
for i in range(h):
    for j in range(w):
        index = np.argmax(prediction[i][j])
        output[i][j] = index
idx = np.asarray(ids[0])
print(idx.shape)


color = label_img_to_rgb(output)

plt.imshow(color)


# fig, axes = plt.subplots(nrows=8, ncols=1, figsize=(20,20))
# for i, ax in enumerate(axes):
#     ax.imshow(idx[:,:,i])
#     ax.set_title("slice " + str(i))
# plt.show()
