# U-net type encoder-decoder convolutional neural network

### Import required Python libraries



In [None]:
from google.colab import drive
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Activation
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import concatenate, UpSampling2D, BatchNormalization
from tensorflow.keras.optimizers import Adam
import scipy.io as sio

# This will prompt for authorization. This enables to load files stored in your google drive
drive.mount('/content/drive')

In [None]:
# !pip install sklearn
from sklearn.model_selection import train_test_split

In [None]:
from scipy.ndimage import interpolation

### Import, process, and shape the dataset

In [None]:
# load data from Google drive
filename = '/content/drive/MyDrive/Giulia waveguides/Paper/outputs BPM/1cm/combined_normalized/output_stack_3e-5.mat' # Uploading output from BPM
mat_contents = sio.loadmat(filename)
NN_input = mat_contents['output_stack']
print("data loaded...")
maxNN_input=np.max(NN_input)
NN_input=NN_input/maxNN_input
print(NN_input.shape)

filename = '/content/drive/MyDrive/Giulia waveguides/network/dataset/d_all_norm256.mat' # Uploading the original digits
mat_contents = sio.loadmat(filename)
NN_originals = mat_contents['d_all_norm256']
print("data loaded...")

print(NN_originals.shape)

In [None]:
plt.imshow(NN_input[5, :, :], cmap='viridis')
#plt.title('output_test [' + str(test_id) + ']')
plt.colorbar()
plt.grid(None)
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
### RUN THIS SECTION IF YOU WANT TO ADD NOISE ###
for ii in range(1170):
  noise=np.random.rand(5,5)
  noise_up= interpolation.zoom(noise,256/5)
  noise_up=noise_up/np.max(noise_up)
  NN_input[ii,:,:]=noise_up*0.25+NN_input[ii, :, :]

plt.imshow(NN_input[5, :, :], cmap='viridis')
#plt.title('output_test [' + str(test_id) + ']')
plt.colorbar()
plt.grid(None)
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
### RUN THIS SECTION IF YOU WANT TO REMOVE THE CLADDING LIGHT  ###
# Construct the binary mask 
Nx       = 256                                   # x-direction size of computational grid
Ny       = Nx                                    # x-direction size of computational grid. The computation domain is square
Lx_ = 140e-6                                     # width of the computation window [m] 
Ly_ = Lx_		                                     # height of the computation window [m] 
dx_ = Lx_/Nx                                     # normalized discretization step in x
x_  = dx_ * np.arange(-Nx/2,Nx/2,1)              # x dimension vector
dy_ = Ly_/Ny                                       
y_  = dy_ * np.arange(-Ny/2,Ny/2,1)                          
[X_, Y_]    = np.meshgrid(x_, y_)

wgds =np.zeros((Nx,Ny))
width_x  = (2e-6)  *1                           
width_y  = (2.5e-6)*1                          
seperation = 30e-6                              # waveguide centor-to-center seperation (pitch)
wgd_num=int(120e-6/seperation+1)

for ii in range(wgd_num):
  for jj in range(wgd_num):
    ii_=ii-(wgd_num-1)/2
    jj_=jj-(wgd_num-1)/2
    wgds[np.logical_and(np.abs(X_-ii_*seperation)<=width_x/2,  np.abs(Y_-jj_*seperation)<=width_y/2)]=1 
  
plt.imshow((wgds),extent=[-Lx_/2*1e6,Lx_/2*1e6,-Ly_/2*1e6,Ly_/2*1e6])
plt.colorbar()
plt.xlabel('x axis [um]')
plt.ylabel('y axis [um]')
plt.show()

# Remove the cladding light
for ii in range(1170):
  NN_input[ii,:,:]=wgds*NN_input[ii, :, :]

