In [1]:
import torch

In [2]:
import h5py
import numpy as np
import time

from k_filters.estkf import *
from k_filters.estkf_torch import *
from k_filters.ensrf import *
from k_filters.ensrf_torch import *
from k_filters.senkf import *
from k_filters.senkf_torch import *

In [3]:
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())
# check if cuda is available
print(torch.cuda.is_available())

True
True
False


In [4]:
def stats(x):
    return x.shape, np.mean(x,axis=1), np.cov(x[:1000,:],ddof=1)

def close(xa,xb):
    return np.allclose(xa.mean(axis=1),xb.mean(axis=1))

def timeit(f):
    """
    function execution time
    """
    st = time.time()
    x = f._assimilate()
    et = time.time()

    # get the execution time
    elapsed_time = et - st
    
    return 'Execution time:', elapsed_time, 'seconds'

In [5]:
# Load Test Data

filename = "sample_harvey_state3.hdf5"
f = h5py.File(filename, "r")
print(f.keys())

obs =  np.array(f.get(list(f.keys())[1]))
zeta = np.array(f.get(list(f.keys())[4]))

print("wse dim = {} \n ".format(zeta.shape), 
      "obs dim = {}".format(obs.shape)
     )

<KeysViewHDF5 ['coords', 'observations', 'station_coords', 'station_node_inds', 'zeta']>
wse dim = (2, 3352598) 
  obs dim = (11,)


In [6]:
xf = zeta.T
Ne = xf.shape[1]
sigma = 0.5

Xf = xf
R = sigma * np.ones(obs.size)
Y = obs
H = np.eye(obs.size)  
HXf = H[:,::6]

params = [Xf,HXf, Y, R]

In [7]:
filter_estkf = ESTKF(*params)
x1 = filter_estkf._assimilate()
torch_estkf = ESTKFT(*params)
x2 = torch_estkf._assimilate()

In [8]:
print(stats(x1),"\n")
print(stats(x2),"\n")
print(close(x1,x2),"\n")
print(timeit(filter_estkf))
print(timeit(torch_estkf))

((3352598, 2), array([0.54154072, 0.54413802, 0.55966858, ..., 0.30275894, 0.30275792,
       0.30275961]), array([[ 1.62669040e-05,  1.33313593e-05,  9.91098353e-06, ...,
        -9.89622554e-06,  1.01836389e-05, -1.56405570e-06],
       [ 1.33313593e-05,  1.09255664e-05,  8.12243573e-06, ...,
        -8.11034098e-06,  8.34588736e-06, -1.28180436e-06],
       [ 9.91098353e-06,  8.12243573e-06,  6.03849354e-06, ...,
        -6.02950189e-06,  6.20461504e-06, -9.52936728e-07],
       ...,
       [-9.89622554e-06, -8.11034098e-06, -6.02950189e-06, ...,
         6.02052363e-06, -6.19537603e-06,  9.51517753e-07],
       [ 1.01836389e-05,  8.34588736e-06,  6.20461504e-06, ...,
        -6.19537603e-06,  6.37530661e-06, -9.79152419e-07],
       [-1.56405570e-06, -1.28180436e-06, -9.52936728e-07, ...,
         9.51517753e-07, -9.79152419e-07,  1.50383271e-07]])) 

((3352598, 2), array([0.54154072, 0.54413802, 0.55966858, ..., 0.30275894, 0.30275792,
       0.30275961]), array([[ 1.62669051e-05,

In [9]:
filter_ensrf = EnSRF(*params)
x3 = filter_ensrf._assimilate()
torch_ensrf = EnSRFT(*params)
x4 = torch_ensrf._assimilate()

In [10]:
print(stats(x3))
print(stats(x4),"\n")
print(close(x3,x4),"\n")
print(timeit(filter_ensrf))
print(timeit(torch_ensrf))

((3352598, 2), array([0.54154072+0.j, 0.54413802+0.j, 0.55966858+0.j, ...,
       0.30275894+0.j, 0.30275792+0.j, 0.30275961+0.j]), array([[ 1.62669040e-05+0.j,  1.33313593e-05+0.j,  9.91098353e-06+0.j,
        ..., -9.89622554e-06+0.j,  1.01836389e-05+0.j,
        -1.56405570e-06+0.j],
       [ 1.33313593e-05+0.j,  1.09255664e-05+0.j,  8.12243573e-06+0.j,
        ..., -8.11034098e-06+0.j,  8.34588736e-06+0.j,
        -1.28180436e-06+0.j],
       [ 9.91098353e-06+0.j,  8.12243573e-06+0.j,  6.03849354e-06+0.j,
        ..., -6.02950189e-06+0.j,  6.20461504e-06+0.j,
        -9.52936728e-07+0.j],
       ...,
       [-9.89622554e-06+0.j, -8.11034098e-06+0.j, -6.02950189e-06+0.j,
        ...,  6.02052363e-06+0.j, -6.19537603e-06+0.j,
         9.51517753e-07+0.j],
       [ 1.01836389e-05+0.j,  8.34588736e-06+0.j,  6.20461504e-06+0.j,
        ..., -6.19537603e-06+0.j,  6.37530661e-06+0.j,
        -9.79152419e-07+0.j],
       [-1.56405570e-06+0.j, -1.28180436e-06+0.j, -9.52936728e-07+0.j,
     

In [11]:
filter_senkf = SEnKF(*params)
x5 = filter_senkf._assimilate()
torch_senkf = SEnKFT(*params)
x6 = torch_senkf._assimilate()

In [12]:
print(stats(x5))
print(stats(x6),"\n")
print(close(x5,x6),"\n")
print(timeit(filter_senkf))
print(timeit(torch_senkf))

((3352598, 2), array([0.54390396, 0.54607479, 0.56110844, ..., 0.30093394, 0.30093293,
       0.30093462]), array([[ 3.04915212e-08,  2.49889852e-08,  1.85776571e-08, ...,
        -1.85499939e-08,  1.90887363e-08, -2.93174641e-09],
       [ 2.49889852e-08,  2.04794434e-08,  1.52251111e-08, ...,
        -1.52024401e-08,  1.56439604e-08, -2.40267999e-09],
       [ 1.85776571e-08,  1.52251111e-08,  1.13188627e-08, ...,
        -1.13020083e-08,  1.16302494e-08, -1.78623360e-09],
       ...,
       [-1.85499939e-08, -1.52024401e-08, -1.13020083e-08, ...,
         1.12851790e-08, -1.16129314e-08,  1.78357380e-09],
       [ 1.90887363e-08,  1.56439604e-08,  1.16302494e-08, ...,
        -1.16129314e-08,  1.19502025e-08, -1.83537364e-09],
       [-2.93174641e-09, -2.40267999e-09, -1.78623360e-09, ...,
         1.78357380e-09, -1.83537364e-09,  2.81886134e-10]]))
((3352598, 2), array([0.54103649, 0.54372479, 0.55936137, ..., 0.30314833, 0.30314731,
       0.303149  ]), array([[ 6.36153926e-06,  