In [14]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
import h5py
import glob
import math

import numpy             as np
import matplotlib.pyplot as plt
import tensorflow        as tf

from matplotlib.colors     import LogNorm
from sklearn.preprocessing import MinMaxScaler

In [15]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Conv2D, Conv2DTranspose, Dropout, Flatten, Reshape, Lambda, Layer, LeakyReLU, PReLU
from tensorflow.keras.layers import AveragePooling2D, BatchNormalization, Activation
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
#tf.keras.backend.set_floatx('float16')                                                                                                                                                                                    

In [16]:
%matplotlib inline

In [17]:
! ls -lrth input_qcd_v2/*.h5
! ls -lrth input_suep_v2/*.h5

-rw-r--r--  1 simranjitsinghchhibra  staff   7.1M May  1 17:50 input_qcd_v2/qcd_gensim_101_0.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.3M May  1 17:50 input_qcd_v2/qcd_gensim_101_1.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.1M May  1 17:50 input_qcd_v2/qcd_gensim_101_2.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.4M May  1 17:50 input_qcd_v2/qcd_gensim_101_3.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.2M May  1 17:50 input_qcd_v2/qcd_gensim_101_4.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.2M May  1 17:50 input_qcd_v2/qcd_gensim_101_5.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.2M May  1 17:50 input_qcd_v2/qcd_gensim_101_6.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.3M May  1 17:50 input_qcd_v2/qcd_gensim_101_7.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.2M May  1 17:50 input_qcd_v2/qcd_gensim_101_8.h5
-rw-r--r--  1 simranjitsinghchhibra  staff   7.4M May  1 17:50 input_qcd_v2/qcd_gensim_101_9.h5
-rw-r--r--  1 simranjitsinghchhibra  sta

In [18]:
#GENERATOR 

def get_file_list(path):
    flist = []
    flist += glob.glob(path + '/' + '*101_0*.h5')
    flist.sort()
    print("flist: ", flist)
    return flist

def read_images_from_file(fname):
    print("Appending %s" %fname)
    with h5py.File(fname,'r') as f:
        ImageTrk  = np.array(f.get("ImageTrk_PUcorr")[:10], dtype=np.float16)
        ImageTrk  = ImageTrk.reshape(ImageTrk.shape[0], ImageTrk.shape[1], ImageTrk.shape[2], 1)

        ImageECAL = np.array(f.get("ImageECAL")[:10], dtype=np.float16)
        ImageECAL = ImageECAL.reshape(ImageECAL.shape[0], ImageECAL.shape[1], ImageECAL.shape[2], 1)

        ImageHCAL = np.array(f.get("ImageHCAL")[:10], dtype=np.float16)
        ImageHCAL = ImageHCAL.reshape(ImageHCAL.shape[0], ImageHCAL.shape[1], ImageHCAL.shape[2], 1)

        Image3D = np.concatenate([ImageTrk, ImageECAL, ImageHCAL], axis=-1)

        Image3D_zero = np.zeros((Image3D.shape[0], 288, 360, 3), dtype=np.float16)
        Image3D_zero[:, 1:287, :, :] += Image3D
        Image3D_zero = np.divide(Image3D_zero, 2000., dtype=np.float16)
        return Image3D_zero

def concatenate_by_file_content(Image3D, fname):
    Image3D_tmp = read_images_from_file(fname)
    Image3D = np.concatenate([Image3D, Image3D_tmp], axis=0) if Image3D.size else Image3D_tmp
    return Image3D

def gen(parts_n, pathindex):

    if   (pathindex == 0): flist = get_file_list("input_qcd_v2")
    elif (pathindex == 1): flist = get_file_list("input_qcd_v2"  )

    Image3D_conc = np.array([])
    
    for i_file, fname in enumerate(flist):
        Image3D_conc = concatenate_by_file_content(Image3D_conc, fname)

        while (len(Image3D_conc) >= parts_n):
            Image3D_part, Image3D_conc = Image3D_conc[:parts_n] , Image3D_conc[parts_n:]
            print (" K.sum ", np.sum(Image3D_part[0]))                                                                                                                                                                     
            #print (" ",Image3D_part.shape, Image3D_conc.shape)
            
            Image3D_part_tf = tf.convert_to_tensor(Image3D_part, dtype=tf.float16)
            yield (Image3D_part_tf, Image3D_part_tf)


parts = 1
dataset_train = tf.data.Dataset.from_generator(
    gen,
    args=[parts, 0],
    output_types=(tf.float16, tf.float16))

dataset_val = tf.data.Dataset.from_generator(
    gen,
    args=[parts, 1],
    output_types=(tf.float16, tf.float16))

In [19]:
#MODEL

img_rows = 288
img_cols = 360
img_chns = 3

############ENCODER                                                                                                                                                                                                        
inputImage = Input(shape=(img_rows, img_cols, img_chns))
x1 = Conv2D(128, (3,3), strides=(3, 3), padding="same")(inputImage)
x2 = BatchNormalization()(x1)
x3 = PReLU()(x2)
x4 = Conv2D(64, (3,3), strides=(2, 2), padding="same")(x3)
x5 = BatchNormalization()(x4)
x6 = PReLU()(x5)
x7 = Conv2D(32, (3,3), strides=(2, 2), padding="same")(x6)
x8 = BatchNormalization()(x7)
x9 = PReLU()(x8)
x10 = Conv2D(16, (3,3), strides=(2, 2), padding="same")(x9)
x11 = BatchNormalization()(x10)
x12 = PReLU()(x11)
x13 = Conv2D(8, (3,3), strides=(2, 3), padding="same")(x12)
x14 = BatchNormalization()(x13)
encoder_output = PReLU()(x14)

############DECODER                                                                                                                                                                                                       \
                                                                                                                                                                                                                           
x15 = Conv2DTranspose(16, (3,3), strides=(2, 3), padding="same")(encoder_output)
x16 = BatchNormalization()(x15)
x17 = PReLU()(x16)
x18 = Conv2DTranspose(32, (3,3), strides=(2, 2), padding="same")(x17)
x19 = BatchNormalization()(x18)
x20 = PReLU()(x19)
x21 = Conv2DTranspose(64, (3,3), strides=(2, 2), padding="same")(x20)
x22 = BatchNormalization()(x21)
x23 = PReLU()(x22)
x24 = Conv2DTranspose(128, (3,3), strides=(2, 2), padding="same")(x23)
x25 = BatchNormalization()(x24)
x26 = PReLU()(x25)
x27 = Conv2DTranspose(img_chns, (3,3), strides=(3, 3), padding="same")(x26)
x28 = BatchNormalization()(x27)
output = Activation('relu')(x28)

model = Model(inputs=inputImage, outputs=output)
encoder = Model(inputs=inputImage, outputs=encoder_output)

In [20]:
model.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 288, 360, 3)]     0         
                                                                 
 conv2d_5 (Conv2D)           (None, 96, 120, 128)      3584      
                                                                 
 batch_normalization_10 (Bat  (None, 96, 120, 128)     512       
 chNormalization)                                                
                                                                 
 p_re_lu_9 (PReLU)           (None, 96, 120, 128)      1474560   
                                                                 
 conv2d_6 (Conv2D)           (None, 48, 60, 64)        73792     
                                                                 
 batch_normalization_11 (Bat  (None, 48, 60, 64)       256       
 chNormalization)                                          

In [21]:
LEARNING_RATE = 0.0005 #2 times smaller than default 0.001                                                                                                                                                                 
n_epochs = 1
batch_size = parts
opt = Adam(learning_rate = LEARNING_RATE)

#tf.config.run_functions_eagerly(True)

def DiceLoss(y_true, y_pred, smooth=1e-6):
    dice = []
    for i in range(parts):
        y_true_tmp = tf.reshape(y_true[i], shape=(1, (288*360*3)))
        y_pred_tmp = tf.reshape(y_pred[i], shape=(1, (288*360*3)))

        idx_keep_in = tf.where(y_true_tmp[0,:]!=0)[:,-1]
        y_true_tmp  = tf.gather(y_true_tmp[0,:], idx_keep_in)
        y_pred_tmp  = tf.gather(y_pred_tmp[0,:], idx_keep_in)

        y_true_tmp = tf.reshape(y_true_tmp, shape=(1, y_true_tmp.shape[0]))
        y_pred_tmp = tf.reshape(y_pred_tmp, shape=(y_pred_tmp.shape[0], 1))

        intersection = K.sum(K.dot(y_true_tmp, y_pred_tmp))
        dice.append(1 - (2 * intersection + smooth) / (K.sum(y_true[i]) + K.sum(y_pred[i]) + smooth))
    return dice

def intersection(targets, inputs):

    inputs = tf.reshape(inputs, shape=((parts*288*360*3), 1))
    targets = tf.reshape(targets, shape=(1, (parts*288*360*3)))

    return K.sum(K.dot(targets, inputs))

def sum_y_true(y_true, y_pred):
#    return K.sum(y_true, axis=[1,2,3])
    return K.sum(y_true)

def sum_y_pred(y_true, y_pred):
    return K.sum(y_pred, axis=[1,2,3])

In [22]:
model.compile(optimizer=opt, loss = DiceLoss, metrics=[intersection, sum_y_true, sum_y_pred]) 

In [None]:
history = model.fit(dataset_train, epochs=n_epochs, initial_epoch = 0, batch_size=batch_size,# steps_per_epoch=100,                                                                                                        
                    validation_data=dataset_val,
                    callbacks = [
                        EarlyStopping(monitor='val_loss', patience=5, verbose=1, min_delta=0.0001),
                        ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, verbose=1),
                        TerminateOnNaN()])