# Importing Libraries

In [1]:
from diff_weighted_fields import Grid1D, GaussianFieldGenerator1D, Zeldovich1D
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.random import PRNGKey, split
import jax
from tqdm import tqdm
import jax.scipy.linalg as jsp_lin

#grid setup
L = 4000
N = 1024
dk     = 10  #units of kf 
kmin   = 2  #units of kf
R_clip = 2  #units of cell size
kmax   =0.5 #units of kNyq

#define the field and gaussian generator
grid = Grid1D((N,), L, R_clip = R_clip, dk =dk, kmax = kmax , kmin =kmin)
mesh = Grid1D((N//2,), L, R_clip = R_clip, dk =dk, kmax = kmax , kmin =kmin)

def PK(k,theta, threshold = 0.001):
    A,R,n = theta
    _p = A*(k*R)**n*jnp.exp(-(k*R)**2)
    sig = jnp.max(_p)
    return _p + sig* threshold

Using clipping smoothing with R_clip = 2 cells (physical = 7.8125); k_clip = 8.042e-01
kmin: 0.0031415926535897933
kmax: 0.4021238684654236
dk: 0.015707963267948967
Using clipping smoothing with R_clip = 2 cells (physical = 15.625); k_clip = 4.021e-01
kmin: 0.0031415926535897933
kmax: 0.2010619342327118
dk: 0.015707963267948967


In [2]:
#white noise used to generate the field
keys = split(PRNGKey(2),300)
noise = jnp.array([grid.generate_hermitian_noise(x) for x in keys])

# Defining main functions 

In [3]:
#gaussian and zeldovich field generators
gen = GaussianFieldGenerator1D(grid, PK)
zel = Zeldovich1D(gen, mesh, scheme = 'cic')

In [4]:
Nx = 10000
Ny = 1

x_min, x_max = -1, 1
y_min, y_max = -0, 0

x_vec = jnp.linspace(x_min, x_max, Nx)
y_vec = jnp.linspace(y_min, y_max, Ny)

# --- (3) build a meshgrid and flatten it ---
# X, Y each have shape (Ny, Nx)
X, Y = jnp.meshgrid(x_vec, y_vec, indexing='xy')

# Now flatten:
#   xs_flat, ys_flat each shape (Ny*Nx,)
xs_flat = X.ravel()
ys_flat = Y.ravel()

# Build a 2D array of C‐vectors, shape (Ny*Nx, 4):
# each row = [ 1, x_i, y_i, 0 ]
ones_flat = jnp.ones_like(xs_flat)
zeros_flat = jnp.zeros_like(xs_flat)
C_flat = jnp.stack([ones_flat/5, xs_flat, ys_flat, zeros_flat], axis=1)  # shape (Ny*Nx, 4)

In [6]:
#fid params
A = 2.
R = 4*grid.H[0]
n = 2.
theta = jnp.array([A, R, n])
D = 1 #growth factor

#We define a batch of C vectors
BATCH_SIZE = 100

def compute_inv_cov(x):
    x = x - jnp.mean(x, axis=0)
    return jnp.linalg.inv((x.T @ x) / (x.shape[0] - 1))
compute_inv_cov_batch = jax.vmap(compute_inv_cov, in_axes=0)

def compute_batch(theta,D,C):
    def compute_pk(theta,D,C):
        pk_batch = zel.make_realization_batch(D,theta,0,C,noise)
        return jnp.mean(pk_batch, axis=1)
    pk_batch = zel.make_realization_batch(D,theta,0,C,noise)
    inv_covs = compute_inv_cov_batch(pk_batch)
    J = jax.jacrev(compute_pk, argnums=0)(theta, D, C)
    F = jnp.einsum('nij,nip, njq ->npq', inv_covs,J,J)
    return jnp.linalg.det(F)

results_batch = jax.jit(compute_batch)

In [7]:
F = results_batch(theta,D,C_flat[0:BATCH_SIZE])
F = F.block_until_ready()

: 

In [7]:
STEPS = C_flat.shape[0] // BATCH_SIZE

In [None]:
Fs = jnp.zeros(C_flat.shape[0])
for i in tqdm(range(STEPS)):
    start = i * BATCH_SIZE
    end = start + BATCH_SIZE
    _F = results_batch(theta, D, C_flat[start:end])
    Fs = Fs.at[start:end].set(_F)

In [None]:
plt.plot(Fs)

In [13]:
def improved_compute_all_dets(theta, D, C_flat, batch_size, compute_batch_fn):
    """
    Compute determinants for all of C_flat by scanning over fixed‐size batches.

    Args:
      theta:      array of shape (3,)   (fiducial [A,R,n])
      D:          scalar               (growth factor)
      C_flat:     array (N_total, 4)   (all C rows stacked)
      batch_size: int (size of each minibatch)
      compute_batch_fn:  jit-compiled function
         signature: (theta, D, C_batch) -> shape (batch_size,)
         where C_batch has shape (batch_size, 4).

    Returns:
      dets_all:   array (N_total,) containing det F for each row of C_flat.
    """

    # 1) Figure out how many full batches we have:
    N_total = C_flat.shape[0]
    num_batches = N_total // batch_size

    # 2) Reshape C_flat into (num_batches, batch_size, 4)
    C_batches = C_flat.reshape((num_batches, batch_size, 4))

    # 3) Define the scan body.  We don’t need a “carry” here, so carry is just None.
    def scan_body(carry, C_batch):
        # compute_batch_fn returns a (batch_size,) array of det F for this batch
        dets = compute_batch_fn(theta, D, C_batch)
        return carry, dets

    # 4) Run lax.scan over all batches
    #    - carry_out is unused (None)
    #    - dets_stacked has shape (num_batches, batch_size)
    _, dets_stacked = jax.lax.scan(scan_body, None, C_batches)

    # 5) Flatten back to (N_total,)
    dets_all = dets_stacked.reshape((N_total,))
    return dets_all

In [14]:
improved_compute_all_dets_jit = jax.jit(
    improved_compute_all_dets,
    static_argnums=(3, 4)  
)

# Finally, call it:
#   Fs_all will have shape (N_total,)
Fs_all = improved_compute_all_dets_jit(theta, D, C_flat, BATCH_SIZE, results_batch)

In [None]:
Fs_all

In [16]:
Fs_all = improved_compute_all_dets_jit(theta, D, C_flat, BATCH_SIZE, results_batch).block_until_ready()

In [17]:
Fs_all = improved_compute_all_dets(theta, D, C_flat, BATCH_SIZE, results_batch).block_until_ready() 