In [None]:
"""
=================================================================================
PHYSICS-INFORMED NEURAL NETWORKS FOR ELECTROCHEMICAL IMPEDANCE SPECTROSCOPY
=================================================================================
Solves the frequency-domain diffusion equation with Faradaic reactions and 
double-layer capacitance.

Key features:
- Complex-valued concentration perturbations
- Butler-Volmer kinetics at electrode
- Double-layer capacitance
- Validation set for smooth loss curves
- Publication-quality plots
=================================================================================
"""

import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter

# =============================================================================
# SECTION 1: SETUP AND UTILITIES
# =============================================================================

print("="*80)
print("ELECTROCHEMICAL IMPEDANCE SPECTROSCOPY - PINN")
print("="*80)

# Reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")

TWO_PI = 2.0 * math.pi

def to_t(x):
    """Convert to torch tensor"""
    return torch.as_tensor(x, dtype=torch.float32, device=device)

def grad(outputs, inputs):
    """Compute gradient"""
    return torch.autograd.grad(outputs, inputs, 
                              grad_outputs=torch.ones_like(outputs),
                              create_graph=True, retain_graph=True, 
                              only_inputs=True)[0]

# Complex number utilities (using real-imaginary pairs)
class Complex:
    def __init__(self, re, im):
        self.re = re
        self.im = im
    
    def __add__(self, other):
        return Complex(self.re + other.re, self.im + other.im)
    
    def __sub__(self, other):
        return Complex(self.re - other.re, self.im - other.im)
    
    def __mul__(self, other):
        if isinstance(other, Complex):
            return Complex(
                self.re*other.re - self.im*other.im, 
                self.re*other.im + self.im*other.re
            )
        return Complex(self.re*other, self.im*other)
    
    def __rmul__(self, other):
        return self.__mul__(other)
    
    def __truediv__(self, other):
        if isinstance(other, Complex):
            den = other.re*other.re + other.im*other.im
            return Complex(
                (self.re*other.re + self.im*other.im)/den, 
                (self.im*other.re - self.re*other.im)/den
            )
        return Complex(self.re/other, self.im/other)
    
    def conj(self):
        return Complex(self.re, -self.im)
    
    def abs(self):
        return torch.sqrt(self.re*self.re + self.im*self.im)

# =============================================================================
# SECTION 2: PHYSICAL PARAMETERS
# =============================================================================

print("\nPhysical Parameters:")

# Physical constants (SI)
F = to_t(96485.33212)   # Faraday constant (C/mol)
R_g = to_t(8.314462618) # Gas constant (J/mol/K)

# System parameters
params = {
    'T': 298.15,        # Temperature [K]
    'D': 1e-10,         # Diffusion coefficient [m²/s]
    'd': 6e-4,          # Electrode thickness [m]
    'c0': 1.0,          # Bulk concentration [mol/m³] = 1 mM
    'alpha': 0.5,       # Transfer coefficient [-]
    'k': 1e-5,          # Rate constant [m/s]
    'C_dl': 0.2,        # Double-layer capacitance [F/m²]
    'deltaV': 5e-3      # Voltage amplitude [V]
}

# Convert to tensors
T = to_t(params['T'])
D = to_t(params['D'])
d = to_t(params['d'])
c0 = to_t(params['c0'])
alpha_tf = to_t(params['alpha'])
k = to_t(params['k'])
C_dl = to_t(params['C_dl'])
deltaV = to_t(params['deltaV'])

# Frequency range
f_min, f_max = 1e-2, 1e2
print(f"  Temperature: {params['T']} K")
print(f"  Diffusion coef: {params['D']} m²/s")
print(f"  Thickness: {params['d']*1e3:.2f} mm")
print(f"  Frequency range: {f_min} - {f_max} Hz")

# Derived parameters (for physics)
log_D = torch.log(D.clone())
log_k = torch.log(k.clone())
log_Cdl = torch.log(C_dl.clone())

def current_params():
    """Return current parameter values"""
    return (torch.exp(log_D), d, c0, alpha_tf, torch.exp(log_k), 
            torch.exp(log_Cdl), T, deltaV)

