In [1]:
import numpy as np
import tensorflow as tf
import time

In [None]:
#This notebooks shows the struture of the network/the loss functions/the training of our approach

In [2]:
#Load Data

# The data below can be downloaded from: https://drive.google.com/drive/folders/1IoxOtAt-8NiFgbtZh1RDY32Jb_Wd5TGa?usp=sharing

signals_train=np.load('signals_train.npy')
distributions_train=np.load('distributions_train.npy')
signals_valid=np.load('signals_valid.npy')
distributions_valid=np.load('distributions_valid.npy')


In [3]:
import math

#Define Loss functions

arr=np.logspace(math.log10(10.0), math.log10(2000.), num=60, endpoint=True, base=10.0)

arr=np.tile(arr, (2000, 1))

arr_tf=tf.constant(arr.astype('float32'), dtype=tf.float32)


#Implementation of the Wasserstein Distance
def wasserstein_distance(y_actual,y_pred):
    #np.abs(np.cumsum(gt_distributions[40,:]-dist_array[40,:])
           
    abs_cdf_difference=tf.math.abs(tf.math.cumsum(y_actual-y_pred,axis=1))

    return tf.reduce_mean(0.5*tf.reduce_sum(tf.math.multiply(-arr_tf[:,:-1]+arr_tf[:,1:],abs_cdf_difference[:,:-1]+abs_cdf_difference[:,1:]),axis=1))

#Combination loss function used in MIML
def MSE_wasserstein_combo(y_actual,y_pred):
    wass_loss=wasserstein_distance(y_actual,y_pred)
    MSE= tf.math.reduce_mean(tf.reduce_mean(tf.math.squared_difference(y_pred, y_actual),axis=1))
    return wass_loss+100000.*MSE

In [4]:
#Define the network structure
inputs = tf.keras.Input(shape=(32,))
x = tf.keras.layers.Dense(256, activation=tf.nn.leaky_relu, kernel_initializer='he_uniform',bias_initializer=tf.keras.initializers.Constant(0.01))(inputs)
x = tf.keras.layers.Dense(256, activation=tf.nn.leaky_relu, kernel_initializer='he_uniform',bias_initializer=tf.keras.initializers.Constant(0.01))(x)
x=tf.keras.layers.Dense(256, activation=tf.nn.leaky_relu, kernel_initializer='he_uniform',bias_initializer=tf.keras.initializers.Constant(0.01))(x)
x=tf.keras.layers.Dense(256, activation=tf.nn.leaky_relu, kernel_initializer='he_uniform',bias_initializer=tf.keras.initializers.Constant(0.01))(x)
x=tf.keras.layers.Dense(256, activation=tf.nn.leaky_relu, kernel_initializer='he_uniform',bias_initializer=tf.keras.initializers.Constant(0.01))(x)
x=tf.keras.layers.Dense(256, activation=tf.nn.leaky_relu, kernel_initializer='he_uniform',bias_initializer=tf.keras.initializers.Constant(0.01))(x)
outputs=tf.keras.layers.Dense(60, activation=tf.keras.activations.softmax, kernel_initializer='he_uniform',bias_initializer=tf.keras.initializers.Constant(0.01))(x)
#outputs=tf.keras.layers.Dense(2, activation=tf.keras.activations.relu, kernel_initializer='normal',bias_initializer=tf.keras.initializers.Constant(0.1))(x)
#outputs=tf.clip_by_value(x,90.,180.)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

In [5]:
#Define optimizer and train the network

# checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
#     'weights.{epoch:02d}-{val_loss:.2f}.hdf5', monitor='val_loss', verbose=0, save_best_only=False,
#     save_weights_only=True, mode='auto', save_freq='epoch'
# )


opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=opt,
              
              loss=MSE_wasserstein_combo,metrics=['mse',wasserstein_distance])
start=time.time()

model.fit(signals_train, distributions_train,epochs=30, batch_size=2000, validation_data=(signals_valid,distributions_valid))  # starts training

#To save the model for each epoch, uncomment the checkpoint_callback above and use
#model.fit(signals_train, distributions_train,epochs=30, batch_size=2000, validation_data=(signals_valid,distributions_valid),callbacks=[checkpoint_callback])  # starts training


end=time.time()

print('Time Elapsed:%i seconds'%(end-start))

Train on 1120000 samples, validate on 140000 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Time Elapsed:71 seconds
