In [107]:
import numpy as np
from matplotlib import pyplot as plt
import tensorflow.compat.v1 as tf

plt.style.use('dark_background')


## visualization of the latent space

In [90]:
def check_latent_space(model_dir):
    # load the model with graph
    sess = tf.Session(graph=tf.Graph())
    MODEL_LOADED=tf.saved_model.loader.load(sess, ["serve"], model_dir)
    graph = sess.graph

    INPUT_S = graph.get_tensor_by_name('input_SENSOR:0')
    INPUT_Y = graph.get_tensor_by_name('input_Y:0')
    INPUT_X = graph.get_tensor_by_name('input_X:0')
    output_u = graph.get_tensor_by_name('output_u:0')
    output_para_net = graph.get_tensor_by_name('output_para_net:0')
    latent_rep = graph.get_tensor_by_name('latent_rep:0')

    # load training data
    tmp = np.load('./DATA/rt-full-amr-siren-with-sensor-train-n.npz')  # 35 columns (sensor-1, ..., sensor-32, x,y,u)
    TRAIN_DATA = tmp['data']
    std_tr = tmp['std']
    mean_tr = tmp['mean'] 

    # load testing data

    data = np.load('./DATA/rt-256x512-cnn.npz')
    TRAIN_SENSOR_DATA = data['train_sensor_data']
    TEST_SENSOR_DATA = data['test_sensor_data']
    TRAIN_CART_DATA = data['train_data']
    TEST_CART_DATA = data['test_data']

    xx, yy = np.meshgrid(np.linspace(2.5e-7,0.5,256,endpoint=True), np.linspace(5e-7, 1, 512,endpoint=True))  # size (512,256)
    xx = xx.reshape(-1,1)
    yy = yy.reshape(-1,1)

    # normalize xx,yy,SENSOR_DATA
    xx_ = (xx - mean_tr[32])/std_tr[32]
    yy_ = (yy - mean_tr[33])/std_tr[33]
    TRAIN_SENSOR_DATA_ = (TRAIN_SENSOR_DATA - mean_tr[:32])/std_tr[:32]
    TEST_SENSOR_DATA_  = (TEST_SENSOR_DATA  - mean_tr[:32])/std_tr[:32]

#     print('TRAIN DATA SHAPE = ',TRAIN_CART_DATA.shape)
#     print('TEST DATA SHAPE = ',TEST_CART_DATA.shape)

    # evaluating all training
    ## compute the latent space
    lat_rep = sess.run(latent_rep, feed_dict={INPUT_S: TRAIN_SENSOR_DATA_})
    return lat_rep

In [91]:
## DIM = 1
model_dir = 'NIF/new-umich/formal_weight_1e-2' + '/RT_NIF_SIREN_NSX_128_LSX_2_NST_64_LST_2_NP_1_NSENSOR_32_ACTREG_1.0/saved_model_ckpt_600'
lat_rep_1 = check_latent_space(model_dir)

## DIM = 2
model_dir = 'NIF/new-umich/formal_weight_1e-2' + '/RT_NIF_SIREN_NSX_128_LSX_2_NST_64_LST_2_NP_2_NSENSOR_32_ACTREG_1.0/saved_model_ckpt_600'
lat_rep_2 = check_latent_space(model_dir)
# plt.plot(lat_rep[:,0],lat_rep[:,1],'k-o')

## DIM = 4
model_dir = 'NIF/new-umich/formal_weight_1e-2' + '/RT_NIF_SIREN_NSX_128_LSX_2_NST_64_LST_2_NP_4_NSENSOR_32_ACTREG_1.0/saved_model_ckpt_400'
lat_rep_4 = check_latent_space(model_dir)

## DIM = 8
model_dir = 'NIF/new-umich/formal_weight_1e-2' + '/RT_NIF_SIREN_NSX_128_LSX_2_NST_64_LST_2_NP_8_NSENSOR_32_ACTREG_1.0/saved_model_ckpt_200'
lat_rep_8 = check_latent_space(model_dir)