# =============================================================================
# SECTION 3: NEURAL NETWORK ARCHITECTURE
# =============================================================================

print("\n" + "="*80)
print("NEURAL NETWORK ARCHITECTURE")
print("="*80)

class EIS_PINN(nn.Module):
    """
    PINN for EIS with physics-informed ansatz
    
    Inputs: (x, log10f, dV)
    Outputs: (c_real, c_imag) - complex concentration perturbation
    
    Ansatz includes exponential decay with characteristic length delta
    """
    def __init__(self, layers=[3, 128, 128, 128, 128, 2]):
        super().__init__()
        net = []
        for i in range(len(layers)-2):
            net.append(nn.Linear(layers[i], layers[i+1]))
            net.append(nn.Tanh())
        net.append(nn.Linear(layers[-2], layers[-1]))
        self.net = nn.Sequential(*net)
    
    def forward(self, x, log10f, dV):
        """
        Forward pass with physics-informed ansatz
        
        Args:
            x: spatial coordinate [m]
            log10f: log10 of frequency [Hz]
            dV: voltage amplitude [V]
        
        Returns:
            c_r, c_i: real and imaginary parts of concentration [mol/m³]
        """
        Dv, _, _, _, _, _, _, _ = current_params()
        
        # Frequency and characteristic length
        f = 10.0 ** log10f
        omega = to_t(TWO_PI) * f
        delta = torch.sqrt(2.0 * Dv / (omega + 1e-12))  # Diffusion length
        
        # Network input
        xin = torch.stack([x.squeeze(-1), log10f.squeeze(-1), 
                          dV.squeeze(-1)], dim=-1)
        out = self.net(xin)
        
        A_r = out[..., 0:1]  # Real amplitude
        A_i = out[..., 1:2]  # Imaginary amplitude
        
        # Physics-informed ansatz: exp(-x/delta) * [cos, sin] with BCs
        x_norm = x / (delta + 1e-12)
        exp_decay = torch.exp(-x_norm)
        cos_term = torch.cos(x_norm)
        sin_term = torch.sin(x_norm)
        
        # Complex concentration
        c_r = exp_decay * (A_r * cos_term + A_i * sin_term)
        c_i = exp_decay * (A_i * cos_term - A_r * sin_term)
        
        # Enforce c → 0 at x = d (far boundary)
        factor = (d - x) / d
        c_r = factor * c_r
        c_i = factor * c_i
        
        return c_r, c_i

model = EIS_PINN().to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

# =============================================================================
# SECTION 4: PHYSICS LOSS FUNCTIONS
# =============================================================================