plt.imshow(NN_input[5, :, :], cmap='viridis')
#plt.title('output_test [' + str(test_id) + ']')
plt.colorbar()
plt.grid(None)
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
X_train, X_test, Y_train, Y_test = train_test_split(NN_input, NN_originals, test_size=0.1, random_state=1994) # We are taking 10% of the dataset for testing (works in pairs)
print('X_train : ', X_train.shape)
print('X_test : ', X_test.shape)
print('Y_train : ', Y_train.shape)
print('Y_test : ', Y_test.shape)

In [None]:
output_train = Y_train # portion of the original images we used for training
input_train = X_train # portion of the BPM images associated with Y_train
output_test = Y_test # portion of the original images we used for testing
input_test = X_test # portion of the BPM images associated with Y_test

input_train= input_train.reshape(input_train.shape[0], 256, 256, 1)  # add an extra dimension to array
output_train= output_train.reshape(output_train.shape[0], 256, 256, 1)
input_test= input_test.reshape(input_test.shape[0], 256, 256, 1)
output_test= output_test.reshape(output_test.shape[0], 256, 256, 1)



print('* processing and shaping data')
print()
print('input_train : ', input_train.shape)
print('output_train : ', output_train.shape)
print('input_test : ', input_test.shape)
print('output_test : ', output_test.shape)
print()

# plot example

item_id = 15

plt.imshow(input_train[item_id, :, :, 0], cmap='viridis')
plt.title('input_train [' + str(item_id) + ']')
plt.grid(None)
plt.xticks([])
plt.yticks([])
plt.show()

plt.imshow(output_train[item_id, :, :, 0], cmap='viridis')
plt.title('output_train [' + str(item_id) + ']')
plt.grid(None)
plt.xticks([])
plt.yticks([])
plt.show()

### Define the network hyperparameters

In [None]:
optimizer_type = Adam(learning_rate=1e-5)  # optimisation algorithm: Adam 
loss = 'mean_squared_error'  # loss (cost) function to be minimised by the optimiser
metrics = ['mean_absolute_error']  # network accuracy metric to be determined after each epoch
validtrain_split_ratio = 0.2  # % of the seen dataset to be put aside for validation, rest is for training
max_epochs = 150  # maxmimum number of epochs to be iterated
batch_size = 20   # batch size for the training data set
batch_shuffle = True   # shuffle the training data prior to batching before each epoch

### Define the network architecture

* using the Keras' *functional* model


In [None]:
input_shape = (256, 256, 1)
inputs = Input(shape=input_shape)

# encoder section

down0 = Conv2D(8, (3, 3), padding='same')(inputs) # we have 8 kernels of size 3 x 3
down0 = BatchNormalization()(down0) # thresholding 
down0 = Activation('relu')(down0) # pixel by pixel to discard negative values
down0 = Conv2D(8, (3, 3), padding='same')(down0) # we add kernels and we convolve again; we have 8 kernels 
down0 = BatchNormalization()(down0)
down0 = Activation('relu')(down0)
down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) #  we check 2 x 2 matrices and take the max; we move 2 pixels in terms of columns and raws; this is the true downsampling step
    
down1 = Conv2D(16, (3, 3), padding='same')(down0_pool)
down1 = BatchNormalization()(down1)
down1 = Activation('relu')(down1)
down1 = Conv2D(16, (3, 3), padding='same')(down1)
down1 = BatchNormalization()(down1)
down1 = Activation('relu')(down1)
down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) 

down2 = Conv2D(32, (3, 3), padding='same')(down1_pool)
down2 = BatchNormalization()(down2)
down2 = Activation('relu')(down2)
down2 = Conv2D(32, (3, 3), padding='same')(down2)
down2 = BatchNormalization()(down2)
down2 = Activation('relu')(down2)
down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) 

down3 = Conv2D(64, (3, 3), padding='same')(down2_pool)
down3 = BatchNormalization()(down3)
down3 = Activation('relu')(down3)
down3 = Conv2D(64, (3, 3), padding='same')(down3)
down3 = BatchNormalization()(down3)
down3 = Activation('relu')(down3)
down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) 

