# AI-Driven Optimization: Chaotic Advection & Stochastic Resonance

**Objective:** Optimize the parameters of a **Chaotic & Stochastic Injection Schedule**.

**Physics Principles:**
1.  **Stochastic Resonance (The "Noise" Effect):** 
    Dislodging salt bridges or overcoming capillary entry pressures often requires crossing an energy barrier. A constant pressure might fail, but adding **noise** (random pressure spikes) allows the system to probabilistically sample higher pressures, "jiggling" particles loose and sustaining transport without high average energy.
    * *Model:* $P(t) = P_{base} + A_{noise} \cdot \xi(t)$, where $\xi$ is Gaussian noise.

2.  **Chaotic Advection (Rotated Potential Mixing):** 
    In laminar flows, mixing is poor (diffusion-limited). By alternating the injection location (e.g., between Top and Bottom perforation zones), we induce **Chaotic Advection**, stretching and folding the fluid to dilute salt concentrations rapidly.
    * *Model:* Injection location $Y_{in}(t)$ oscillates with frequency $f_{switch}$.

**Optimization Parameters $\theta$:**
1.  `base`: Baseline Pressure.
2.  `noise_amp`: Amplitude of the stochastic noise (finding the resonance peak).
3.  `switch_freq`: Frequency of zonal alternation (finding the chaotic mixing rate).

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, lax, value_and_grad, checkpoint
import optax
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

# --- 1. Physical Constants (SAME CONFIG) ---
NX, NY = 100, 50       
TIME_STEPS = 1000      

TAU_BRINE = 1.0        
TAU_CO2 = 0.9          # Stable Viscosity
G_INT = -1.0           # Stable Interaction
RHO_BRINE = 1.0        
RHO_CO2_INIT = 0.1     
D_SALT = 0.05          
K_SP = 1.1             

# Lattice Weights
W = jnp.array([4/9, 1/9, 1/9, 1/9, 1/9, 1/36, 1/36, 1/36, 1/36])
CX = jnp.array([0, 1, 0, -1, 0, 1, -1, -1, 1])
CY = jnp.array([0, 0, 1, 0, -1, 1, 1, -1, -1])

print("Environment Configured: Chaotic & Stochastic Mode")

## 2. Modified Physics Kernels (Zonal Injection)
We modify the `lbm_step` to accept a **Spatial Injection Center** ($Y_{center}$). This allows us to inject fluid into specific zones (Top vs Bottom) rather than the entire face, enabling Rotated Potential Mixing.

In [None]:
@jit
def get_equilibrium(rho, u_x, u_y):
    u_x_exp = jnp.expand_dims(u_x, axis=-1)
    u_y_exp = jnp.expand_dims(u_y, axis=-1)
    rho_exp = jnp.expand_dims(rho, axis=-1)
    u_sq = u_x**2 + u_y**2
    u_sq_exp = jnp.expand_dims(u_sq, axis=-1)
    eu = (CX * u_x_exp + CY * u_y_exp)
    return rho_exp * W * (1.0 + 3.0*eu + 4.5*eu**2 - 1.5*u_sq_exp)

@jit
def interaction_force(rho, G):
    rho_safe = jnp.clip(rho, 1e-3, 5.0)
    psi = 1 - jnp.exp(-rho_safe)
    psi_xp, psi_xm = jnp.roll(psi, -1, axis=0), jnp.roll(psi, 1, axis=0)
    psi_yp, psi_ym = jnp.roll(psi, -1, axis=1), jnp.roll(psi, 1, axis=1)
    fx = -G * psi * (psi_xp - psi_xm)
    fy = -G * psi * (psi_yp - psi_ym)
    return fx, fy