def physics_loss(model, x_c, log10f_c, dV_c, x_b0, log10f_b, dV_b, 
                w_pde=1.0, w_bc=1.0):
    """
    Compute physics-informed loss
    
    Components:
    1. PDE residual: ∂c/∂t = D ∂²c/∂x² in frequency domain
    2. Boundary conditions at x=0 (electrode)
    3. Current balance (Faradaic + capacitive)
    
    Args:
        x_c, log10f_c, dV_c: collocation points (interior)
        x_b0, log10f_b, dV_b: boundary points (x=0)
        w_pde, w_bc: loss weights
    
    Returns:
        total_loss, dict of individual losses
    """
    Dv, dv, c0v, alphav, kv, Cdlv, Tv, dV = current_params()
    
    # Derived parameters
    gk = (F*F * kv * c0v) / (R_g * Tv)  # Kinetic parameter
    phi = (R_g * Tv) / (F * c0v)        # Potential scale
    
    # =========================================================================
    # 1. PDE LOSS (Interior)
    # =========================================================================
    c_r, c_i = model(x_c, log10f_c, dV_c)
    
    # Second derivative wrt x
    x_c.requires_grad_(True)
    c_r_x = grad(c_r, x_c)
    c_i_x = grad(c_i, x_c)
    c_r_xx = grad(c_r_x, x_c)
    c_i_xx = grad(c_i_x, x_c)
    
    # Frequency domain PDE: -iω*c = D*∂²c/∂x²
    f_c = 10.0 ** log10f_c
    omega_c = to_t(TWO_PI) * f_c
    
    pde_r = -omega_c * c_i - Dv * c_r_xx  # Real part
    pde_i =  omega_c * c_r - Dv * c_i_xx  # Imaginary part
    
    loss_pde = (pde_r.pow(2).mean() + pde_i.pow(2).mean())
    
    # =========================================================================
    # 2. BOUNDARY CONDITION LOSS (x=0, electrode)
    # =========================================================================
    c_r0, c_i0 = model(x_b0, log10f_b, dV_b)
    x_b0.requires_grad_(True)
    c_r0_x = grad(c_r0, x_b0)
    c_i0_x = grad(c_i0, x_b0)
    
    # Flux balance at electrode: -F*D*∂c/∂x = i_faradaic
    # Faradaic current (Butler-Volmer linearized)
    i_ct_r = gk * (dV - phi * c_r0)
    i_ct_i = - gk * (phi * c_i0)
    
    # Flux boundary condition
    bc_flux_r = (-F * Dv * c_r0_x - gk * (dV - phi * c_r0)).pow(2).mean()
    bc_flux_i = (-F * Dv * c_i0_x + gk * (phi * c_i0)).pow(2).mean()
    
    # =========================================================================
    # 3. CAPACITIVE CURRENT (Double layer)
    # =========================================================================
    f_b = 10.0 ** log10f_b
    omega_b = to_t(TWO_PI) * f_b
    
    i_dl_r = to_t(0.0) * i_ct_r      # Real part is zero
    i_dl_i = (omega_b * Cdlv * dV)   # Imaginary part
    
    i_tot_r = i_ct_r + i_dl_r
    i_tot_i = i_ct_i + i_dl_i
    
    # Total boundary loss
    bc_total = bc_flux_r + bc_flux_i
    
    # Regularization on total current (soft constraint)
    loss_bc = bc_total + 1e-4 * (i_tot_r.pow(2).mean() + i_tot_i.pow(2).mean())
    
    # Combined loss
    total_loss = w_pde * loss_pde + w_bc * loss_bc
    
    return total_loss, {
        'loss_pde': loss_pde.detach(),
        'loss_bc': loss_bc.detach(),
        'loss_flux_r': bc_flux_r.detach(),
        'loss_flux_i': bc_flux_i.detach()
    }

# =============================================================================
# SECTION 5: SAMPLING STRATEGIES
# =============================================================================

def sample_collocation(n_x=128, n_f=64):
    """
    Sample interior collocation points
    
    Uses biased sampling toward x=0 where gradients are largest
    """
    # Biased sampling: u^0.5 gives more samples near 0
    u = np.random.rand(n_x, 1)
    x = to_t((u ** 0.5)) * d
    
    # Log-uniform frequency sampling
    log10f = to_t(np.random.rand(n_f, 1) * (math.log10(f_max) - math.log10(f_min)) 
                  + math.log10(f_min))
    
    # Create grid (all combinations)
    X = x.repeat(log10f.shape[0], 1)
    F = log10f.repeat_interleave(n_x, dim=0)
    
    return X, F

def sample_boundary(n_f=128):
    """Sample boundary points at x=0"""
    log10f = to_t(np.random.rand(n_f, 1) * (math.log10(f_max) - math.log10(f_min)) 
                  + math.log10(f_min))
    x0 = to_t(np.zeros((n_f, 1)))
    return x0, log10f


# =============================================================================
# SECTION 6: TRAINING (ADJUSTED FOR PINN VALIDATION)
# =============================================================================

print("\n" + "="*80)
print("TRAINING")
print("="*80)

# Training configuration
max_steps = 10000
log_every = 500
lr = 1e-3

# Loss weights
w_pde = 10.0
w_bc = 5.0

# Optimizer
trainable_params = list(model.parameters())
optimizer = optim.Adam(trainable_params, lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.5)

