<a href="https://colab.research.google.com/github/psychemistz/MultiDimGCNR/blob/main/Denoising_super_res_UNET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import keras.backend as K
import cv2
from sklearn.model_selection import train_test_split

import keras
from os import path
import h5py
import numpy as np
import matplotlib.pyplot as plt
import gc
import glob
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
import scipy

from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.layers import *
from keras.models import Model
from keras.models import load_model
from keras.utils.np_utils import to_categorical
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.metrics import confusion_matrix
from tensorflow.keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
from sklearn.datasets import load_sample_image
from sklearn.feature_extraction import image
from scipy.io import savemat

In [None]:
label_obj_path = '/content/drive/MyDrive/BUS_PROJECT/datasets/XPIE_object/'
label_mask_path = '/content/drive/MyDrive/BUS_PROJECT/datasets/XPIE_mask/'

noisy_obj_path = '/content/drive/MyDrive/BUS_PROJECT/datasets/XPIE_obj_2_64/'
noisy_mask_path = '/content/drive/MyDrive/BUS_PROJECT/datasets/XPIE_mask_2_64/'

US_data_path = '/content/drive/MyDrive/BUS_PROJECT/datasets/BUS_TCB_64/'
US_label_path = '/content/drive/MyDrive/BUS_PROJECT/datasets/filter_imgs_256/' # NLLR

FWHM_path = '/content/drive/MyDrive/BUS_PROJECT/datasets/FWHM_TCB_64/'

In [None]:
# import os
label_obj = glob.glob(label_obj_path+'*.png')
label_mask = glob.glob(label_mask_path+'*.png')
noisy_obj = glob.glob(noisy_obj_path+'*.png')
noisy_mask = glob.glob(noisy_mask_path+'*.png')
label_US = glob.glob(US_label_path+ '*.png')
noisy_US = glob.glob(US_data_path+'*.png')

FWHM_list = glob.glob(FWHM_path+'*.png')

In [None]:
split_name_obj = '/content/drive/MyDrive/BUS_PROJECT/datasets/XPIE_object/'
split_name_mask = '/content/drive/MyDrive/BUS_PROJECT/datasets/XPIE_mask/'
split_name_US = 'beta_10_H_5_'
split_name_FWHM = '/content/drive/MyDrive/BUS_PROJECT/datasets/FWHM_TCB_64/'

In [None]:
def load_images(data_path, label_list, split_name):
  train_labels = []
  train_imgs = [] 
  for label in label_list:
    data_filepath = data_path + label.split(split_name)[1]
    label_img = img_to_array(load_img(label, color_mode='grayscale'))
    noisy_img = img_to_array(load_img(data_filepath, color_mode='grayscale'))
    train_labels.append(label_img)
    train_imgs.append(noisy_img)
  y_train = np.asarray(train_labels)
  x_train = np.asarray(train_imgs)
  x_train = x_train/255.0
  y_train = y_train/255.0
  return y_train, x_train

In [None]:
def enc_conv_block(in_conv, num_filter = 32, kernel_size = 3):
  x = Conv2D(num_filter, kernel_size, padding='same')(in_conv)
  x = BatchNormalization()(x)
  x = Activation(activation='relu')(x)
  x = Conv2D(num_filter, kernel_size, padding='same')(x)
  x = BatchNormalization()(x)
  SC = Activation(activation='relu')(x)
  x = MaxPooling2D(pool_size=(2,2))(SC)
  x = Conv2D(num_filter, kernel_size, padding='same')(x)
  x = BatchNormalization()(x)
  y = Activation(activation='relu')(x)
  return y, SC 

def dec_deconv_block(encoder_output, SC, num_filter = 256, kernel_size = 3):
  x  = Conv2DTranspose(num_filter, kernel_size, strides = (2,2), padding='same')(encoder_output)
  x = BatchNormalization()(x)
  x = Activation(activation='relu')(x)
  x = Concatenate()([x, SC])
  x = Conv2D(num_filter, kernel_size, padding='same')(x)
  x = BatchNormalization()(x)
  x = Activation(activation='relu')(x)
  x = Conv2D(num_filter, kernel_size, padding='same')(x)
  x = BatchNormalization()(x)
  y = Activation(activation='relu')(x)
  return y