down4 = Conv2D(128, (3, 3), padding='same')(down3_pool)
down4 = BatchNormalization()(down4)
down4 = Activation('relu')(down4)
down4 = Conv2D(128, (3, 3), padding='same')(down4)
down4 = BatchNormalization()(down4)
down4 = Activation('relu')(down4)
down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) 
 
# center section
    
center = Conv2D(128, (3, 3), padding='same')(down4_pool)
center = BatchNormalization()(center)
center = Activation('relu')(center)
center = Conv2D(128, (3, 3), padding='same')(center)
center = BatchNormalization()(center)
center = Activation('relu')(center)
    
# decoder section WITHOUT (commented) skip connections to the encoder section

up4 = UpSampling2D((2, 2))(center)
#up2 = concatenate([down2, up2], axis=3)
up4 = Conv2D(128, (3, 3), padding='same')(up4)
up4 = BatchNormalization()(up4)
up4 = Activation('relu')(up4)
up4 = Conv2D(128, (3, 3), padding='same')(up4)
up4 = BatchNormalization()(up4)
up4 = Activation('relu')(up4)

up3 = UpSampling2D((2, 2))(up4)
#up2 = concatenate([down2, up2], axis=3)
up3 = Conv2D(64, (3, 3), padding='same')(up3)
up3 = BatchNormalization()(up3)
up3 = Activation('relu')(up3)
up3 = Conv2D(64, (3, 3), padding='same')(up3)
up3 = BatchNormalization()(up3)
up3 = Activation('relu')(up3)

up2 = UpSampling2D((2, 2))(up3)
#up2 = concatenate([down2, up2], axis=3)
up2 = Conv2D(32, (3, 3), padding='same')(up2)
up2 = BatchNormalization()(up2)
up2 = Activation('relu')(up2)
up2 = Conv2D(32, (3, 3), padding='same')(up2)
up2 = BatchNormalization()(up2)
up2 = Activation('relu')(up2)

up1 = UpSampling2D((2, 2))(up2)
#up1 = concatenate([down1, up1], axis=3)
up1 = Conv2D(16, (3, 3), padding='same')(up1)
up1 = BatchNormalization()(up1)
up1 = Activation('relu')(up1)
up1 = Conv2D(16, (3, 3), padding='same')(up1)
up1 = BatchNormalization()(up1)
up1 = Activation('relu')(up1)

up0 = UpSampling2D((2, 2))(up1)
#up0 = concatenate([down0, up0], axis=3)
up0 = Conv2D(8, (3, 3), padding='same')(up0)
up0 = BatchNormalization()(up0)
up0 = Activation('relu')(up0)
up0 = Conv2D(8, (3, 3), padding='same')(up0)
up0 = BatchNormalization()(up0)
up0 = Activation('relu')(up0)

outputs = Conv2D(1, (1, 1), activation='relu')(up0)

### Compile the network

In [None]:
print()
print('* Compiling the network model *')
print()

model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=optimizer_type, loss=loss, metrics=metrics)

# display a summary of the compiled neural network

print(model.summary())  
print()

### Train the neural network with the training dataset

In [None]:
print('* Training the compiled network *')
print()

history = model.fit(input_train, output_train, \
                    batch_size=batch_size, \
                    epochs=max_epochs, \
                    validation_split=validtrain_split_ratio, \
                    shuffle=batch_shuffle)

print()
print('Training completed')
print()

### Plot the training history of the network

In [None]:
from google.colab import files


# model loss

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss : ' + loss)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='best')
plt.show()
#plt.savefig("MSE.png")
#files.download("MSE.png") 
plt.close()

# model accuracy metric

plt.plot(np.array(history.history[metrics[0]]))
plt.plot(np.array(history.history['val_' + metrics[0]]))
plt.title('Model accuracy metric : ' + metrics[0])
plt.ylabel('Accuracy metric')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='best')
plt.show()
#plt.savefig("MAE.png")
#files.download("MAE.png") 
plt.close()