# Create FIXED validation set
print(f"\nCreating fixed validation set...")
torch.manual_seed(123)  # Different seed
n_val_x, n_val_f = 50, 20
Xc_val, Fc_val = sample_collocation(n_x=n_val_x, n_f=n_val_f)
x0_val, Fb_val = sample_boundary(n_f=40)
dV_c_val = (torch.rand(len(Xc_val), 1, device=device) * 1.9 + 0.1) * deltaV
dV_b_val = (torch.rand(len(Fb_val), 1, device=device) * 1.9 + 0.1) * deltaV

# History tracking
history = {
    'step': [],
    'train_loss': [],
    'val_loss': [],
    'val_pde': [],
    'val_bc': [],
    'val_flux_r': [],
    'val_flux_i': []
}

print(f"\nTraining Configuration:")
print(f"  Max steps: {max_steps}")
print(f"  Learning rate: {lr}")
print(f"  Loss weights: PDE={w_pde}, BC={w_bc}")
print(f"\nTraining...")

for step in range(1, max_steps+1):
    optimizer.zero_grad()
    
    # Sample training points
    Xc, Fc = sample_collocation(n_x=100, n_f=40)
    x0, Fb = sample_boundary(n_f=80)
    dV_c = (torch.rand(len(Xc), 1, device=device) * 1.9 + 0.1) * deltaV
    dV_b = (torch.rand(len(Fb), 1, device=device) * 1.9 + 0.1) * deltaV
    
    # Compute loss and update
    loss, parts = physics_loss(model, Xc.requires_grad_(True), Fc, dV_c, 
                              x0.requires_grad_(True), Fb, dV_b, 
                              w_pde=w_pde, w_bc=w_bc)
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    # Validation and logging
    if step % log_every == 0:
        model.eval()
        # IMPORTANT: Do NOT use with torch.no_grad() here. 
        # The physics_loss function needs to compute gradients for the PDE.
        val_loss, val_parts = physics_loss(
            model, 
            Xc_val.requires_grad_(True), Fc_val, dV_c_val,
            x0_val.requires_grad_(True), Fb_val, dV_b_val,
            w_pde=w_pde, w_bc=w_bc
        )
        model.train()
        
        history['step'].append(step)
        history['train_loss'].append(loss.item())
        history['val_loss'].append(val_loss.item())
        history['val_pde'].append(val_parts['loss_pde'].item())
        history['val_bc'].append(val_parts['loss_bc'].item())
        history['val_flux_r'].append(val_parts['loss_flux_r'].item())
        history['val_flux_i'].append(val_parts['loss_flux_i'].item())
        
        print(f"  Step {step:5d} | Train: {loss:.2e} | Val: {val_loss:.2e} | "
              f"PDE: {val_parts['loss_pde']:.2e} | BC: {val_parts['loss_bc']:.2e}")

print("\n✓ Training complete!")

# =============================================================================
# SECTION 7: IMPEDANCE PREDICTION
# =============================================================================

print("\n" + "="*80)
print("IMPEDANCE CALCULATION")
print("="*80)

model.eval()

def predict_impedance(f_eval=np.logspace(-2, 4, 200)):
    """
    Predict impedance Z(ω) from trained PINN
    
    Args:
        f_eval: frequencies to evaluate [Hz]
    
    Returns:
        f_eval, Z_real, Z_imag [arrays]
    """
    Dv, dv, c0v, alphav, kv, Cdlv, Tv, dV = current_params()
    
    f_t = to_t(f_eval.reshape(-1, 1))
    log10f = torch.log10(f_t)
    
    x0 = to_t(np.zeros((len(f_eval), 1)))
    dV_t = to_t(dV * torch.ones_like(log10f))
    
    # Get concentration at electrode
    x0.requires_grad_(True)
    c_r0, c_i0 = model(x0, log10f, dV_t)
    c_r0_x = grad(c_r0, x0)
    c_i0_x = grad(c_i0, x0)
    
    phi = (R_g * Tv) / (F * c0v)
    gk = (F*F * kv * c0v) / (R_g * Tv)
    
    # Faradaic current
    i_ct_r = gk * (dV - phi * c_r0)
    i_ct_i = - gk * (phi * c_i0)
    
    # Capacitive current
    omega = to_t(TWO_PI) * f_t
    i_dl_r = torch.zeros_like(i_ct_r)
    i_dl_i = (omega * Cdlv * dV)
    
    # Total current
    i_tot_r = i_ct_r + i_dl_r
    i_tot_i = i_ct_i + i_dl_i
    
    # Impedance: Z = V / I
    den = i_tot_r**2 + i_tot_i**2 + 1e-18
    Z_r = (dV * i_tot_r) / den
    Z_i = (-dV * i_tot_i) / den
    
    return (f_eval, 
            Z_r.detach().cpu().numpy().squeeze(), 
            Z_i.detach().cpu().numpy().squeeze())

