# Importing Libraries

In [67]:
from diff_weighted_fields import Grid1D, PowerSpectrum, GaussianFieldGenerator1D, Zeldovich1D
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.random import PRNGKey, split
import jax
import numpy as np
from getdist import MCSamples, plots
from tqdm import tqdm

#grid setup
L = 4000
N = 1024
dk     = 10  #units of kf 
kmin   = 2  #units of kf
R_clip = 2  #units of cell size
kmax   =0.5 #units of kNyq

#define the field and gaussian generator
grid = Grid1D((N,), L, R_clip = R_clip, dk =dk, kmax = kmax , kmin =kmin)
mesh = Grid1D((N//2,), L, R_clip = R_clip, dk =dk, kmax = kmax , kmin =kmin)

def PK(k,theta, threshold = 0.001):
    A,R,n = theta
    _p = A*(k*R)**n*jnp.exp(-(k*R)**2)
    sig = jnp.max(_p)
    return _p + sig* threshold

Using clipping smoothing with R_clip = 2 cells (physical = 7.8125); k_clip = 8.042e-01
kmin: 0.0031415926535897933
kmax: 0.4021238684654236
dk: 0.015707963267948967
Using clipping smoothing with R_clip = 2 cells (physical = 15.625); k_clip = 4.021e-01
kmin: 0.0031415926535897933
kmax: 0.2010619342327118
dk: 0.015707963267948967


In [68]:
#white noise used to generate the field
keys = split(PRNGKey(2),300)
noise = jnp.array([grid.generate_hermitian_noise(x) for x in keys])

# Defining main functions 

In [69]:
#gaussian and zeldovich field generators
gen = GaussianFieldGenerator1D(grid, PK)
zel = Zeldovich1D(gen, mesh, scheme = 'cic')

In [71]:
Nx = 50
Ny = 50

x_min, x_max = -0.1, 0.1
y_min, y_max = -0.1, 0.1

x_vec = jnp.linspace(x_min, x_max, Nx)
y_vec = jnp.linspace(y_min, y_max, Ny)

# --- (3) build a meshgrid and flatten it ---
# X, Y each have shape (Ny, Nx)
X, Y = jnp.meshgrid(x_vec, y_vec, indexing='xy')

# Now flatten:
#   xs_flat, ys_flat each shape (Ny*Nx,)
xs_flat = X.ravel()
ys_flat = Y.ravel()

# Build a 2D array of C‐vectors, shape (Ny*Nx, 4):
# each row = [ 1, x_i, y_i, 0 ]
ones_flat = jnp.ones_like(xs_flat)
zeros_flat = jnp.zeros_like(xs_flat)
C_flat = jnp.stack([ones_flat/5, xs_flat, ys_flat, zeros_flat], axis=1)  # shape (Ny*Nx, 4)

In [73]:
#fid params
A = 2.
R = 4*grid.H[0]
n = 2.
theta = jnp.array([A, R, n])
D = 1 #growth factor

#We define a batch of C vectors
BATCH_SIZE = 10

def compute_inv_cov(x):
    x = x - jnp.mean(x, axis=0)
    return jnp.linalg.inv((x.T @ x) / (x.shape[0] - 1))
compute_inv_cov_batch = jax.vmap(compute_inv_cov, in_axes=0)

def compute_batch(theta,D,C):
    def compute_pk(theta,D,C):
        pk_batch = zel.make_realization_batch(D,theta,0,C,noise)
        return jnp.mean(pk_batch, axis=1)
    pk_batch = zel.make_realization_batch(D,theta,0,C,noise)
    inv_covs = compute_inv_cov_batch(pk_batch)
    J = jax.jacfwd(compute_pk, argnums=0)(theta, D, C)
    F = jnp.einsum('nij,nip, njq ->npq', inv_covs,J,J)
    return jnp.linalg.det(F)

results_batch = jax.jit(compute_batch)

In [74]:
F = results_batch(theta,D,C_flat[0:BATCH_SIZE])
F = F.block_until_ready()

In [76]:
STEPS = C_flat.shape[0] // BATCH_SIZE

In [None]:
Fs = jnp.zeros_like(C_flat)
for i in tqdm(range(STEPS)):
    start = i * BATCH_SIZE
    end = start + BATCH_SIZE
    F = results_batch(theta, D, C_flat[start:end])
    if i == 0:
        means = mean
    else:
        means = jnp.concatenate([means, mean], axis=0)

In [None]:
jac(theta,C_batch).block_until_ready()

In [None]:
jac(theta,C_batch).block_until_ready()

In [None]:
plt.plot(jnp.mean(pk_batch,axis = 1).T)

In [None]:
pk_batch.shape

In [None]:
mesh.k_ctrs.shape

In [None]:
plt.plot(pk_batch[0].T)

In [19]:
mu = jnp.mean(pk_batch,axis = -1)

In [None]:
plt.plot(mu.T)

In [None]:
jnp.cov(pk_batch)

In [None]:
one_plus_delta_batch.shape

In [None]:
m_batch.shape

In [21]:
rho_weighted = one_plus_delta_batch[None,:,:]*m_batch
rho_bar = jnp.mean(rho_weighted, axis = -1)
delta_weighted = rho_weighted/rho_bar[:,:,None] - 1

In [None]:
delta_weighted.shape

In [33]:
a = rho_weighted[1,0,:]/jnp.mean(rho_weighted[1,0,:])-1

In [None]:
a/delta_weighted[1,0,:]

In [None]:
rho_weighted.shape

In [None]:
jnp.mean(rho_weighted,axis = -1).shape

In [None]:
rho_weighted = jnp.einsum()

In [None]:
m = jnp.dot(C,self.m_array)
rho_weighted = m*self.one_plus_delta
mean_rho_marked = jnp.mean(rho_weighted)
delta_marked = rho_weighted / mean_rho_marked - 1.0
field_marked = Field1D(grid=self.grid)
field_marked.assign_from_real_space(delta_marked)
field_marked.compute_fft()
field_marked.W = self.W
return field_marked

In [None]:
pk_batch[0]

In [6]:
def PowerSpectrum_batch(delta_k_batch: jnp.ndarray, W_batch: jnp.ndarray, grid) -> jnp.ndarray:
    # Number of k‐bins
    nbins = len(grid.k_edges) - 1

    # Compute |δₖ|² for each batch entry
    field_k_abs = delta_k_batch * jnp.conjugate(delta_k_batch)  # shape (..., N)

    # Safe‐divide by W for compensation (avoid tiny denominators)
    eps = 1e-3
    safe_W = jnp.where(jnp.abs(W_batch) < eps, 1.0, W_batch)
    field_k_abs = field_k_abs / safe_W / safe_W  # still shape (..., N)

    # Flatten every leading batch dimension into a single axis of length B
    batch_shape = field_k_abs.shape[:-1]
    N = field_k_abs.shape[-1]
    flat_field = field_k_abs.reshape((-1, N))  # shape (B, N)

    # k_mapping: shape (N,) with bin indices or −1 for “ignore”
    k_mapping = grid.k_mapping
    valid = k_mapping >= 0
    kmap_safe = jnp.where(valid, k_mapping, 0)

    def ps_single(field_row):
        # Mask out invalid entries
        masked = jnp.where(valid, field_row, 0.0)
        counts = jnp.bincount(kmap_safe, weights=valid.astype(field_row.dtype), length=nbins)
        power  = jnp.bincount(kmap_safe, weights=masked,     length=nbins)
        pk     = jnp.real(jnp.where(counts > 0, power / counts, 0.0))
        return pk * (grid.H ** grid.Ndim)
    
    pk_flat = jax.vmap(ps_single)(flat_field)  # shape (B, nbins)
    return pk_flat.reshape(batch_shape + (nbins,))

In [9]:
pk_batch = PowerSpectrum_batch(marked_batch,jnp.ones_like(marked_batch),mesh)

In [5]:
R_smooth = grid.H[0] * 3
smooth = jnp.exp(-0.5 * (mesh.kgrid_abs* R_smooth)**2)

In [9]:
delta_k_batch = abs(delta_k_batch)

In [13]:
res = delta_k_batch*smooth

In [None]:
res.shape

In [None]:
res.shape

We start by studying the displacement vector:

In [None]:
#fid params
A = 2.
R = 4*grid.H[0]
n = 2.
D = 1 #growth factor

def max_growth_factor_no_shell_crossing(psi: jnp.ndarray, H: float) -> float:
    dpsi_dq = jnp.diff(psi) / H
    min_grad = jnp.min(dpsi_dq)
    return -1.0 / min_grad

psi_q = jnp.array([zel.make_realization(D,[A,R,n],x, displacement=True) for x in noise])
max_D = jnp.array([max_growth_factor_no_shell_crossing(psi, grid.H[0]) for psi in psi_q])
min_max_D = jnp.min(max_D)
print(min_max_D)

In [None]:
D = 1.6

#generate the field in fourier space and take ifft
gauss = gen.make_realization_from_noise((A,R,n),noise[0])
gauss.compute_ifft()

#zel are generated originally in fourier
zel = zel.make_realization(D,[A,R,n],noise[0])

plt.figure(figsize = (8,2))
plt.plot(gauss.grid.q, gauss.delta, lw =2, label = 'initial gauss')
plt.plot(zel.grid.q,zel.delta/D, lw = 2, label = 'zeldovich')
plt.xlim((1500,2200))
plt.ylim((-0.4,0.8))
plt.grid()
plt.legend()
plt.xlabel('x')
plt.ylabel(r'$\delta$')

In [7]:
def _plin(theta,noise):
    gauss = gen.make_realization_from_noise(theta,noise)
    gauss.compute_ifft()
    return PowerSpectrum(gauss, gauss)

def plin_samples(theta):
    # Compute the power‐spectrum for each noise realization, then average
    ps = jax.vmap(lambda ε: _plin(theta, ε), in_axes=0)(noise)   # (N_real, P_len)
    return ps

def plin(theta):
    # Compute the power‐spectrum for each noise realization, then average
    ps = jax.vmap(lambda ε: _plin(theta, ε), in_axes=0)(noise)   # (N_real, P_len)
    return jnp.mean(ps, axis=0) 

def cov(theta):
    # Gather the raw PS realizations to form a sample covariance
    ps = jax.vmap(lambda ε: _plin(theta, ε), in_axes=0)(noise)   # (N_real, P_len)
    return jnp.cov(ps.T)   

# JIT‐compile both:
plin_jit   = jax.jit(plin)
cov_jit    = jax.jit(cov)
jac_plin   = jax.jit(jax.jacfwd(plin)) 

def ParamCov(theta):
    theta = jnp.asarray(theta)
    J = jac_plin(theta)
    C_inv = jnp.linalg.inv(cov_jit(theta))
    return jnp.linalg.inv(J.T @ (C_inv @ J))

In [8]:
param_cov = ParamCov((A,R,n))

In [None]:
def correlation_from_covariance(covariance_matrix):
    diag = jnp.sqrt(jnp.diag(covariance_matrix))
    outer_diag = jnp.outer(diag, diag)
    correlation_matrix = covariance_matrix / outer_diag
    correlation_matrix = jnp.nan_to_num(correlation_matrix)
    return correlation_matrix

correlation_matrix = correlation_from_covariance(cov_jit((A, R, n)))
plt.imshow(correlation_matrix, cmap='coolwarm', vmin=-1, vmax=1)

In [None]:
#draw samples from the multivariate normal distribution and make triangle plot using getdist
import getdist
import getdist.plots as gplt
import numpy as np
from getdist import MCSamples
samples = np.random.multivariate_normal([A,R,n], param_cov, size=1000)
samples = MCSamples(samples=samples, names=['A', 'R', 'n'])
g = gplt.getSubplotPlotter()
g.triangle_plot(samples, filled=True, contour_levels=[0.68, 0.95])

In [11]:
def likelihood(A, R, n, data):
    theory = plin((A, R, n))
    #the ensemble average only agree when cov indep. of parameters
    inv_cov = jax.lax.stop_gradient(jnp.linalg.inv(cov((A, R, n))))
    diff = theory - data
    return 0.5 * jnp.dot(diff, jnp.dot(inv_cov, diff))

data = plin_samples((A,R,n))

In [12]:
hess_jit = jax.hessian(likelihood, argnums=(0, 1, 2))
_ = hess_jit(A, R, n, data[0])

hess_sample = []
for i in range(200):
    hess_matrix = jnp.array(hess_jit(A, R, n, data[i]))
    hess_sample.append(hess_matrix)

In [13]:
fisher_sample = jnp.array([jnp.linalg.inv(h) for h in hess_sample])

In [None]:
plt.hist(fisher_sample[:,0,0],bins = 30)
plt.vlines(param_cov[0,0], 0, 100, color='red', label='Fisher from covariance')

# Marked Likelihood

In [20]:
gen = GaussianFieldGenerator1D(grid, PK)

def _plin_mark(D,theta,R_smooth,C,noise):
    zel = Zeldovich1D(gen, mesh, scheme = 'cic')
    zel = zel.make_realization(D, theta, noise)
    zel.compute_fft()
    zel.ComputeBasis(R_smooth)
    f = zel.WeightedChild(C)
    return PowerSpectrum(f, f)

In [21]:
def plin_mark(theta, D, R_smooth,C):
    ps_samples = jax.vmap(
        lambda ε: _plin_mark(D, theta, R_smooth, C, ε),
        in_axes=0,
    )(noise)                  # shape: (N_realizations, P_len)

    return jnp.mean(ps_samples, axis=0)   # shape: (P_len,)

In [22]:
A = 2.
R = 2. * grid.H[0]
n = 2.
theta = jnp.array([A, R, n])
C = jnp.array([1., 0., 0., 0])
D = 1.

In [23]:
def cov_mark(theta, D, R_smooth,C):
    """
    Return the P_len×P_len covariance matrix of the raw PS
    samples at (A,R,n)=theta, averaged over noise.
    """

    A, R, n = theta[0], theta[1], theta[2]
    ps_samples = jax.vmap(
        lambda ε: _plin_mark(D, theta, R_smooth, C, ε),
        in_axes=0,
    )(noise)  # (N_realizations, P_len)

    # Cov of shape (P_len, P_len)
    return jnp.cov(ps_samples.T)

In [None]:
%timeit -n 2 -r 10 cov_mark

In [44]:
jac_plin_mark   = jax.jacrev(plin_mark)

def ParamCov_marked(theta,D, R_smooth,C):
    theta = jnp.asarray(theta)
    J = jac_plin_mark(theta,D, R_smooth,C)
    C_inv = jnp.linalg.inv(cov_mark(theta,D, R_smooth,C))
    return jnp.linalg.inv(J.T @ (C_inv @ J))

def FoM_marked(theta,D,R_smooth,C):
    return jnp.linalg.det(ParamCov_marked(theta,D, R_smooth,C))

# Marked posteriors for varying D
A = 2.0
R = 4*grid.H[0]
n = 2.0

In [None]:
Ds = jnp.linspace(0.1, 1.6, 5)
R_smooth = 2*grid.H[0]
C = jnp.array([1,0,0,0])
marked_posteriors = []
FoM_0 = []
for d in Ds:
    marked_cov = ParamCov_marked((A,R,n),d, R_smooth,C)
    FoM_0.append(jnp.linalg.det(marked_cov))

    marked_samples = np.random.multivariate_normal([A, R, n], marked_cov, size=10000)
    print(jnp.linalg.inv(marked_cov))
    marked_posteriors.append(MCSamples(samples=marked_samples, names=['A', 'R', 'n']))

# Combine all posteriors into a single corner plot
g = plots.getSubplotPlotter()
g.triangle_plot(
    marked_posteriors[::-1],
    filled=True,
    contour_levels=[0.95],
    legend_labels=[f'Marked D={D_val:.2f}' for D_val in Ds][::-1]+['Linear'],
)

plt.figure()
plt.plot(Ds,FoM_0)
plt.xlabel('D')
plt.ylabel('FoM')
plt.grid()
del(g)
del(marked_posteriors)
del(marked_samples)

In [46]:
R_smooth = 2*mesh.H
D = 1
FoM_marked_C = jax.jit(jax.vmap(FoM_marked, in_axes=(None, None, None, 0)))
_ = FoM_marked_C((A,R,n),D,R_smooth,jnp.array([[1,0,0,0]])).block_until_ready()

In [31]:
Nx = 200
Ny = 200

x_min, x_max = -2.0, 2.0
y_min, y_max = -2.0, 2.0

x_vec = jnp.linspace(x_min, x_max, Nx)
y_vec = jnp.linspace(y_min, y_max, Ny)

# --- (3) build a meshgrid and flatten it ---
# X, Y each have shape (Ny, Nx)
X, Y = jnp.meshgrid(x_vec, y_vec, indexing='xy')

# Now flatten:
#   xs_flat, ys_flat each shape (Ny*Nx,)
xs_flat = X.ravel()
ys_flat = Y.ravel()

# Build a 2D array of C‐vectors, shape (Ny*Nx, 4):
# each row = [ 1, x_i, y_i, 0 ]
ones_flat = jnp.ones_like(xs_flat)
zeros_flat = jnp.zeros_like(xs_flat)
C_flat = jnp.stack([ones_flat/5, xs_flat, ys_flat, zeros_flat], axis=1)  # shape (Ny*Nx, 4)

In [None]:
BATCH = 500
FoM_marked_C((A,R,n),D,R_smooth,C_flat[0:BATCH]).block_until_ready()

In [None]:
FoM_marked_C((A,R,n),D,R_smooth,C_flat[0:BATCH]).block_until_ready()

In [27]:
def single_power(theta, D, R_smooth, C_i, epsilon_j):
    # exactly what _plin_mark does
    return _plin_mark(D, theta, R_smooth, C_i, epsilon_j)  # (P_len,)

In [40]:
def single_power(theta, D, R_smooth, C_i, epsilon_j):
    # exactly what _plin_mark does
    return _plin_mark(D, theta, R_smooth, C_i, epsilon_j)  # (P_len,)

In [None]:
plt.plot(single_power((A, R, n), D, 0, C_flat[0], noise[0]))