@jit
def collision_stream(f, salt_conc, mask, tau):
    rho = jnp.sum(f, axis=-1)
    rho_safe = jnp.maximum(rho, 1e-3)
    u_x = jnp.sum(f * CX, axis=-1) / rho_safe
    u_y = jnp.sum(f * CY, axis=-1) / rho_safe
    u_x = jnp.clip(u_x, -0.4, 0.4)
    u_y = jnp.clip(u_y, -0.4, 0.4)
    
    fx, fy = interaction_force(rho, G_INT)
    u_x += fx / rho_safe
    u_y += fy / rho_safe
    
    precip_factor = jax.nn.sigmoid(10 * (salt_conc - K_SP)) 
    u_x *= (1.0 - precip_factor) 
    u_y *= (1.0 - precip_factor)

    f_eq = get_equilibrium(rho, u_x, u_y)
    tau_exp = jnp.expand_dims(tau, axis=-1)
    f_out = f - (f - f_eq) / tau_exp
    
    for i in range(9):
        f_out = f_out.at[..., i].set(jnp.roll(f_out[..., i], (CX[i], CY[i]), axis=(0, 1)))
        
    mask_exp = jnp.expand_dims(mask, axis=-1)
    f_out = f_out * (1 - mask_exp) + f * mask_exp
    return jnp.nan_to_num(f_out, nan=0.0), rho, u_x, u_y

@checkpoint
def lbm_step_chaotic(carry, inputs):
    # inputs is now a tuple: (pressure_in, y_center_in)
    pressure_in, y_center_in = inputs
    f, salt, mask = carry
    
    # --- 1. DYNAMIC INLET BC (Chaotic Advection) ---
    # We create a spatial mask for the inlet based on y_center_in
    # y_center_in will be either ~12 (Bottom) or ~37 (Top)
    y_grid = jnp.arange(NY)
    # Soft Gaussian Window for injection zone (sigma=8)
    # We use a soft mask to be differentiable
    inj_window = jnp.exp(-((y_grid - y_center_in)**2) / (2 * 8.0**2))
    # Normalize max to 1.0
    inj_window = inj_window / (jnp.max(inj_window) + 1e-6)
    
    # Calculate Inlet Equilibrium
    p_safe = jnp.clip(pressure_in, 0.0, 0.5) 
    rho_inlet = RHO_CO2_INIT + p_safe * 1.0 
    f_eq_inlet = get_equilibrium(rho_inlet, 0.1, 0.0) # Velocity 0.1 to right
    
    # Apply BC only where window > 0.01 (Soft blending)
    # f[0] = f_eq * window + f_old * (1-window)
    # This alternates the injection zone
    f_slice = f[0, :, :]
    # Expand window for 9 populations: (NY,) -> (NY, 9)
    w_exp = jnp.expand_dims(inj_window, -1)
    
    # Blend: Where window is high, force CO2 injection. Where low, leave as is (wall/bulk)
    f_new_slice = f_eq_inlet * w_exp + f_slice * (1.0 - w_exp)
    f = f.at[0, :, :].set(f_new_slice)
    
    # --- 2. RELAXATION ---
    rho_local = jnp.sum(f, axis=-1)
    tau_eff = TAU_CO2 + (TAU_BRINE - TAU_CO2) * (rho_local - RHO_CO2_INIT)/(RHO_BRINE - RHO_CO2_INIT)
    tau_safe = jnp.maximum(tau_eff, 0.52)
    
    # --- 3. SOLVE ---
    f_new, rho, ux, uy = collision_stream(f, salt, mask, tau_safe)
    
    # --- 4. SALT ---
    ux_s = jnp.clip(ux, -0.2, 0.2)
    uy_s = jnp.clip(uy, -0.2, 0.2)
    grad_salt_x = (jnp.roll(salt, -1, axis=0) - jnp.roll(salt, 1, axis=0)) / 2.0
    grad_salt_y = (jnp.roll(salt, -1, axis=1) - jnp.roll(salt, 1, axis=1)) / 2.0
    laplacian = (jnp.roll(salt, -1, axis=0) + jnp.roll(salt, 1, axis=0) + 
                 jnp.roll(salt, -1, axis=1) + jnp.roll(salt, 1, axis=1) - 4*salt)
    salt_new = salt + (-(ux_s * grad_salt_x + uy_s * grad_salt_y) + D_SALT * laplacian)
    
    # BC: Fresh CO2 at inlet has 0 salt (Only in active injection zone)
    # We use the same window to clear salt
    salt_inlet = salt_new[0, :] * (1.0 - inj_window)
    salt_new = salt_new.at[0, :].set(salt_inlet)
    salt_new = jnp.clip(salt_new, 0.0, 5.0)
    
    # --- 5. PRECIPITATION ---
    new_precip = jax.nn.sigmoid(20 * (salt_new - K_SP))
    mask_new = jnp.maximum(mask, new_precip)
    
    pore_vol = jnp.sum(1-mask_new) + 1e-6
    s_co2 = jnp.sum((rho < (RHO_BRINE + RHO_CO2_INIT)/2.0) * (1-mask_new)) / pore_vol
    
    return (jnp.nan_to_num(f_new), salt_new, mask_new), s_co2