f_eval, Zr, Zi = predict_impedance()
print(f"\nImpedance calculated at {len(f_eval)} frequencies")
print(f"  |Z| range: [{np.sqrt(Zr**2 + Zi**2).min():.2e}, {np.sqrt(Zr**2 + Zi**2).max():.2e}] Ω·m²")

# =============================================================================
# SECTION 8: PUBLICATION-QUALITY PLOTS
# =============================================================================

print("\n" + "="*80)
print("GENERATING PLOTS")
print("="*80)

plt.rcParams.update({
    'font.size': 11,
    'font.family': 'sans-serif',
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'legend.fontsize': 10,
    'figure.dpi': 150,
    'lines.linewidth': 2.5,
})

# PLOT 1: Training History
print("\n  Creating training history...")
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Training History', fontsize=15, fontweight='bold')

steps = np.array(history['step'])

# Total loss
ax = axes[0, 0]
ax.semilogy(steps, history['train_loss'], 'o-', markersize=3, 
           label='Train', alpha=0.7, color='#1f77b4')
ax.semilogy(steps, history['val_loss'], 's-', markersize=4,
           label='Validation', color='#ff7f0e')
ax.set_xlabel('Step', fontweight='bold')
ax.set_ylabel('Total Loss', fontweight='bold')
ax.set_title('Total Loss', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# PDE loss
ax = axes[0, 1]
ax.semilogy(steps, history['val_pde'], color='#2ca02c')
ax.set_xlabel('Step', fontweight='bold')
ax.set_ylabel('PDE Loss', fontweight='bold')
ax.set_title('PDE Residual', fontweight='bold')
ax.grid(True, alpha=0.3)

# BC loss
ax = axes[1, 0]
ax.semilogy(steps, history['val_bc'], color='#d62728')
ax.set_xlabel('Step', fontweight='bold')
ax.set_ylabel('BC Loss', fontweight='bold')
ax.set_title('Boundary Condition', fontweight='bold')
ax.grid(True, alpha=0.3)

# Flux losses
ax = axes[1, 1]
ax.semilogy(steps, history['val_flux_r'], label='Real', color='#9467bd')
ax.semilogy(steps, history['val_flux_i'], label='Imag', color='#8c564b')
ax.set_xlabel('Step', fontweight='bold')
ax.set_ylabel('Flux Loss', fontweight='bold')
ax.set_title('Flux Components', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("01_training_history.png", dpi=300, bbox_inches='tight')
plt.savefig("01_training_history.pdf", bbox_inches='tight')
print("    ✓ 01_training_history")
plt.close()

# PLOT 2: Nyquist Plot
print("\n  Creating Nyquist plot...")
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.plot(Zr, -Zi, 'o-', ms=4, linewidth=2, color='#1f77b4')
ax.set_xlabel('Re(Z) [Ω·m²]', fontsize=13, fontweight='bold')
ax.set_ylabel('-Im(Z) [Ω·m²]', fontsize=13, fontweight='bold')
ax.set_title('Nyquist Plot', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.savefig("02_nyquist.png", dpi=300, bbox_inches='tight')
plt.savefig("02_nyquist.pdf", bbox_inches='tight')
print("    ✓ 02_nyquist")
plt.close()

# PLOT 3: Bode Plots
print("\n  Creating Bode plots...")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

Z_mag = np.sqrt(Zr**2 + Zi**2)
Z_phase = np.degrees(np.arctan2(Zi, Zr))

ax = axes[0]
ax.loglog(f_eval, Z_mag, linewidth=2.5, color='#1f77b4')
ax.set_xlabel('Frequency [Hz]', fontsize=12, fontweight='bold')
ax.set_ylabel('|Z| [Ω·m²]', fontsize=12, fontweight='bold')
ax.set_title('Bode Magnitude', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, which='both')

ax = axes[1]
ax.semilogx(f_eval, Z_phase, linewidth=2.5, color='#ff7f0e')
ax.set_xlabel('Frequency [Hz]', fontsize=12, fontweight='bold')
ax.set_ylabel('Phase [°]', fontsize=12, fontweight='bold')
ax.set_title('Bode Phase', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, which='both')

plt.tight_layout()
plt.savefig("03_bode.png", dpi=300, bbox_inches='tight')
plt.savefig("03_bode.pdf", bbox_inches='tight')
print("    ✓ 03_bode")
plt.close()

# PLOT 4: Time Domain Signals
print("\n  Creating time domain signals...")

def time_domain_signals(f_hz=1e-2, n_cycles=3, points_per_cycle=200):
    """Reconstruct time-domain V(t) and i(t) from impedance"""
    # Get impedance at chosen frequency
    f_eval = np.array([f_hz])
    _, Zr_val, Zi_val = predict_impedance(f_eval=f_eval)
    Zr_val, Zi_val = float(Zr_val), float(Zi_val)
    
    omega = 2*np.pi*f_hz
    T0 = 1.0 / f_hz
    t = np.linspace(0.0, n_cycles*T0, int(n_cycles*points_per_cycle))
    
    # Voltage (input)
    Dv, dv, c0v, alphav, kv, Cdlv, Tv, dV_p = current_params()
    dV_val = float(dV_p.item())
    V_t = dV_val * np.cos(omega * t)
    
    # Current (from impedance)
    Z_abs = np.hypot(Zr_val, Zi_val)
    Z_phase = np.arctan2(Zi_val, Zr_val)
    I_amp = dV_val / (Z_abs + 1e-18)
    i_t = I_amp * np.cos(omega * t + Z_phase)
    
    return t, V_t, i_t

t, V_t, i_t = time_domain_signals(f_hz=1e-2, n_cycles=3)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
ax.plot(t, V_t*1e3, linewidth=2.5, color='#1f77b4')
ax.set_xlabel('Time [s]', fontsize=12, fontweight='bold')
ax.set_ylabel('Voltage [mV]', fontsize=12, fontweight='bold')
ax.set_title('Applied Voltage (f=0.01 Hz)', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

ax = axes[1]
ax.plot(t, i_t*1e3, linewidth=2.5, color='#d62728')
ax.set_xlabel('Time [s]', fontsize=12, fontweight='bold')
ax.set_ylabel('Current Density [mA/m²]', fontsize=12, fontweight='bold')
ax.set_title('Current Response (f=0.01 Hz)', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("04_time_domain.png", dpi=300, bbox_inches='tight')
plt.savefig("04_time_domain.pdf", bbox_inches='tight')
print("    ✓ 04_time_domain")
plt.close()


print("\n  Creating Real vs Imaginary comparison plots...")

def concentration_profile_both(f_hz=1e-2, n_x=300):
    """Get concentration profiles for Species A (Ox) and B (Red)"""
    Dv, dv, c0v, alphav, kv, Cdlv, Tv, dV = current_params()
    x = np.linspace(0.0, float(dv.item()), n_x).reshape(-1, 1)
    
    f_vec = np.full_like(x, float(f_hz))
    x_t = to_t(x)
    log10f = torch.log10(to_t(f_vec))
    
    with torch.no_grad():
        dV_t = to_t(dV * torch.ones_like(x_t))
        # Get Species A (Oxidized) from PINN
        ca_r, ca_i = model(x_t, log10f, dV_t)
        
        # Calculate Species B (Reduced)
        cb_r = -ca_r
        cb_i = -ca_i
    
    return (x.squeeze(), 
            ca_r.cpu().numpy().squeeze(), ca_i.cpu().numpy().squeeze(),
            cb_r.cpu().numpy().squeeze(), cb_i.cpu().numpy().squeeze())

x, ca_r, ca_i, cb_r, cb_i = concentration_profile_both(f_hz=1e-2)
x_mm = x * 1e3 

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# PLOT 1: Both Real Parts
ax = axes[0]
ax.plot(x_mm, ca_r*1e3, linewidth=2.5, color='#1f77b4', label='Species A (Ox)')
ax.plot(x_mm, cb_r*1e3, linewidth=2.5, color='#d62728', linestyle='--', label='Species B (Red)')
ax.set_xlabel('Distance from Electrode [mm]', fontsize=12, fontweight='bold')
ax.set_ylabel(r'Re(C) [mmol/m³]', fontsize=12, fontweight='bold')
ax.set_title('Real Part Comparison', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)
ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))

# PLOT 2: Both Imaginary Parts
ax = axes[1]
ax.plot(x_mm, ca_i*1e3, linewidth=2.5, color='#1f77b4', label='Species A (Ox)')
ax.plot(x_mm, cb_i*1e3, linewidth=2.5, color='#d62728', linestyle='--', label='Species B (Red)')
ax.set_xlabel('Distance from Electrode [mm]', fontsize=12, fontweight='bold')
ax.set_ylabel(r'Im(C) [mmol/m³]', fontsize=12, fontweight='bold')
ax.set_title('Imaginary Part Comparison', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)
ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))



plt.tight_layout()
plt.savefig("05_concentration_comparison.png", dpi=300, bbox_inches='tight')
plt.savefig("05_concentration_comparison.pdf", bbox_inches='tight')
print("    ✓ 05_concentration_comparison")
plt.close()

# =============================================================================
# SUMMARY
# =============================================================================

print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"\nTraining:")
print(f"  Final train loss: {history['train_loss'][-1]:.2e}")
print(f"  Final val loss:   {history['val_loss'][-1]:.2e}")
print(f"\nImpedance:")
print(f"  Frequency range: {f_eval.min():.2e} - {f_eval.max():.2e} Hz")
print(f"  |Z| range: {Z_mag.min():.2e} - {Z_mag.max():.2e} Ω·m²")
print(f"\nGenerated plots:")
print("  01_training_history.png/pdf")
print("  02_nyquist.png/pdf")
print("  03_bode.png/pdf")
print("  04_time_domain.png/pdf")
print("  05_concentration_profiles.png/pdf")
print("\n" + "="*80)