### Evaluate the trained network performance on the unseen test dataset

In [None]:
print('* Evaluating the performance of the trained network on the unseen test dataset *')
print()

evaluate_model = model.evaluate(x=input_test, y=output_test)
loss_metric = evaluate_model [0]
accuracy_metric = evaluate_model [1]

print()
print('Accuracy - ' + metrics[0] + ': %0.3f'%accuracy_metric)
print('Loss - ' + loss + ': %0.3f'%loss_metric)

### Predict the output of a given input

In [None]:
print('* Predicting the output of a given input from test set *')
print()

test_id = 21

input_predict = np.zeros(shape=(1, 256, 256, 1))  # create numpy array of required dimensions for network input

input_predict[0, :, :, 0] = input_test[test_id, :, :, 0]  # reshaping test input image

output_predict = model.predict(input_predict)


print('test_id : ', test_id)
print()

# plot prediction example from test set

plt.imshow(input_test[test_id, :, :, 0], cmap='viridis')
plt.title('input_test [' + str(test_id) + ']')
plt.grid(None)
plt.xticks([])
plt.yticks([])
plt.show()

print()

plt.imshow(output_predict[0, :, :, 0], cmap='viridis')
plt.title('output_predict')
plt.grid(None)
plt.xticks([])
plt.yticks([])
plt.show()

print()

plt.imshow(output_test[test_id, :, :, 0], cmap='viridis')
plt.title('output_test [' + str(test_id) + ']')
plt.grid(None)
plt.xticks([])
plt.yticks([])
plt.show()



In [None]:
# Save some example reconstructions from the test set
input_predict_all=np.zeros((1,256,256,13))
output_predict_all=np.zeros((1,256,256,13))
output_test_all=np.zeros((1,256,256,13))

print('* Predicting the output of a given input from test set *')
print()

test_idss=[3, 8, 21, 42, 31, 66, 100, 94, 52, 77, 111, 69, 6] # 31 is the choice for the paper
maee=np.zeros(shape=(1, 13))
msee=np.zeros(shape=(1, 13))

for ii in range(13):

  test_id = test_idss[ii]

  input_predict = np.zeros(shape=(1, 256, 256, 1))  # create numpy array of required dimensions for network input

  input_predict[0, :, :, 0] = input_test[test_id, :, :, 0]  # reshaping test input image

  output_predict = model.predict(input_predict)


  print('test_id : ', test_id)
  print()

  # plot prediction example from test set

  plt.imshow(input_predict[0, :, :, 0], cmap='viridis')
  plt.title('input_test [' + str(test_id) + ']')
  plt.grid(None)
  plt.xticks([])
  plt.yticks([])
  plt.show()

  print()

  plt.imshow(output_predict[0, :, :, 0], cmap='viridis')
  plt.title('output_predict')
  plt.grid(None)
  plt.xticks([])
  plt.yticks([])
  plt.show()

  print()

  plt.imshow(output_test[test_id, :, :, 0], cmap='viridis')
  plt.title('output_test [' + str(test_id) + ']')
  plt.grid(None)
  plt.xticks([])
  plt.yticks([])
  plt.show()

  # Individual MAE and MSE values for the given examples
  maee[0,ii]=np.sum(np.abs(output_predict[0, :, :, 0]-output_test[test_id, :, :, 0]))/256/256
  msee[0,ii]=np.sum((output_predict[0, :, :, 0]-output_test[test_id, :, :, 0])**2)/256/256

  input_predict_all[0,:,:,ii]=input_predict[0, :, :, 0]
  output_predict_all[0,:,:,ii]=output_predict[0, :, :, 0]
  output_test_all[0,:,:,ii]=output_test[test_id, :, :, 0]


sio.savemat('/content/drive/MyDrive/Giulia waveguides/Paper Rev/figs_30e-6.mat', {'input_predict_all':input_predict_all,'output_predict_all':output_predict_all,'output_test_all':output_test_all})

  