In [None]:
#%env XLA_PYTHON_CLIENT_ALLOCATOR=platform 
#%env XLA_PYTHON_CLIENT_PREALLOCATE=false
#%env XLA_PYTHON_CLIENT_MEM_FRACTION=.50
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")
import netCDF4 as nc4
import jax
from jax import lax, random, numpy as jnp
import flax
import numpy as np
from flax import linen as nn
import matplotlib.pyplot as plt
from jax.numpy.fft import fft2, ifft2
from flax import serialization

from CNN.model import PeriodicCNN
from CBAM.model import ResNet_CBAM
from Nonlocal.model import NonLocalCNN

import sys
sys.path.append('..')
from bvex_dl import *
from namelist_dl import *
import time as Time
from typing import Union, Any
Array = Union[np.ndarray, jnp.DeviceArray]

In [None]:
@jax.jit
def RMSE(
    array1: Array, 
    array2: Array,
) -> Array:
 
    rmse = jnp.linalg.norm(array1 - array2) / jnp.sqrt((len(array1) * len(array1[0])))
    
    return rmse



def R2(
    array1: Array, 
    array2: Array,
) -> Array:
        
    r2 = np.corrcoef(array1.flatten(), array2.flatten())[0,1]
    
    return r2

In [None]:
@jax.vmap
def MSE(x, y):
    return jnp.mean((x-y)**2)


def get_initial_params(key, model):
    init_val = jnp.ones((1,64,64,3), jnp.float32)
    initial_params = model().init(key, init_val)['params']
    return initial_params


In [None]:
def online_test(ic, params, t_max):
    
    # here t_max is the max time of the first simulation slice 
    
    field = {}
    field[0] = ic['zeta']
    time = {}
    time[0] = ic['time']
    nn_opt = {}
    
    @jax.jit
    def inference(x,y):
        return model().apply(x,y)
    
    
    qNow = ic['zeta']
    tNow = ic['time']

    qNew, tNew = vetdrk4(qNow, tNow)
    
    tNew = jnp.around(tNew, decimals=2)

    pNew, _, _ = laplacian(qNew)

    fNew = jax.vmap(cal_forcing)(pNew, tNew)
    
    
    stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
    correction = inference({'params': params}, stateNew)
    
    qNewCorrected = qNew + correction.squeeze()

    idx = 1
    
    field[idx] = qNewCorrected
    time[idx] = tNew
    nn_opt[idx] = correction
    
    start_0 = Time.time()
    while tNew[0] <= t_max:
        

        qNew, tNew = vetdrk4(qNewCorrected, tNew)
        tNew = jnp.around(tNew, decimals=2)
        pNew, _, _ = laplacian(qNew)
        fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
        stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)

        correction = inference({'params': params}, stateNew)

        qNewCorrected = qNew + correction.squeeze()
        
        idx += 1
        field[idx] = qNewCorrected
        
        time[idx] = tNew
        nn_opt[idx] = correction
        

    return field, time, nn_opt

In [None]:
model_dict = {}
model_dict['CNN8']= {'model_path': 'CNN',
               'model_name': 'CNN8',
                   'model': PeriodicCNN}

model_dict['CNN1']= {'model_path': 'CNN',
               'model_name': 'CNN1',
                   'model': PeriodicCNN}

model_dict['CNN16']= {'model_path': 'CNN',
               'model_name': 'CNN16',
                   'model': PeriodicCNN}

model_dict['CNN24']= {'model_path': 'CNN',
               'model_name': 'CNN24',
                   'model': PeriodicCNN}

model_dict['CNN32']= {'model_path': 'CNN',
               'model_name': 'CNN32',
                   'model': PeriodicCNN}



model_dict['CBAM'] = {'model_path': 'CBAM',
                        'model_name': 'CBAM',
                            'model': ResNet_CBAM}
model_dict['CNN16_WC'] = {'model_path': 'CNN',
                       'model_name' : 'CNN16_WC',
                       'model': PeriodicCNN}
model_dict['Nonlocal'] = {'model_path': 'Nonlocal',
                        'model_name': 'NLCNN',
                            'model': NonLocalCNN}

In [None]:
model_name = 'CNN16'
test_epoch = 100

In [None]:
path_dict = model_dict[model_name]
model = path_dict['model']
loaded = np.load(f'/workspace/yquai/BVEX/DL/DL_Model/{path_dict["model_path"]}/checkpoint/{path_dict["model_name"]}_state_dict_epoch_{test_epoch}.npy', allow_pickle=True).item()
rng = jax.random.PRNGKey(2021)
rng, init_rng = jax.random.split(rng)
params = get_initial_params(init_rng, model)
params = serialization.from_state_dict(params, loaded)

In [None]:
test_idx = 1
print(test_idx)
icFile = nc4.Dataset(f'/workspace/yquai/BVEX/Data/Test/TestIC_R64_Batch_{test_idx}.nc')
ic = {}
zeta = np.transpose(np.copy(icFile['zeta'][:]).astype('float32'), [0,2,1])
ic['zeta'] = zeta
time = np.copy(icFile['time'][:]).astype('float32')
ic['time'] = time
t_max = ic['time'][0]+30
batch_size = len(ic['time'])

print("integration...")

field, time, nn_opt = online_test(ic, params, t_max)


array = np.zeros((batch_size, 61, 64, 64))
print("saving zeta...")
for i in range(601):
    if i % 10 == 0:
        array[:,int(i/10),:,:] = field[i]
array = np.transpose(array, [0,1,3,2])


# save zeta
np.save(f'/workspace/yquai/BVEX/LowRes_HighRes_Test/ZETA_from1024/{model_name}_epoch_{test_epoch}_idx_{test_idx}.npy', array)


print("calculating R2 and RMSE...")
# calculate R2 and RMSE
truth_LowRes = nc4.Dataset(f'/workspace/yquai/BVEX/Data/Test/Truth_64_{test_idx}.nc')
zeta_truth = np.copy(truth_LowRes['zeta'][:]).astype('float32')
n_ensemble = zeta_truth.shape[0]

rmse = np.zeros((n_ensemble, 61))
r2 = np.zeros((n_ensemble, 61))
    
for j in range(n_ensemble):
    for i in range(61):
        rmse[j,i] = RMSE(zeta_truth[j,i], array[j,i])
        r2[j,i] = R2(zeta_truth[j,i], array[j,i])

print("saving R2 and RMSE...")
np.save(f'/workspace/yquai/BVEX/LowRes_HighRes_Test/RMSE/{model_name}_epoch_{test_epoch}_idx_{test_idx}.npy', rmse)
np.save(f'/workspace/yquai/BVEX/LowRes_HighRes_Test/R2/{model_name}_epoch_{test_epoch}_idx_{test_idx}.npy', r2)


In [None]:
ic['zeta'].shape