def my_model():
  num_filter = 32
  kernel_size = 3
  
  input_img = Input(shape = (None, None, 1), name = 'input_img')
  green = Conv2D(num_filter, kernel_size, activation='relu', padding='same', name = 'green')(input_img)
  CB_1 , SC1 = enc_conv_block(green, num_filter, kernel_size)
  CB_2 , SC2 = enc_conv_block(CB_1, 2*num_filter, kernel_size)
  CB_3 , SC3 = enc_conv_block(CB_2, 4*num_filter, kernel_size)
  CB_4 , SC4 = enc_conv_block(CB_3, 8*num_filter, kernel_size)
  
  CB5, SC5 = enc_conv_block(CB_4, 16*num_filter, kernel_size)
  
  DCB1 = dec_deconv_block(SC5, SC4, 8*num_filter, kernel_size)
  DCB2 = dec_deconv_block(DCB1, SC3, 4*num_filter, kernel_size)  
  DCB3 = dec_deconv_block(DCB2, SC2, 2*num_filter, kernel_size)
  DCB4 = dec_deconv_block(DCB3, SC1, num_filter, kernel_size)

  # x = Conv2D(1, (1,1), padding= 'same', name='1by1conv')(DCB4)
  # output = Add(name='output')([x, input_img])
  output = Conv2D(1, (1,1), padding= 'same', activation= 'sigmoid', name='output')(DCB4)

  model = Model(inputs=[input_img], outputs=[output])
  model.compile(optimizer=Adam(learning_rate = 7e-4), 
                loss={ 'output': 'mean_squared_error'})

  return model

model = my_model()
model.summary() 

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_img (InputLayer)         [(None, None, None,  0           []                               
                                 1)]                                                              
                                                                                                  
 green (Conv2D)                 (None, None, None,   320         ['input_img[0][0]']              
                                32)                                                               
                                                                                                  
 conv2d (Conv2D)                (None, None, None,   9248        ['green[0][0]']                  
                                32)                                                           

In [None]:
### FOR UNET
save_path_US = '/content/drive/MyDrive/BUS_PROJECT/results/Denoising_SR_64/N2_64/UNET/BUS_only/'
save_path_XPIE = '/content/drive/MyDrive/BUS_PROJECT/results/Denoising_SR_64/N2_64/UNET/XPIE_only/'
save_path_US_aug = '/content/drive/MyDrive/BUS_PROJECT/results/Denoising_SR_64/N2_64/UNET/BUS_aug_all/'

In [None]:
from scipy.io import savemat
def load_test_and_pred(data_path, test_label_list, save_path, split_name, model):
  for label_filepath in test_label_list:
    file_name = label_filepath.split(split_name)[1]
    data_filepath = data_path + file_name
    x_test = img_to_array(load_img(data_filepath, color_mode='grayscale'))
    x_test = np.expand_dims(x_test, axis = 0)
    x_test = x_test/255.0
    test_pred = model.predict(x_test, batch_size = 1)
    test_pred = np.clip(test_pred, 0.0, 1.0)
    save_name = save_path + file_name.split('.')[0] + '.mat' 
    savemat(save_name, {'test_pred':test_pred})

In [None]:
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
import os

X_obj = label_obj
X_mask = label_mask
X_US = label_US
X_FWHM = FWHM_list
n_splits = 5
kf = KFold(n_splits)
aug = 2

### FOR XPIE OBJECT
train_obj_list = []
test_obj_list = []
# val_labels_obj_list = []
for train_obj, test_obj in kf.split(X_obj):
  train_obj_list.append(train_obj)
  test_obj_list.append(test_obj)

### FOR XPIE MASK/GROUND
train_mask_list = []
test_mask_list = []
# val_labels_mask_list = []
for train_mask, test_mask in kf.split(X_mask):
  train_mask_list.append(train_mask)
  test_mask_list.append(test_mask)

### FOR BUS
train_US_list = []
test_US_list = []
# val_labels_US_list = []
for train_US, test_US in kf.split(X_US):
  train_US_list.append(train_US)
  test_US_list.append(test_US)

### FOR FWHM
train_FWHM_list = []
test_FWHM_list = []
# val_labels_US_list = []
for train_FWHM, test_FWHM in kf.split(X_FWHM):
  train_FWHM_list.append(train_FWHM)
  test_FWHM_list.append(test_FWHM)