INFO:tensorflow:Restoring parameters from NIF/new-umich/formal_weight_1e-2/RT_NIF_SIREN_NSX_128_LSX_2_NST_64_LST_2_NP_1_NSENSOR_32_ACTREG_1.0/saved_model_ckpt_600/variables/variables
INFO:tensorflow:Restoring parameters from NIF/new-umich/formal_weight_1e-2/RT_NIF_SIREN_NSX_128_LSX_2_NST_64_LST_2_NP_2_NSENSOR_32_ACTREG_1.0/saved_model_ckpt_600/variables/variables
INFO:tensorflow:Restoring parameters from NIF/new-umich/formal_weight_1e-2/RT_NIF_SIREN_NSX_128_LSX_2_NST_64_LST_2_NP_4_NSENSOR_32_ACTREG_1.0/saved_model_ckpt_400/variables/variables
INFO:tensorflow:Restoring parameters from NIF/new-umich/formal_weight_1e-2/RT_NIF_SIREN_NSX_128_LSX_2_NST_64_LST_2_NP_8_NSENSOR_32_ACTREG_1.0/saved_model_ckpt_200/variables/variables


## plot latent space

In [92]:
index_arr = np.arange(84)
data_index_arr = index_arr*2

data_index_str_0 = '0000'

for i in index_arr:
    data_index_str = data_index_str_0 + str(i)
    data_index_str = data_index_str[-4:]
    # print(data_index_str)

0000
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083


In [108]:
from matplotlib.offsetbox import TextArea, DrawingArea, OffsetImage, AnnotationBbox
import matplotlib.image as mpimg


# index_arr = np.arange(3) #
index_arr =  np.arange(84)
data_index_arr = index_arr*2

data_index_str_0 = '0000'

for i,index in enumerate(data_index_arr):
    print('i = {}, index = {}'.format(i,index))
    tmp = data_index_str_0 + str(index)
    tmp = tmp[-4:]
    
    fig,axs = plt.subplots(1,3,figsize=(8,4))
    
    # fig 1 - original AMR
    arr_amr = mpimg.imread('DATA/animation/a.'+tmp+'.png')
    imagebox = OffsetImage(arr_amr,zoom=0.45)
    ab = AnnotationBbox(imagebox, (0.5, 0.5),pad=0.0,frameon=False)
    axs[0].add_artist(ab)

    # fig 2 - latent representation
    lat_rep = lat_rep_1 
    axs[1].set_xlim([-1,85])
    axs[1].set_ylim([-0.06,0.06])
    for jj in range(lat_rep.shape[1]):
        axs[1].plot(lat_rep[:i+1,jj],'ro-',label=r'$x_'+str(jj+1)+'$',markersize=3)

    # fig 3 - reconstructed AMR (since almost no difference, so I just to training data to illustrate)
    arr_amr = mpimg.imread('DATA/animation/b.'+tmp+'.png')
    imagebox = OffsetImage(arr_amr,zoom=0.45)
    ab = AnnotationBbox(imagebox, (0.5, 0.5),pad=0.0,frameon=False)
    axs[2].add_artist(ab)

    axs[0].axis('off')
    axs[1].axis('off')
    axs[2].axis('off')
    
    axs[0].set_title('spatio-temporal field',c='w',pad=30)
    axs[1].set_title('latent space',c='w',pad=30)
    axs[2].set_title('reconstructed field',c='w',pad=30)

    fig.tight_layout()
    plt.savefig('pngs/animation/'+str(i)+'.png',bbox_inches='tight')
    plt.close()
    


# plt.xlabel('time index')
# plt.ylabel('latent variables')
# plt.legend(bbox_to_anchor=(1.05,1))
# plt.savefig('./pngs/latent.png',bbox_inches='tight')

i = 0, index = 0
i = 1, index = 2
i = 2, index = 4
i = 3, index = 6
i = 4, index = 8
i = 5, index = 10
i = 6, index = 12
i = 7, index = 14
i = 8, index = 16
i = 9, index = 18
i = 10, index = 20
i = 11, index = 22
i = 12, index = 24
i = 13, index = 26
i = 14, index = 28
i = 15, index = 30
i = 16, index = 32
i = 17, index = 34
i = 18, index = 36
i = 19, index = 38
i = 20, index = 40
i = 21, index = 42
i = 22, index = 44
i = 23, index = 46
i = 24, index = 48
i = 25, index = 50
i = 26, index = 52
i = 27, index = 54
i = 28, index = 56
i = 29, index = 58
i = 30, index = 60
i = 31, index = 62
i = 32, index = 64
i = 33, index = 66
i = 34, index = 68
i = 35, index = 70
i = 36, index = 72
i = 37, index = 74
i = 38, index = 76
i = 39, index = 78
i = 40, index = 80
i = 41, index = 82
i = 42, index = 84
i = 43, index = 86
i = 44, index = 88
i = 45, index = 90
i = 46, index = 92
i = 47, index = 94
i = 48, index = 96
i = 49, index = 98
i = 50, index = 100
i = 51, index = 102
i = 52, index = 104
i = 5