In [1]:
import torch

In [2]:
import h5py
import numpy as np
import time

from k_filters.estkf import *
from k_filters.estkf_torch import *

In [28]:
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())
# check if cuda is available
print(torch.cuda.is_available())

True
True
False


In [40]:
def stats(x):
    return xa.shape, np.mean(x,axis=1), np.cov(x[:1000,:],ddof=1)

def timeit(f):
    """
    function execution time
    """
    st = time.time()
    x = f._assimilate()
    et = time.time()

    # get the execution time
    elapsed_time = et - st
    
    return 'Execution time:', elapsed_time, 'seconds'

In [4]:
# Load Test Data

filename = "sample_harvey_state3.hdf5"
f = h5py.File(filename, "r")
print(f.keys())

obs =  np.array(f.get(list(f.keys())[1]))
zeta = np.array(f.get(list(f.keys())[4]))

print("wse dim = {} \n ".format(zeta.shape), 
      "obs dim = {}".format(obs.shape)
     )

<KeysViewHDF5 ['coords', 'observations', 'station_coords', 'station_node_inds', 'zeta']>
wse dim = (2, 3352598) 
  obs dim = (11,)


In [5]:
xf = zeta.T
Ne = xf.shape[1]
sigma = 0.5

Xf = xf
R = sigma * np.ones(obs.size)
Y = obs
H = np.eye(obs.size)  
HXf = H[:,::6]

params = [Xf,HXf, Y, R]

In [6]:
filter_estkf = ESTKF(*params)
xa = filter_estkf._assimilate()

In [7]:
torch_estkf = ESTKFT(*params)
xatorch = torch_estkf._assimilate()

In [41]:
stats(xa)

((3352598, 2),
 array([0.54154072, 0.54413802, 0.55966858, ..., 0.30275894, 0.30275792,
        0.30275961]),
 array([[ 1.62669040e-05,  1.33313593e-05,  9.91098353e-06, ...,
         -9.89622554e-06,  1.01836389e-05, -1.56405570e-06],
        [ 1.33313593e-05,  1.09255664e-05,  8.12243573e-06, ...,
         -8.11034098e-06,  8.34588736e-06, -1.28180436e-06],
        [ 9.91098353e-06,  8.12243573e-06,  6.03849354e-06, ...,
         -6.02950189e-06,  6.20461504e-06, -9.52936728e-07],
        ...,
        [-9.89622554e-06, -8.11034098e-06, -6.02950189e-06, ...,
          6.02052363e-06, -6.19537603e-06,  9.51517753e-07],
        [ 1.01836389e-05,  8.34588736e-06,  6.20461504e-06, ...,
         -6.19537603e-06,  6.37530661e-06, -9.79152419e-07],
        [-1.56405570e-06, -1.28180436e-06, -9.52936728e-07, ...,
          9.51517753e-07, -9.79152419e-07,  1.50383271e-07]]))

In [42]:
stats(xatorch)

((3352598, 2),
 array([0.54154072, 0.54413802, 0.55966858, ..., 0.30275894, 0.30275792,
        0.30275961]),
 array([[ 1.62669051e-05,  1.33313602e-05,  9.91098419e-06, ...,
         -9.89622620e-06,  1.01836395e-05, -1.56405580e-06],
        [ 1.33313602e-05,  1.09255672e-05,  8.12243628e-06, ...,
         -8.11034153e-06,  8.34588792e-06, -1.28180444e-06],
        [ 9.91098419e-06,  8.12243628e-06,  6.03849394e-06, ...,
         -6.02950229e-06,  6.20461545e-06, -9.52936791e-07],
        ...,
        [-9.89622620e-06, -8.11034153e-06, -6.02950229e-06, ...,
          6.02052404e-06, -6.19537644e-06,  9.51517817e-07],
        [ 1.01836395e-05,  8.34588792e-06,  6.20461545e-06, ...,
         -6.19537644e-06,  6.37530704e-06, -9.79152484e-07],
        [-1.56405580e-06, -1.28180444e-06, -9.52936791e-07, ...,
          9.51517817e-07, -9.79152484e-07,  1.50383281e-07]]))

In [17]:
np.allclose(mean2,meant)

True

In [35]:
timeit(filter_estkf)

('Execution time:', 0.05569720268249512, 'seconds')

In [34]:
timeit(torch_estkf)

('Execution time:', 0.03415107727050781, 'seconds')