In [1]:
import numpy as np
from enkf import eakf
from numba import jit


def eakf(ensemble_size, nobs, zens, Hk, obs_error_var, localize, CMat, zobs):
    """
    Ensemble Adjustment Kalman Filter (EAKF).
    
    Parameters:
        ensemble_size (int): Number of ensemble members.
        nobs (int): Number of observations.
        zens (np.ndarray): Ensemble matrix of shape (ensemble_size, nmod).
        Hk (np.ndarray): Observation operator matrix of shape (nobs, nmod).
        obs_error_var (float): Observation error variance.
        localize (int): Flag for localization (1 for applying localization, 0 otherwise).
        CMat (np.ndarray): Localization matrix of shape (nobs, nmod).
        zobs (np.ndarray): Observations of shape (nobs,).
    
    Returns:
        np.ndarray: Updated ensemble matrix.
    """
    rn = 1.0 / (ensemble_size - 1)
    
    for iobs in range(nobs):
        xmean = np.mean(zens, axis=0)
        xprime = zens - xmean
        hxens = (Hk[iobs, :] @ zens.T)
        hxmean = np.mean(hxens)
        hxprime = hxens - hxmean
        hpbht = hxprime @ hxprime.T * rn
        gainfact = (hpbht + obs_error_var) / hpbht * (1.0 - np.sqrt(obs_error_var / (hpbht + obs_error_var)))
        pbht = (xprime.T @ hxprime) * rn
        
        if localize == 1:
            Cvect = CMat[iobs, :]
            kfgain = Cvect * (pbht / (hpbht + obs_error_var))
        else:
            kfgain = pbht / (hpbht + obs_error_var)

        mean_inc = kfgain * (zobs[iobs] - hxmean)
        prime_inc = - (gainfact * kfgain[:, None] @ hxprime[None, :]).T
        
        zens = zens + mean_inc + prime_inc
    
    return zens

ensemble_size = 10
nobsgrid = 3
nmod = 5
obs_error_var = 0.5
localize = 1

# Generate test data
np.random.seed(42)
zens_mat = np.mat(np.random.rand(ensemble_size, nmod))
Hk_mat = np.mat(np.random.rand(nobsgrid, nmod))
CMat_mat = np.mat(np.random.rand(nobsgrid, nmod))
zobs_mat = np.mat(np.random.rand(1, nobsgrid))

# Convert test data to np.array for the corrected function
zens_mat_array = np.array(zens_mat)
Hk_mat_array = np.array(Hk_mat)
CMat_mat_array = np.array(CMat_mat)
zobs_mat_array = np.array(zobs_mat)[0,:]

# Run the corrected original function using np.array
zens_original_corrected_result = eakf1(ensemble_size, nobsgrid, zens_mat_array, Hk_mat_array, obs_error_var, localize, CMat_mat_array, zobs_mat_array)

# Run the updated function
zens_updated_result = eakf(ensemble_size, nobsgrid, zens_mat, Hk_mat, obs_error_var, localize, CMat_mat, zobs_mat)

# Check if the results are approximately equal
difference_corrected = np.linalg.norm(zens_original_corrected_result - zens_updated_result)
is_correct_corrected = np.allclose(zens_original_corrected_result, zens_updated_result, atol=1e-6)

# Display results
{"Difference Norm": difference_corrected, "Are Results Matching": is_correct_corrected}

{'Difference Norm': 0.0, 'Are Results Matching': True}