def run_simulation(inputs, initial_state):
    # inputs is (pressure_schedule, y_schedule)
    return lax.scan(lbm_step_chaotic, initial_state, inputs)

## 3. Stochastic & Chaotic Parametric Optimization
We optimize **3 parameters**:
1. `base`: Baseline pressure.
2. `noise_amp`: Amplitude of the **Stochastic Resonance** (random noise).
3. `switch_freq`: Frequency of the **Chaotic Advection** (switching zones).

**Note:** To make noise differentiable, we pre-generate a fixed Gaussian noise pattern $\xi(t)$ and optimize the scaling factor $A_{noise}$. This is the "Reparameterization Trick".

In [None]:
# Pre-generate fixed noise for the reparameterization trick
noise_key = jax.random.PRNGKey(99)
FIXED_NOISE = jax.random.normal(noise_key, shape=(TIME_STEPS,))

def get_chaotic_schedule(params):
    """
    Generates P(t) and Y_center(t).
    """
    base, noise_amp, switch_freq = params
    
    t = jnp.arange(TIME_STEPS)
    
    # 1. Stochastic Pressure Schedule
    # P(t) = Base + Noise_Amp * Fixed_Gaussian_Noise
    pressure = base + noise_amp * FIXED_NOISE
    pressure = jnp.clip(pressure, 0.0, 0.2) # Safety clip
    
    # 2. Chaotic Switching Schedule (Injection Location)
    # Oscillates between Y=12 (Bottom 25%) and Y=37 (Top 75%)
    # Frequency determines how fast we switch
    omega = 2 * jnp.pi * switch_freq / TIME_STEPS
    # Square wave approximation using tanh for differentiability
    switch_signal = jnp.tanh(10 * jnp.sin(omega * t))
    
    # Map [-1, 1] to [12, 37]
    # Center = (12+37)/2 = 24.5
    # Scale = (37-12)/2 = 12.5
    y_center = 24.5 + 12.5 * switch_signal
    
    return pressure, y_center

def loss_fn_chaotic(params, initial_state):
    # 1. Generate Inputs
    pressure_sched, y_sched = get_chaotic_schedule(params)
    inputs = (pressure_sched, y_sched)
    
    # 2. Run Simulation
    final_state, s_hist = run_simulation(inputs, initial_state)
    
    # 3. Compute Loss
    LAMBDA_ENERGY = 0.05
    final_saturation = s_hist[-1]
    perf_loss = (1.0 - final_saturation)**2
    energy_loss = jnp.mean(pressure_sched**2)
    
    return perf_loss + (LAMBDA_ENERGY * energy_loss)

# --- Initialize ---
key = jax.random.PRNGKey(55)
mask_init = jax.random.bernoulli(key, p=0.25, shape=(NX, NY)).astype(jnp.float64)
mask_init = mask_init.at[0:5, :].set(0.0).at[-5:, :].set(0.0) 

rho_init = jnp.ones((NX, NY)) * RHO_BRINE 
u_init = jnp.zeros((NX, NY))
f_init = get_equilibrium(rho_init, u_init, u_init)
salt_init = jnp.ones((NX, NY)) * 0.5 
state_init = (f_init, salt_init, mask_init)

# --- Optimizer Setup ---
# Initial Guess: 
# Base=0.02
# Noise=0.005 (Small initial noise)
# Freq=2.0 (Switch zones twice)
init_params = jnp.array([0.02, 0.005, 2.0])

optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(init_params)
grad_fn = jit(value_and_grad(loss_fn_chaotic))

print("Starting Parametric Optimization (Chaotic & Stochastic)...")
print(f"{'Epoch':<6} | {'Loss':<10} | {'Base':<8} | {'NoiseAmp':<8} | {'SwitchFreq':<8}")
print("-" * 65)

