In [4]:
"""
=================================================================================
CYCLIC VOLTAMMETRY SIMULATION: FINITE DIFFERENCE AND PINN
=================================================================================
This code simulates reversible redox reactions (O + e- ⇌ R) at an electrode
using both finite difference (ground truth) and physics-informed neural networks.

KEY FEATURE: Learnable penetration depth ℓp for each species network

Three systems with different bulk concentration ratios:
  System 1: c_O/c_R = 100  (oxidized-rich)
  System 2: c_O/c_R = 1    (balanced)
  System 3: c_O/c_R = 0.01 (reduced-rich)
=================================================================================
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

# =============================================================================
# SECTION 1: PHYSICAL CONSTANTS AND PARAMETERS
# =============================================================================

print("="*80)
print("CYCLIC VOLTAMMETRY SIMULATION WITH LEARNABLE PENETRATION DEPTH")
print("="*80)

# Physical constants
F = 96485.0          # Faraday constant [C/mol]
R = 8.314            # Gas constant [J/(mol·K)]
T = 298.0            # Temperature [K]

# Potential waveform (triangular wave)
E_start = 0.2        # Initial potential [V]
E_vertex = -0.8      # Switching potential [V]
E_end = 0.2          # Final potential [V]
E_0 = -0.4           # Standard potential [V]
nu = 10.0            # Scan rate [mV/s]

# Transport and kinetics
D = 1e-6             # Diffusion coefficient [cm²/s]
n = 1                # Number of electrons
alpha = 0.5          # Transfer coefficient
k_0 = 1e-1           # Standard rate constant [cm/s]
A = 1.0              # Electrode area [cm²]

# Domain parameters
L = 0.1              # Diffusion layer thickness [cm]
npts_x = 100         # Spatial grid points
npts_t = 500         # Temporal grid points

# Three systems with different concentration ratios
systems = [
    {'name': 'System 1', 'c_O': 1.0, 'c_R': 1e-2, 'ratio': 100},
    {'name': 'System 2', 'c_O': 1.0, 'c_R': 1.0,  'ratio': 1},
    {'name': 'System 3', 'c_O': 1e-2, 'c_R': 1.0, 'ratio': 0.01},
]

# Derived parameters
sweep_rate_V_per_s = nu / 1000.0
t_forward = abs(E_vertex - E_start) / sweep_rate_V_per_s
t_return = abs(E_end - E_vertex) / sweep_rate_V_per_s
T_final = t_forward + t_return
t_switch = t_forward

print(f"\nSimulation Parameters:")
print(f"  Scan rate: {nu} mV/s")
print(f"  Total time: {T_final:.2f} s")
print(f"  Switching time: {t_switch:.2f} s")

# =============================================================================
# SECTION 2: FINITE DIFFERENCE SIMULATION
# =============================================================================

print("\n" + "="*80)
print("FINITE DIFFERENCE SIMULATION")
print("="*80)

# Grid setup
del_x = L / (npts_x - 1)
del_t = T_final / (npts_t - 1)
x_grid = np.linspace(0, L, npts_x)
t_grid = np.linspace(0, T_final, npts_t)

# Stability check
lambda_fd = D * del_t / del_x**2
print(f"\nStability parameter λ = {lambda_fd:.4f} (should be ≤ 0.5)")
if lambda_fd > 0.5:
    print("  WARNING: Unstable!")
else:
    print("  ✓ Stable")

def E_of_t_np(tt):
    """Potential waveform"""
    if tt <= t_switch:
        return E_start - sweep_rate_V_per_s * tt
    else:
        return E_vertex + sweep_rate_V_per_s * (tt - t_switch)

def run_finite_difference(c_O_bulk, c_R_bulk, system_name):
    """Run FD simulation"""
    print(f"\n  {system_name}...")
    
    c_R = np.full((npts_x, npts_t), c_R_bulk, dtype=float)
    c_O = np.full((npts_x, npts_t), c_O_bulk, dtype=float)
    E_t = np.zeros(npts_t)
    k_red_t = np.zeros(npts_t)
    k_ox_t = np.zeros(npts_t)
    i_t = np.zeros(npts_t)
    
    E_eq = E_0 + (R * T / (n * F)) * np.log(c_O_bulk / c_R_bulk)
    M_to_mol_cm3 = 1e-3
    
    for j in range(npts_t):
        E_curr = E_of_t_np(t_grid[j])
        E_t[j] = E_curr
        eta = E_curr - E_eq
        k_red = k_0 * np.exp((-alpha * n * F * eta) / (R * T))
        k_ox = k_0 * np.exp(((1.0 - alpha) * n * F * eta) / (R * T))
        k_red_t[j] = k_red
        k_ox_t[j] = k_ox
        
        if j > 0:
            c_R_prev = c_R[:, j-1]
            c_O_prev = c_O[:, j-1]
            
            c_R[1:-1, j] = c_R_prev[1:-1] + lambda_fd * (c_R_prev[2:] - 2*c_R_prev[1:-1] + c_R_prev[:-2])
            c_O[1:-1, j] = c_O_prev[1:-1] + lambda_fd * (c_O_prev[2:] - 2*c_O_prev[1:-1] + c_O_prev[:-2])
            
            c_R[-1, j] = c_R_bulk
            c_O[-1, j] = c_O_bulk
            
            c_R_pred = c_R_prev[0] + lambda_fd * (c_R_prev[1] - c_R_prev[0])
            c_O_pred = c_O_prev[0] + lambda_fd * (c_O_prev[1] - c_O_prev[0])
            
            S = c_R_pred + c_O_pred
            k_sum = k_red + k_ox
            if k_sum > 0:
                C_R_eq = (k_red / k_sum) * S
                c_R[0, j] = C_R_eq + (c_R_pred - C_R_eq) * np.exp(-k_sum * del_t)
                c_O[0, j] = S - c_R[0, j]
            else:
                c_R[0, j] = c_R_pred
                c_O[0, j] = c_O_pred
            
            c_R[0, j] = max(c_R[0, j], 0.0)
            c_O[0, j] = max(c_O[0, j], 0.0)
            
            r_net = k_red * c_O[0, j] - k_ox * c_R[0, j]
            i_t[j] = n * F * A * M_to_mol_cm3 * r_net
    
    return {
        'c_R': c_R, 'c_O': c_O,
        'E_t': E_t, 'k_red_t': k_red_t, 'k_ox_t': k_ox_t, 'i_t': i_t,
        'E_eq': E_eq
    }

fd_results = []
for sys in systems:
    result = run_finite_difference(sys['c_O'], sys['c_R'], sys['name'])
    result['system'] = sys
    fd_results.append(result)

# =============================================================================
# SECTION 3: NEURAL NETWORK WITH LEARNABLE PENETRATION DEPTH
# =============================================================================

print("\n" + "="*80)
print("PINN WITH LEARNABLE PENETRATION DEPTH")
print("="*80)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {DEVICE}")

class PINN_Oxidized(nn.Module):
    """NN for oxidized species with LEARNABLE penetration depth"""
    def __init__(self, width=50, depth=5, l_p_init=3.5e-3):
        super().__init__()
        layers = [nn.Linear(2, width), nn.SiLU()]
        for _ in range(depth - 2):
            layers += [nn.Linear(width, width), nn.SiLU()]
        layers.append(nn.Linear(width, 1))
        self.net = nn.Sequential(*layers)
        
        # LEARNABLE penetration depth (initialized with reasonable value)
        self.log_l_p = nn.Parameter(torch.tensor(np.log(l_p_init), dtype=torch.float32))
    
    def forward(self, xt):
        x = xt[:, 0:1]
        l_p = torch.exp(self.log_l_p)  # Ensure positivity via exp
        recovery = 1.0 - torch.exp(-x / l_p)
        C = torch.nn.functional.softplus(self.net(xt))
        return C * recovery
    
    def get_l_p(self):
        return torch.exp(self.log_l_p).item()

class PINN_Reduced(nn.Module):
    """NN for reduced species with LEARNABLE penetration depth"""
    def __init__(self, width=50, depth=5, l_p_init=3.5e-3):
        super().__init__()
        layers = [nn.Linear(2, width), nn.SiLU()]
        for _ in range(depth - 2):
            layers += [nn.Linear(width, width), nn.SiLU()]
        layers.append(nn.Linear(width, 1))
        self.net = nn.Sequential(*layers)
        
        # LEARNABLE penetration depth
        self.log_l_p = nn.Parameter(torch.tensor(np.log(l_p_init), dtype=torch.float32))
    
    def forward(self, xt):
        x = xt[:, 0:1]
        l_p = torch.exp(self.log_l_p)
        decay = torch.exp(-x / l_p)
        C = torch.nn.functional.softplus(self.net(xt))
        return C * decay
    
    def get_l_p(self):
        return torch.exp(self.log_l_p).item()

# Initialize networks
torch.manual_seed(42)
pinn_models = []
for i, sys in enumerate(systems):
    torch.manual_seed(42 + i)
    model_O = PINN_Oxidized().to(DEVICE)
    model_R = PINN_Reduced().to(DEVICE)
    pinn_models.append({'O': model_O, 'R': model_R, 'system': sys})

print(f"Parameters per system: {sum(p.numel() for p in pinn_models[0]['O'].parameters()):,}")
print(f"  Including 2 learnable penetration depths (ℓp_O and ℓp_R)")

# =============================================================================
# SECTION 4: HELPER FUNCTIONS
# =============================================================================

def create_interpolator(t_data, y_data, device):
    t_torch = torch.tensor(t_data, dtype=torch.float32, device=device)
    y_torch = torch.tensor(y_data, dtype=torch.float32, device=device)
    
    def interpolate(t_query):
        t_clamped = t_query.view(-1).clamp(min=t_torch[0], max=t_torch[-1])
        idx = torch.bucketize(t_clamped, t_torch, right=False)
        idx = torch.clamp(idx, 1, len(t_torch)-1)
        i0, i1 = idx - 1, idx
        t0, t1 = t_torch[i0], t_torch[i1]
        y0, y1 = y_torch[i0], y_torch[i1]
        w = (t_clamped - t0) / (t1 - t0 + 1e-12)
        return (y0 + w * (y1 - y0)).view_as(t_query)
    
    return interpolate

for i, fd_res in enumerate(fd_results):
    pinn_models[i]['c_O_surf_fn'] = create_interpolator(t_grid, fd_res['c_O'][0, :], DEVICE)
    pinn_models[i]['c_R_surf_fn'] = create_interpolator(t_grid, fd_res['c_R'][0, :], DEVICE)

def E_of_t_torch(t):
    t_switch_t = torch.tensor(t_switch, dtype=t.dtype, device=t.device)
    forward = E_start - sweep_rate_V_per_s * t
    backward = E_vertex + sweep_rate_V_per_s * (t - t_switch_t)
    return torch.where(t <= t_switch_t, forward, backward)

def k_red_torch(t, E_eq):
    E_t = E_of_t_torch(t)
    eta = E_t - E_eq
    return k_0 * torch.exp((-alpha * n * F * eta) / (R * T))

def k_ox_torch(t, E_eq):
    E_t = E_of_t_torch(t)
    eta = E_t - E_eq
    return k_0 * torch.exp(((1.0 - alpha) * n * F * eta) / (R * T))

# =============================================================================
# SECTION 5: PINN PHYSICS FUNCTIONS
# =============================================================================

def compute_c_with_hard_BC(model, x, t, c_bulk, c_surf_fn, species_type):
    H = (t > 0).float()
    c_surf = c_surf_fn(t)
    c_far = torch.tensor(c_bulk, dtype=t.dtype, device=t.device)
    
    xt = torch.cat([x, t], dim=1)
    c_raw = model(xt)
    
    x0 = torch.zeros_like(t)
    xL = torch.full_like(t, L)
    c_at_0 = model(torch.cat([x0, t], dim=1))
    c_at_L = model(torch.cat([xL, t], dim=1))
    
    if species_type == 'O':
        c_norm = ((c_raw - c_at_0) / (c_at_L - c_at_0 + 1e-8)).clamp(0, 1)
        c_t = c_surf + (c_far - c_surf) * c_norm
    else:
        c_norm = ((c_raw - c_at_L) / (c_at_0 - c_at_L + 1e-8)).clamp(0, 1)
        c_t = c_far + (c_surf - c_far) * c_norm
    
    return c_far * (1 - H) + c_t * H

def compute_derivatives(model_O, model_R, x, t, c_O_bulk, c_R_bulk, c_O_surf_fn, c_R_surf_fn):
    x_req = x.detach().clone().requires_grad_(True)
    t_req = t.detach().clone().requires_grad_(True)
    
    c_O = compute_c_with_hard_BC(model_O, x_req, t_req, c_O_bulk, c_O_surf_fn, 'O')
    c_R = compute_c_with_hard_BC(model_R, x_req, t_req, c_R_bulk, c_R_surf_fn, 'R')
    
    ones = torch.ones_like(c_O)
    dc_O_dt = torch.autograd.grad(c_O, t_req, ones, create_graph=True)[0]
    dc_O_dx = torch.autograd.grad(c_O, x_req, ones, create_graph=True)[0]
    dc_R_dt = torch.autograd.grad(c_R, t_req, ones, create_graph=True)[0]
    dc_R_dx = torch.autograd.grad(c_R, x_req, ones, create_graph=True)[0]
    d2c_O_dx2 = torch.autograd.grad(dc_O_dx, x_req, ones, create_graph=True)[0]
    d2c_R_dx2 = torch.autograd.grad(dc_R_dx, x_req, ones, create_graph=True)[0]
    
    return {
        'c_O': c_O, 'dc_O_dt': dc_O_dt, 'dc_O_dx': dc_O_dx, 'd2c_O_dx2': d2c_O_dx2,
        'c_R': c_R, 'dc_R_dt': dc_R_dt, 'dc_R_dx': dc_R_dx, 'd2c_R_dx2': d2c_R_dx2
    }

def compute_losses(model_O, model_R, x_int, t_int, x_bc, t_bc,
                  c_O_bulk, c_R_bulk, E_eq, c_O_surf_fn, c_R_surf_fn):
    # Physics loss
    derivs = compute_derivatives(model_O, model_R, x_int, t_int, 
                                c_O_bulk, c_R_bulk, c_O_surf_fn, c_R_surf_fn)
    res_O = derivs['dc_O_dt'] - D * derivs['d2c_O_dx2']
    res_R = derivs['dc_R_dt'] - D * derivs['d2c_R_dx2']
    loss_physics = (res_O.pow(2).mean() + res_R.pow(2).mean()) / 2
    
    # Flux BC loss
    derivs_bc = compute_derivatives(model_O, model_R, x_bc, t_bc,
                                   c_O_bulk, c_R_bulk, c_O_surf_fn, c_R_surf_fn)
    M_to_mol_cm3 = 1e-3
    reaction_flux = (k_red_torch(t_bc, E_eq) * derivs_bc['c_O'] - 
                     k_ox_torch(t_bc, E_eq) * derivs_bc['c_R']) * M_to_mol_cm3
    
    flux_res_O = (-D * derivs_bc['dc_O_dx']) - reaction_flux
    flux_res_R = (-D * derivs_bc['dc_R_dx']) + reaction_flux
    loss_flux = (flux_res_O.pow(2).mean() + flux_res_R.pow(2).mean()) / 2
    
    # Conservation loss
    total_bulk = c_O_bulk + c_R_bulk
    loss_conservation = (derivs['c_O'] + derivs['c_R'] - total_bulk).pow(2).mean()
    
    return loss_physics, loss_flux, loss_conservation

# =============================================================================
# SECTION 6: TRAINING
# =============================================================================

print("\n" + "="*80)
print("TRAINING")
print("="*80)

N_interior = 4096
N_boundary = 1024
epochs = 9000
lr = 1e-3
w_physics = 1.0
w_flux = 10.0
w_conservation = 1.0

for model_dict in pinn_models:
    params = list(model_dict['O'].parameters()) + list(model_dict['R'].parameters())
    model_dict['optimizer'] = torch.optim.Adam(params, lr=lr)
    model_dict['scheduler'] = torch.optim.lr_scheduler.StepLR(
        model_dict['optimizer'], step_size=200, gamma=0.5
    )
    model_dict['history'] = []

def sample_points(N, device):
    t_rand = torch.rand(N, 1, device=device) * T_final
    x_rand = torch.rand(N, 1, device=device) * L
    return x_rand, t_rand

print(f"\nEpochs: {epochs}, Interior: {N_interior}, Boundary: {N_boundary}")
print("Training...")

for step in range(epochs + 1):
    for i, (model_dict, fd_res) in enumerate(zip(pinn_models, fd_results)):
        x_int, t_int = sample_points(N_interior, DEVICE)
        x_bc = torch.zeros(N_boundary, 1, device=DEVICE)
        t_bc = torch.rand(N_boundary, 1, device=DEVICE) * T_final
        
        sys = model_dict['system']
        loss_phys, loss_flux, loss_cons = compute_losses(
            model_dict['O'], model_dict['R'],
            x_int, t_int, x_bc, t_bc,
            sys['c_O'], sys['c_R'], fd_res['E_eq'],
            model_dict['c_O_surf_fn'], model_dict['c_R_surf_fn']
        )
        
        total_loss = w_physics * loss_phys + w_flux * loss_flux + w_conservation * loss_cons
        
        model_dict['optimizer'].zero_grad()
        total_loss.backward()
        model_dict['optimizer'].step()
        
        if step % 200 == 0 and step > 0:
            model_dict['scheduler'].step()
        
        if step % 500 == 0:
            model_dict['history'].append({
                'step': step,
                'total': total_loss.item(),
                'physics': loss_phys.item(),
                'flux': loss_flux.item(),
                'conservation': loss_cons.item(),
                'l_p_O': model_dict['O'].get_l_p(),
                'l_p_R': model_dict['R'].get_l_p()
            })
    
    if step % 1000 == 0:
        print(f"  Step {step:5d}")

print("✓ Training complete!")

# Print learned penetration depths
print("\nLearned Penetration Depths (ℓp):")
for i, model_dict in enumerate(pinn_models):
    sys = model_dict['system']
    print(f"  {sys['name']}:")
    print(f"    ℓp_O = {model_dict['O'].get_l_p()*1000:.3f} mm")
    print(f"    ℓp_R = {model_dict['R'].get_l_p()*1000:.3f} mm")

# =============================================================================
# SECTION 7: EVALUATION AND PLOTTING
# =============================================================================

print("\n" + "="*80)
print("GENERATING PLOTS")
print("="*80)

for model_dict in pinn_models:
    model_dict['O'].eval()
    model_dict['R'].eval()

Nt_eval = 500
Nx_eval = 120
t_eval = torch.linspace(0, T_final, Nt_eval, device=DEVICE).view(-1, 1)
x_eval = torch.linspace(0, L, Nx_eval, device=DEVICE).view(-1, 1)

def evaluate_pinn(model_dict, fd_res):
    sys = model_dict['system']
    
    with torch.no_grad():
        x0 = torch.zeros_like(t_eval)
        c_O_surf = compute_c_with_hard_BC(model_dict['O'], x0, t_eval,
                                         sys['c_O'], model_dict['c_O_surf_fn'], 'O')
        c_R_surf = compute_c_with_hard_BC(model_dict['R'], x0, t_eval,
                                         sys['c_R'], model_dict['c_R_surf_fn'], 'R')
        
        M_to_mol_cm3 = 1e-3
        k_red_vals = k_red_torch(t_eval, fd_res['E_eq'])
        k_ox_vals = k_ox_torch(t_eval, fd_res['E_eq'])
        r_net = k_red_vals * c_O_surf - k_ox_vals * c_R_surf
        i_pinn = n * F * A * M_to_mol_cm3 * r_net
        E_pinn = E_of_t_torch(t_eval)
        
        mask = (E_pinn >= -0.8) & (E_pinn < 0.1)
        i_masked = i_pinn.clone()
        i_masked[~mask] = float('nan')
        
        # if sys['ratio'] >= 1.0:
        #     idx_peak = torch.nanargmax(i_masked).item()
        # else:
        #     idx_peak = torch.nanargmin(i_masked).item()
        
         # Robust Peak Finding Logic
        if sys['ratio'] >= 1.0:  # Anodic
            # Replace NaN with -Infinity so they are ignored by max
            tmp = torch.where(torch.isnan(i_masked), torch.tensor(float('-inf'), device=i_masked.device), i_masked)
            idx_peak = torch.argmax(tmp).item()
        else:  # Cathodic
            # Replace NaN with +Infinity so they are ignored by min
            tmp = torch.where(torch.isnan(i_masked), torch.tensor(float('inf'), device=i_masked.device), i_masked)
            idx_peak = torch.argmin(tmp).item()
        
        idx_before = max(0, int(0.8 * idx_peak))
        idx_after = min(Nt_eval - 1, int(1.2 * idx_peak))
        
        profiles = {}
        for label, idx in zip(['t1', 't2', 't3'], [idx_before, idx_peak, idx_after]):
            t_profile = t_eval[idx].repeat(Nx_eval, 1)
            c_O_profile = compute_c_with_hard_BC(model_dict['O'], x_eval, t_profile,
                                                sys['c_O'], model_dict['c_O_surf_fn'], 'O')
            c_R_profile = compute_c_with_hard_BC(model_dict['R'], x_eval, t_profile,
                                                sys['c_R'], model_dict['c_R_surf_fn'], 'R')
            
            profiles[label] = {
                'c_O': c_O_profile.cpu().numpy(),
                'c_R': c_R_profile.cpu().numpy(),
                't': t_eval[idx].item(),
                'E': E_pinn[idx].item()
            }
        
        return {
            'E': E_pinn.cpu().numpy().flatten(),
            'i': i_pinn.cpu().numpy().flatten(),
            't': t_eval.cpu().numpy().flatten(),
            'k_red': k_red_vals.cpu().numpy().flatten(),
            'k_ox': k_ox_vals.cpu().numpy().flatten(),
            'profiles': profiles,
            'indices': [idx_before, idx_peak, idx_after]
        }

pinn_results = []
for model_dict, fd_res in zip(pinn_models, fd_results):
    result = evaluate_pinn(model_dict, fd_res)
    result['system'] = model_dict['system']
    pinn_results.append(result)

plt.rcParams.update({
    'font.size': 11,
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'legend.fontsize': 10,
    'figure.dpi': 150,
    'lines.linewidth': 2.5,
})

x_mm = x_eval.cpu().numpy().flatten() * 10
x_mm_fd = x_grid * 10
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
profile_colors = ['#d62728', '#2ca02c', '#9467bd']

# =============================================================================
# PLOT 1: VOLTAMMOGRAM COMPARISON (FD vs PINN)
# =============================================================================

print("\n  Creating voltammogram comparison...")
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for i, (ax, fd_res, pinn_res) in enumerate(zip(axes, fd_results, pinn_results)):
    sys = fd_res['system']
    
    mask_fd = (fd_res['E_t'] >= -0.8) & (fd_res['E_t'] < 0.1)
    mask_pinn = (pinn_res['E'] >= -0.8) & (pinn_res['E'] < 0.1)
    
    ax.plot(fd_res['E_t'][mask_fd], fd_res['i_t'][mask_fd]*1e6, 
            '--', linewidth=3, label='FD', color='gray', alpha=0.6)
    ax.plot(pinn_res['E'][mask_pinn], pinn_res['i'][mask_pinn]*1e6,
            linewidth=2.5, label='PINN', color=colors[i])
    
    for idx, color in zip(pinn_res['indices'], profile_colors):
        if -0.8 <= pinn_res['E'][idx] < 0.1:
            ax.axvline(pinn_res['E'][idx], color=color, linewidth=1.5, alpha=0.5)
    
    ax.set_xlabel('Potential (V)', fontweight='bold')
    ax.set_ylabel('Current (μA)', fontweight='bold')
    ax.set_title(f"{sys['name']}: c_O/c_R = {sys['ratio']}", fontweight='bold')
    ax.set_xlim(-0.8, 0.1)
    ax.axhline(0, color='k', linestyle='--', alpha=0.3, linewidth=1)
    ax.legend()
    ax.grid(True, alpha=0.2)

plt.tight_layout()
plt.savefig('01_voltammogram_comparison.png', dpi=300, bbox_inches='tight')
plt.savefig('01_voltammogram_comparison.pdf', bbox_inches='tight')
print("    ✓ 01_voltammogram_comparison")
plt.close()

# =============================================================================
# PLOT 2: CONCENTRATION PROFILES WITH CORRECTED RATE CONSTANTS
# =============================================================================

print("\n  Creating concentration profiles...")
for i, (fd_res, pinn_res) in enumerate(zip(fd_results, pinn_results)):
    sys = fd_res['system']
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f"{sys['name']}: Concentration Profiles and Dynamics", 
                 fontsize=15, fontweight='bold')
    
    # Potential waveform
    ax = axes[0, 0]
    ax.plot(pinn_res['t'], pinn_res['E'], linewidth=2, color=colors[i])
    ax.axhline(E_0, linestyle='--', color='k', linewidth=1.2, label='E₀')
    ax.axhline(fd_res['E_eq'], linestyle=':', color='red', linewidth=1.5, label='E_eq')
    ax.axvline(t_switch, linestyle='--', color='gray', linewidth=1, label='Switch')
    for label, prof in pinn_res['profiles'].items():
        ax.axvline(prof['t'], linestyle=':', alpha=0.5, linewidth=1)
    ax.set_xlabel('Time (s)', fontweight='bold')
    ax.set_ylabel('Potential (V)', fontweight='bold')
    ax.set_title('Potential Waveform', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.2)
    
    # Rate constants - CORRECTED EXPONENTIAL PLOT
    ax = axes[0, 1]
    ax.plot(pinn_res['t'], pinn_res['k_red'], linewidth=2, 
                label='k_red', color='#1f77b4')
    ax.plot(pinn_res['t'], pinn_res['k_ox'], linewidth=2,
                label='k_ox', color='#ff7f0e')
    ax.axvline(t_switch, linestyle='--', color='gray', linewidth=1, alpha=0.5)
    ax.set_xlabel('Time (s)', fontweight='bold')
    ax.set_ylabel('Rate Constant (cm/s)', fontweight='bold')
    ax.set_title('Rate Constants (Butler-Volmer)', fontweight='bold')
    ax.set_ylim(1e-6, 1e4)
    ax.legend()
    ax.grid(True, alpha=0.2, which='both')
    
    # Reduced species
    ax = axes[1, 0]
    for j, (label, prof) in enumerate(pinn_res['profiles'].items()):
        idx_fd = np.argmin(np.abs(t_grid - prof['t']))
        ax.plot(x_mm_fd, fd_res['c_R'][:, idx_fd] / sys['c_R'],
                '--', linewidth=2, color=profile_colors[j], alpha=0.5)
        ax.plot(x_mm, prof['c_R'] / sys['c_R'],
                linewidth=2.5, color=profile_colors[j],
                label=f"t={prof['t']:.1f}s (E={prof['E']:.2f}V)")
    ax.axhline(1.0, linestyle='--', color='k', linewidth=1.2)
    ax.set_xlabel('Distance from Electrode (mm)', fontweight='bold')
    ax.set_ylabel('c_R / c_R,bulk', fontweight='bold')
    ax.set_title('Reduced Species', fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.2)
    
    # Oxidized species
    ax = axes[1, 1]
    for j, (label, prof) in enumerate(pinn_res['profiles'].items()):
        idx_fd = np.argmin(np.abs(t_grid - prof['t']))
        ax.plot(x_mm_fd, fd_res['c_O'][:, idx_fd] / sys['c_O'],
                '--', linewidth=2, color=profile_colors[j], alpha=0.5)
        ax.plot(x_mm, prof['c_O'] / sys['c_O'],
                linewidth=2.5, color=profile_colors[j],
                label=f"t={prof['t']:.1f}s (E={prof['E']:.2f}V)")
    ax.axhline(1.0, linestyle='--', color='k', linewidth=1.2)
    ax.set_xlabel('Distance from Electrode (mm)', fontweight='bold')
    ax.set_ylabel('c_O / c_O,bulk', fontweight='bold')
    ax.set_title('Oxidized Species', fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.2)
    
    plt.tight_layout()
    filename = f"02_profiles_system{i+1}"
    plt.savefig(f'{filename}.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{filename}.pdf', bbox_inches='tight')
    print(f"    ✓ {filename}")
    plt.close()

# =============================================================================
# PLOT 3: TRAINING LOSS CURVES
# =============================================================================

print("\n  Creating training loss curves...")
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for i, (ax, model_dict) in enumerate(zip(axes, pinn_models)):
    history = model_dict['history']
    if not history:
        continue
    
    steps = [h['step'] for h in history]
    total_loss = [h['total'] for h in history]
    physics_loss = [h['physics'] for h in history]
    flux_loss = [h['flux'] for h in history]
    cons_loss = [h['conservation'] for h in history]
    
    ax.semilogy(steps, total_loss, linewidth=2.5, label='Total', color=colors[i])
    ax.semilogy(steps, physics_loss, linewidth=1.5, label='Physics', alpha=0.7)
    ax.semilogy(steps, flux_loss, linewidth=1.5, label='Flux', alpha=0.7)
    ax.semilogy(steps, cons_loss, linewidth=1.5, label='Conservation', alpha=0.7)
    
    ax.set_xlabel('Training Step', fontweight='bold')
    ax.set_ylabel('Loss', fontweight='bold')
    ax.set_title(f"{model_dict['system']['name']} Training Loss", fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('03_training_loss.png', dpi=300, bbox_inches='tight')
plt.savefig('03_training_loss.pdf', bbox_inches='tight')
print("    ✓ 03_training_loss")
plt.close()

# =============================================================================
# PLOT 4: SYSTEM OVERLAY COMPARISON
# =============================================================================

print("\n  Creating system overlay...")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# FD overlay
for i, fd_res in enumerate(fd_results):
    sys = fd_res['system']
    mask = (fd_res['E_t'] >= -0.8) & (fd_res['E_t'] < 0.1)
    ax1.plot(fd_res['E_t'][mask], fd_res['i_t'][mask]*1e6,
            linewidth=2.5, label=f"{sys['name']} (c_O/c_R={sys['ratio']})",
            color=colors[i])

ax1.set_xlabel('Potential (V)', fontweight='bold')
ax1.set_ylabel('Current (μA)', fontweight='bold')
ax1.set_title('Finite Difference: All Systems', fontweight='bold', fontsize=14)
ax1.axhline(0, color='k', linestyle='--', alpha=0.3)
ax1.legend()
ax1.grid(True, alpha=0.2)

# PINN overlay
for i, pinn_res in enumerate(pinn_results):
    sys = pinn_res['system']
    mask = (pinn_res['E'] >= -0.8) & (pinn_res['E'] < 0.1)
    ax2.plot(pinn_res['E'][mask], pinn_res['i'][mask]*1e6,
            linewidth=2.5, label=f"{sys['name']} (c_O/c_R={sys['ratio']})",
            color=colors[i])

ax2.set_xlabel('Potential (V)', fontweight='bold')
ax2.set_ylabel('Current (μA)', fontweight='bold')
ax2.set_title('PINN: All Systems', fontweight='bold', fontsize=14)
ax2.axhline(0, color='k', linestyle='--', alpha=0.3)
ax2.legend()
ax2.grid(True, alpha=0.2)

plt.tight_layout()
plt.savefig('04_all_systems_overlay.png', dpi=300, bbox_inches='tight')
plt.savefig('04_all_systems_overlay.pdf', bbox_inches='tight')
print("    ✓ 04_all_systems_overlay")
plt.close()

# =============================================================================
# PLOT 5: LEARNED PENETRATION DEPTHS
# =============================================================================

print("\n  Creating penetration depth evolution...")
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for i, (ax, model_dict) in enumerate(zip(axes, pinn_models)):
    history = model_dict['history']
    if not history:
        continue
    
    steps = [h['step'] for h in history]
    l_p_O = [h['l_p_O'] * 1000 for h in history]  # Convert to mm
    l_p_R = [h['l_p_R'] * 1000 for h in history]
    
    ax.plot(steps, l_p_O, linewidth=2.5, label='ℓp_O (Oxidized)', color='#1f77b4')
    ax.plot(steps, l_p_R, linewidth=2.5, label='ℓp_R (Reduced)', color='#ff7f0e')
    
    ax.set_xlabel('Training Step', fontweight='bold')
    ax.set_ylabel('Penetration Depth ℓp (mm)', fontweight='bold')
    ax.set_title(f"{model_dict['system']['name']}: Learned ℓp", fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('05_penetration_depths.png', dpi=300, bbox_inches='tight')
plt.savefig('05_penetration_depths.pdf', bbox_inches='tight')
print("    ✓ 05_penetration_depths")
plt.close()

# =============================================================================
# FINAL SUMMARY
# =============================================================================

print("\n" + "="*80)
print("SIMULATION SUMMARY")
print("="*80)

for i, (fd_res, pinn_res, model_dict) in enumerate(zip(fd_results, pinn_results, pinn_models)):
    sys = fd_res['system']
    print(f"\n{sys['name']} (c_O={sys['c_O']}, c_R={sys['c_R']}, ratio={sys['ratio']}):")
    print(f"  FD Peak Current:   {np.max(np.abs(fd_res['i_t']))*1e6:.2f} μA")
    print(f"  PINN Peak Current: {np.max(np.abs(pinn_res['i']))*1e6:.2f} μA")
    error = np.abs(np.max(np.abs(pinn_res['i'])) - np.max(np.abs(fd_res['i_t']))) / np.max(np.abs(fd_res['i_t'])) * 100
    print(f"  Relative Error:    {error:.2f}%")
    print(f"  Learned ℓp_O:      {model_dict['O'].get_l_p()*1000:.3f} mm")
    print(f"  Learned ℓp_R:      {model_dict['R'].get_l_p()*1000:.3f} mm")

print("\n" + "="*80)
print("✓ ALL PLOTS GENERATED!")
print("="*80)
print("\nGenerated files:")
print("  01_voltammogram_comparison.png/pdf - FD vs PINN for all 3 systems")
print("  02_profiles_system1.png/pdf - System 1 profiles + rate constants")
print("  02_profiles_system2.png/pdf - System 2 profiles + rate constants")
print("  02_profiles_system3.png/pdf - System 3 profiles + rate constants")
print("  03_training_loss.png/pdf - Training loss curves")
print("  04_all_systems_overlay.png/pdf - Comparison overlay")
print("  05_penetration_depths.png/pdf - Evolution of learned ℓp")
print("\n" + "="*80)

CYCLIC VOLTAMMETRY SIMULATION WITH LEARNABLE PENETRATION DEPTH

Simulation Parameters:
  Scan rate: 10.0 mV/s
  Total time: 200.00 s
  Switching time: 100.00 s

FINITE DIFFERENCE SIMULATION

Stability parameter λ = 0.3928 (should be ≤ 0.5)
  ✓ Stable

  System 1...

  System 2...

  System 3...

PINN WITH LEARNABLE PENETRATION DEPTH

Device: cpu
Parameters per system: 7,852
  Including 2 learnable penetration depths (ℓp_O and ℓp_R)

TRAINING

Epochs: 9000, Interior: 4096, Boundary: 1024
Training...
  Step     0
  Step  1000
  Step  2000
  Step  3000
  Step  4000
  Step  5000
  Step  6000
  Step  7000
  Step  8000
  Step  9000
✓ Training complete!

Learned Penetration Depths (ℓp):
  System 1:
    ℓp_O = 7.603 mm
    ℓp_R = 7.635 mm
  System 2:
    ℓp_O = 5.317 mm
    ℓp_R = 5.309 mm
  System 3:
    ℓp_O = 6.248 mm
    ℓp_R = 6.238 mm

GENERATING PLOTS

  Creating voltammogram comparison...
    ✓ 01_voltammogram_comparison

  Creating concentration profiles...
    ✓ 02_profiles_system1
 