for trial_no in range(0,n_splits):
  obj_train_list = [X_obj[index] for index in train_obj_list[trial_no]]
  obj_test_list = [X_obj[index] for index in test_obj_list[trial_no]]

  mask_train_list = [X_mask[index] for index in train_mask_list[trial_no]]
  mask_test_list = [X_mask[index] for index in test_mask_list[trial_no]]

  US_train_list = [X_US[index] for index in train_US_list[trial_no]]
  US_test_list = [X_US[index] for index in test_US_list[trial_no]]

  FWHM_train_list = [X_FWHM[index] for index in train_FWHM_list[trial_no]]
  FWHM_test_list = [X_FWHM[index] for index in test_FWHM_list[trial_no]]
  # print(obj_test_list)
  if aug == 0:
    y_train_US, x_train_US = load_images(US_data_path, US_train_list, split_name_US)
    x_train, x_val, y_train, y_val = train_test_split(x_train_US, y_train_US, test_size=0.2, random_state=42)
    print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
    save_path = save_path_US
  
  elif aug == 1:
    y_train_obj, x_train_obj = load_images(noisy_obj_path, obj_train_list, split_name_obj)
    x_train_obj, x_val_obj, y_train_obj, y_val_obj = train_test_split(x_train_obj, y_train_obj, test_size=0.2, random_state=42)
    y_train_mask, x_train_mask = load_images(noisy_mask_path, mask_train_list, split_name_mask)
    x_train_mask, x_val_mask, y_train_mask, y_val_mask = train_test_split(x_train_mask, y_train_mask, test_size=0.2, random_state=42)
    y_val = np.concatenate((y_val_obj, y_val_mask), axis = 0)
    x_val = np.concatenate((x_val_obj, x_val_mask), axis = 0)
    y_train = np.concatenate((y_train_obj, y_train_mask), axis = 0)
    x_train = np.concatenate((x_train_obj, x_train_mask), axis = 0)
    print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
    save_path = save_path_XPIE

  else:
    y_train_US, x_train_US = load_images(US_data_path, US_train_list, split_name_US)
    x_train_US, x_val_US, y_train_US, y_val_US = train_test_split(x_train_US, y_train_US, test_size=0.2, random_state=42)
    y_train_obj, x_train_obj = load_images(noisy_obj_path, obj_train_list, split_name_obj)
    x_train_obj, x_val_obj, y_train_obj, y_val_obj = train_test_split(x_train_obj, y_train_obj, test_size=0.2, random_state=42)
    y_train_mask, x_train_mask = load_images(noisy_mask_path, mask_train_list, split_name_mask)
    x_train_mask, x_val_mask, y_train_mask, y_val_mask = train_test_split(x_train_mask, y_train_mask, test_size=0.2, random_state=42)
    y_val = np.concatenate((y_val_US, y_val_obj, y_val_mask), axis = 0)
    x_val = np.concatenate((x_val_US, x_val_obj, x_val_mask), axis = 0)
    y_train = np.concatenate((y_train_US, y_train_obj, y_train_mask), axis = 0)
    x_train = np.concatenate((x_train_US, x_train_obj, x_train_mask), axis = 0)
    print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
    save_path = save_path_US_aug
  
  model= my_model()

  es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=50)
  
  checkpoint = ModelCheckpoint(save_path+'model_trial_'+str(trial_no)+'.h5',
                                  verbose=0, monitor='val_loss',save_best_only=True, mode='auto')

  history = model.fit(x_train, y_train, 
                    epochs=200,
                    shuffle = True,
                    validation_data = (x_val, y_val),
                    callbacks = [es, checkpoint],
                    batch_size=8,
                    verbose = 0)
  
  del model
  model = load_model(save_path+'model_trial_'+str(trial_no)+'.h5')
   
  load_test_and_pred(US_data_path, US_test_list, save_path+'BUS/', split_name_US, model)
  load_test_and_pred(noisy_obj_path, obj_test_list, save_path+'XPIE_obj/', split_name_obj, model)
  load_test_and_pred(noisy_mask_path, mask_test_list, save_path+'XPIE_mask/', split_name_mask, model)
  load_test_and_pred(FWHM_path, FWHM_test_list, save_path+'FWHM/', split_name_FWHM, model)

(1384, 256, 256, 1) (1384, 256, 256, 1) (346, 256, 256, 1) (346, 256, 256, 1)
(1384, 256, 256, 1) (1384, 256, 256, 1) (346, 256, 256, 1) (346, 256, 256, 1)
(1384, 256, 256, 1) (1384, 256, 256, 1) (346, 256, 256, 1) (346, 256, 256, 1)
(1384, 256, 256, 1) (1384, 256, 256, 1) (347, 256, 256, 1) (347, 256, 256, 1)
(1384, 256, 256, 1) (1384, 256, 256, 1) (347, 256, 256, 1) (347, 256, 256, 1)