ELECTROCHEMICAL IMPEDANCE SPECTROSCOPY - PINN

Device: cpu

Physical Parameters:
  Temperature: 298.15 K
  Diffusion coef: 1e-10 m²/s
  Thickness: 0.60 mm
  Frequency range: 0.01 - 100.0 Hz

NEURAL NETWORK ARCHITECTURE

Total parameters: 50,306

TRAINING

Creating fixed validation set...

Training Configuration:
  Max steps: 10000
  Learning rate: 0.001
  Loss weights: PDE=10.0, BC=5.0

Training...
  Step   500 | Train: 6.33e-04 | Val: 7.23e-04 | PDE: 2.07e-10 | BC: 1.45e-04
  Step  1000 | Train: 1.79e-04 | Val: 1.69e-04 | PDE: 2.25e-10 | BC: 3.38e-05
  Step  1500 | Train: 8.26e-05 | Val: 3.37e-04 | PDE: 2.35e-10 | BC: 6.74e-05
  Step  2000 | Train: 9.62e-05 | Val: 1.77e-04 | PDE: 2.33e-10 | BC: 3.54e-05
  Step  2500 | Train: 6.36e-05 | Val: 5.92e-05 | PDE: 2.36e-10 | BC: 1.18e-05
  Step  3000 | Train: 7.08e-05 | Val: 4.94e-05 | PDE: 2.36e-10 | BC: 9.88e-06
  Step  3500 | Train: 9.68e-05 | Val: 4.75e-05 | PDE: 2.36e-10 | BC: 9.50e-06
  Step  4000 | Train: 2.70e-04 | Val: 8.46e-05 | PDE