params = init_params
loss_history = []

for epoch in range(51):
    loss_val, grads = grad_fn(params, state_init)
    
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    # Enforce positive parameters
    params = jnp.abs(params)
    
    loss_history.append(loss_val)
    if epoch % 5 == 0:
        b, n, f = params
        print(f"{epoch:<6} | {loss_val:.5f}    | {b:.4f}   | {n:.4f}     | {f:.4f}")

print("Optimization Complete.")

## 4. Results & Comparison
We visualize the noisy pressure schedule and the effect of Chaotic Advection (Switching Zones). Note that comparison with previous baselines should consider that this method splits the injection into zones (potentially reducing total inlet area but enhancing mixing).

In [None]:
# Generate final optimal schedules
opt_p_schedule, opt_y_schedule = get_chaotic_schedule(params)
inputs = (opt_p_schedule, opt_y_schedule)

# Run Validations
final_state_opt, s_hist_opt = run_simulation(inputs, state_init)

# For Baseline, we use the standard CONSTANT PRESSURE code (Full Face) to compare performance
# Note: The baseline here is approximate as the physics kernel for Chaotic is slightly different (Zonal vs Full)
# To be rigorous, we run the Chaotic kernel with Noise=0 and Freq=0 (Static Zonal Injection)
base_params = jnp.array([0.02, 0.0, 0.0])
base_p, base_y = get_chaotic_schedule(base_params)
final_state_base, s_hist_base = run_simulation((base_p, base_y), state_init)

# Plotting
fig, ax = plt.subplots(1, 3, figsize=(20, 6))

# 1. Strategy Inputs
ax[0].plot(opt_p_schedule, color='magenta', alpha=0.6, label='Stochastic Pressure')
ax[0].plot(opt_y_schedule/50.0 * 0.2, color='blue', label='Injection Zone (Scaled)') # Visualize switching
ax[0].set_title(f"Chaotic/Stochastic Inputs")
ax[0].set_ylabel("Pressure / Location")
ax[0].set_xlabel("Time Step")
ax[0].legend(loc='upper right')
ax[0].grid(True, alpha=0.3)

# 2. Efficiency
ax[1].plot(s_hist_opt, color='magenta', linewidth=2, label='Optimized')
ax[1].plot(s_hist_base, 'k--', label='Static Zonal Baseline')
ax[1].set_title(f"Saturation (Final S={s_hist_opt[-1]:.2f})")
ax[1].set_ylabel("Saturation")
ax[1].grid(True, alpha=0.3)
ax[1].legend()

# 3. Fluid Map
f_final, _, mask_final = final_state_opt
rho_final = jnp.sum(f_final, axis=-1)
rho_masked = np.ma.masked_where(mask_final > 0.5, rho_final)

ax[2].imshow(jnp.transpose(mask_final), cmap='Greys', origin='lower', alpha=0.4)
im = ax[2].imshow(rho_masked.T, cmap='RdBu', origin='lower', vmin=0.1, vmax=1.1, alpha=0.9)
ax[2].set_title("Final Fluid Distribution")
plt.colorbar(im, ax=ax[2])

plt.tight_layout()
plt.show()

## 5. Appendices for Conference Paper

In [None]:
print("="*50)
print("       APPENDIX B: OPTIMIZED RESULTS SUMMARY")
print("="*50)
print("Optimized Parameter Values:")
print(f"   1. Base Pressure:       {params[0]:.6f} (Lattice Units)")
print(f"   2. Noise Amplitude:     {params[1]:.6f} (Stochastic Resonance)")
print(f"   3. Switch Frequency:    {params[2]:.6f} Hz (Chaotic Advection)")
print("-"*50)
print("Performance Metrics:")
print(f"   Optimized Saturation:   {s_hist_opt[-1]:.4f} ({(s_hist_opt[-1]*100):.2f}%)")
print(f"   Final Loss Value:       {loss_val:.6f}")
print(f"   Static Zonal Baseline:  {s_hist_base[-1]:.4f} ({(s_hist_base[-1]*100):.2f}%)")
gain = ((s_hist_opt[-1] - s_hist_base[-1]) / s_hist_base[-1]) * 100
print(f"   Efficiency Gain:        +{gain:.2f}%")
print("="*50)