In [4]:
"""
=================================================================================
CYCLIC VOLTAMMETRY: ADAPTIVE LOSS WEIGHTING
=================================================================================
PROBLEM DIAGNOSIS:
1. Loss scale mismatch (flux ~10^0, physics ~10^-5, conservation ~10^-6)
2. Fixed weights don't adapt to changing dynamics
3. Gradient magnitude imbalance causes oscillations

SOLUTION: GradNorm + improved scheduler + gradient clipping
=================================================================================
"""
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F_nn

# Physical constants
F, R, T = 96485.0, 8.314, 298.0
E_start, E_vertex, E_end, E_0 = 0.2, -0.8, 0.2, -0.4
nu, D, n, alpha, k_0, A, L = 10.0, 1e-6, 1, 0.5, 1e-1, 1.0, 0.1
npts_x, npts_t = 100, 500
systems = [
    {'name': 'System 1', 'c_O': 1.0, 'c_R': 1e-2, 'ratio': 100},
    {'name': 'System 2', 'c_O': 1.0, 'c_R': 1.0, 'ratio': 1},
    {'name': 'System 3', 'c_O': 1e-2, 'c_R': 1.0, 'ratio': 0.01},
]

sweep_rate_V_per_s = nu / 1000.0
t_forward = abs(E_vertex - E_start) / sweep_rate_V_per_s
t_return = abs(E_end - E_vertex) / sweep_rate_V_per_s
T_final, t_switch = t_forward + t_return, t_forward

print("="*80)
print("CV SIMULATION: ADAPTIVE LOSS WEIGHTING")
print("="*80)

# Finite difference
del_x, del_t = L / (npts_x - 1), T_final / (npts_t - 1)
x_grid = np.linspace(0, L, npts_x)
t_grid = np.linspace(0, T_final, npts_t)
lambda_fd = D * del_t / del_x**2

def E_of_t_np(tt):
    return E_start - sweep_rate_V_per_s * tt if tt <= t_switch else E_vertex + sweep_rate_V_per_s * (tt - t_switch)

def run_finite_difference(c_O_bulk, c_R_bulk):
    c_R = np.full((npts_x, npts_t), c_R_bulk, dtype=float)
    c_O = np.full((npts_x, npts_t), c_O_bulk, dtype=float)
    E_t, i_t = np.zeros(npts_t), np.zeros(npts_t)
    E_eq = E_0 + (R * T / (n * F)) * np.log(c_O_bulk / c_R_bulk)
    M_to_mol_cm3 = 1e-3
    
    for j in range(npts_t):
        E_curr = E_of_t_np(t_grid[j])
        E_t[j], eta = E_curr, E_curr - E_eq
        k_red = k_0 * np.exp((-alpha * n * F * eta) / (R * T))
        k_ox = k_0 * np.exp(((1.0 - alpha) * n * F * eta) / (R * T))
        
        if j > 0:
            c_R_prev, c_O_prev = c_R[:, j-1], c_O[:, j-1]
            c_R[1:-1, j] = c_R_prev[1:-1] + lambda_fd * (c_R_prev[2:] - 2*c_R_prev[1:-1] + c_R_prev[:-2])
            c_O[1:-1, j] = c_O_prev[1:-1] + lambda_fd * (c_O_prev[2:] - 2*c_O_prev[1:-1] + c_O_prev[:-2])
            c_R[-1, j], c_O[-1, j] = c_R_bulk, c_O_bulk
            c_R_pred = c_R_prev[0] + lambda_fd * (c_R_prev[1] - c_R_prev[0])
            c_O_pred = c_O_prev[0] + lambda_fd * (c_O_prev[1] - c_O_prev[0])
            S, k_sum = c_R_pred + c_O_pred, k_red + k_ox
            if k_sum > 0:
                C_R_eq = (k_red / k_sum) * S
                c_R[0, j] = C_R_eq + (c_R_pred - C_R_eq) * np.exp(-k_sum * del_t)
                c_O[0, j] = S - c_R[0, j]
            else:
                c_R[0, j], c_O[0, j] = c_R_pred, c_O_pred
            c_R[0, j], c_O[0, j] = max(c_R[0, j], 0.0), max(c_O[0, j], 0.0)
            r_net = k_red * c_O[0, j] - k_ox * c_R[0, j]
            i_t[j] = n * F * A * M_to_mol_cm3 * r_net
    
    return {'c_R': c_R, 'c_O': c_O, 'E_t': E_t, 'i_t': i_t, 'E_eq': E_eq}

print("\nRunning FD...")
fd_results = [run_finite_difference(s['c_O'], s['c_R']) | {'system': s} for s in systems]
print("✓ FD complete")

# Neural networks
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
l_p_fixed = 3.5e-3

class PINN_Oxidized(nn.Module):
    def __init__(self, w=50, d=5):
        super().__init__()
        layers = [nn.Linear(2, w), nn.SiLU()]
        for _ in range(d - 2):
            layers += [nn.Linear(w, w), nn.SiLU()]
        layers.append(nn.Linear(w, 1))
        self.net = nn.Sequential(*layers)
    
    def forward(self, xt):
        x = xt[:, 0:1]
        recovery = 1.0 - torch.exp(-x / torch.tensor(l_p_fixed, device=xt.device))
        return F_nn.softplus(self.net(xt)) * recovery

class PINN_Reduced(nn.Module):
    def __init__(self, w=50, d=5):
        super().__init__()
        layers = [nn.Linear(2, w), nn.SiLU()]
        for _ in range(d - 2):
            layers += [nn.Linear(w, w), nn.SiLU()]
        layers.append(nn.Linear(w, 1))
        self.net = nn.Sequential(*layers)
    
    def forward(self, xt):
        x = xt[:, 0:1]
        decay = torch.exp(-x / torch.tensor(l_p_fixed, device=xt.device))
        return F_nn.softplus(self.net(xt)) * decay

# GradNorm weighting
class GradNormWeighting:
    def __init__(self, n=3, alpha=1.5):
        self.alpha, self.n, self.initial = alpha, n, None
        self.weights = torch.ones(n)
    
    def update(self, losses, params, lr=0.025):
        if self.initial is None:
            self.initial = torch.tensor([l.item() for l in losses])
        
        grad_norms = []
        for loss in losses:
            grads = torch.autograd.grad(loss, params, retain_graph=True, create_graph=False)
            grad_norms.append(torch.sqrt(sum([(g**2).sum() for g in grads])))
        grad_norms = torch.stack(grad_norms)
        
        mean_norm = grad_norms.mean()
        loss_ratios = torch.tensor([l.item() / (self.initial[i] + 1e-8) for i, l in enumerate(losses)])
        targets = mean_norm * (loss_ratios ** self.alpha)
        
        for i in range(self.n):
            if targets[i] > 0:
                self.weights[i] *= (1 + lr * (targets[i] / (grad_norms[i] + 1e-8) - 1))
                self.weights[i] = max(self.weights[i].item(), 0.01)
        self.weights = self.weights * self.n / self.weights.sum()
        return self.weights.clone()

# Helper functions
def create_interp(t_data, y_data, dev):
    t_t = torch.tensor(t_data, dtype=torch.float32, device=dev)
    y_t = torch.tensor(y_data, dtype=torch.float32, device=dev)
    def interp(t_q):
        tc = t_q.view(-1).clamp(min=t_t[0], max=t_t[-1])
        idx = torch.bucketize(tc, t_t, right=False).clamp(1, len(t_t)-1)
        i0, i1 = idx - 1, idx
        w = (tc - t_t[i0]) / (t_t[i1] - t_t[i0] + 1e-12)
        return (y_t[i0] + w * (y_t[i1] - y_t[i0])).view_as(t_q)
    return interp

def E_torch(t):
    ts = torch.tensor(t_switch, dtype=t.dtype, device=t.device)
    fwd = E_start - sweep_rate_V_per_s * t
    bwd = E_vertex + sweep_rate_V_per_s * (t - ts)
    return torch.where(t <= ts, fwd, bwd)

def k_red_t(t, Eeq):
    return k_0 * torch.exp((-alpha * n * F * (E_torch(t) - Eeq)) / (R * T))

def k_ox_t(t, Eeq):
    return k_0 * torch.exp(((1.0 - alpha) * n * F * (E_torch(t) - Eeq)) / (R * T))

def compute_c(model, x, t, cb, csf, stype):
    H = (t > 0).float()
    cs = csf(t)
    cf = torch.tensor(cb, dtype=t.dtype, device=t.device)
    xt = torch.cat([x, t], dim=1)
    cr = model(xt)
    c0 = model(torch.cat([torch.zeros_like(t), t], dim=1))
    cL = model(torch.cat([torch.full_like(t, L), t], dim=1))
    if stype == 'O':
        cn = ((cr - c0) / (cL - c0 + 1e-8)).clamp(0, 1)
        ct = cs + (cf - cs) * cn
    else:
        cn = ((cr - cL) / (c0 - cL + 1e-8)).clamp(0, 1)
        ct = cf + (cs - cf) * cn
    return cf * (1 - H) + ct * H

def compute_derivs(mO, mR, x, t, cOb, cRb, cOf, cRf):
    xr = x.detach().clone().requires_grad_(True)
    tr = t.detach().clone().requires_grad_(True)
    cO = compute_c(mO, xr, tr, cOb, cOf, 'O')
    cR = compute_c(mR, xr, tr, cRb, cRf, 'R')
    ones = torch.ones_like(cO)
    dcOdt = torch.autograd.grad(cO, tr, ones, create_graph=True)[0]
    dcOdx = torch.autograd.grad(cO, xr, ones, create_graph=True)[0]
    dcRdt = torch.autograd.grad(cR, tr, ones, create_graph=True)[0]
    dcRdx = torch.autograd.grad(cR, xr, ones, create_graph=True)[0]
    d2cOdx2 = torch.autograd.grad(dcOdx, xr, ones, create_graph=True)[0]
    d2cRdx2 = torch.autograd.grad(dcRdx, xr, ones, create_graph=True)[0]
    return {'cO': cO, 'dcOdt': dcOdt, 'dcOdx': dcOdx, 'd2cOdx2': d2cOdx2,
            'cR': cR, 'dcRdt': dcRdt, 'dcRdx': dcRdx, 'd2cRdx2': d2cRdx2}

def compute_losses(mO, mR, xi, ti, xb, tb, cOb, cRb, Eeq, cOf, cRf):
    d = compute_derivs(mO, mR, xi, ti, cOb, cRb, cOf, cRf)
    resO = d['dcOdt'] - D * d['d2cOdx2']
    resR = d['dcRdt'] - D * d['d2cRdx2']
    lp = (resO.pow(2).mean() + resR.pow(2).mean()) / 2
    
    db = compute_derivs(mO, mR, xb, tb, cOb, cRb, cOf, cRf)
    M = 1e-3
    rf = (k_red_t(tb, Eeq) * db['cO'] - k_ox_t(tb, Eeq) * db['cR']) * M
    flxO = (-D * db['dcOdx']) - rf
    flxR = (-D * db['dcRdx']) + rf
    lf = (flxO.pow(2).mean() + flxR.pow(2).mean()) / 2
    
    lc = (d['cO'] + d['cR'] - (cOb + cRb)).pow(2).mean()
    return lp, lf, lc

# Training setup
print("\n" + "="*80)
print("TRAINING WITH GRADNORM")
print("="*80)

N_int, N_bc, epochs, lr = 4096, 1024, 9000, 1e-3
torch.manual_seed(42)

pinn_models = []
for i, sys in enumerate(systems):
    torch.manual_seed(42 + i)
    mO, mR = PINN_Oxidized().to(DEVICE), PINN_Reduced().to(DEVICE)
    cOf = create_interp(t_grid, fd_results[i]['c_O'][0, :], DEVICE)
    cRf = create_interp(t_grid, fd_results[i]['c_R'][0, :], DEVICE)
    params = list(mO.parameters()) + list(mR.parameters())
    opt = torch.optim.Adam(params, lr=lr)
    sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=1000, eta_min=1e-5)
    gn = GradNormWeighting(n=3, alpha=1.5)
    pinn_models.append({'O': mO, 'R': mR, 'system': sys, 'cOf': cOf, 'cRf': cRf,
                       'opt': opt, 'sch': sch, 'gn': gn, 'hist': []})

print(f"Training {len(pinn_models)} systems for {epochs} epochs...")

for step in range(epochs + 1):
    for i, (md, fdr) in enumerate(zip(pinn_models, fd_results)):
        xi = torch.rand(N_int, 1, device=DEVICE) * L
        ti = torch.rand(N_int, 1, device=DEVICE) * T_final
        xb = torch.zeros(N_bc, 1, device=DEVICE)
        tb = torch.rand(N_bc, 1, device=DEVICE) * T_final
        
        sys = md['system']
        lp, lf, lc = compute_losses(md['O'], md['R'], xi, ti, xb, tb,
                                    sys['c_O'], sys['c_R'], fdr['E_eq'],
                                    md['cOf'], md['cRf'])
        losses = [lp, lf, lc]
        
        if step % 10 == 0 and step > 0:
            sp = list(md['O'].parameters())[:2]
            w = md['gn'].update(losses, sp)
            lt = sum(wi * li for wi, li in zip(w, losses))
        else:
            w = torch.tensor([1.0, 10.0, 1.0])
            lt = w[0]*lp + w[1]*lf + w[2]*lc
        
        md['opt'].zero_grad()
        lt.backward()
        torch.nn.utils.clip_grad_norm_(list(md['O'].parameters()) + list(md['R'].parameters()), 1.0)
        md['opt'].step()
        md['sch'].step()
        
        if step % 500 == 0:
            wn = w.detach().cpu().numpy() if isinstance(w, torch.Tensor) else w.numpy()
            md['hist'].append({'step': step, 'total': lt.item(), 'physics': lp.item(),
                             'flux': lf.item(), 'conservation': lc.item(), 'weights': wn.copy()})
    
    if step % 1000 == 0:
        al = np.mean([h[-1]['total'] for m in pinn_models for h in [m['hist']] if h])
        print(f"  Step {step:5d} | Avg Loss: {al:.2e}")

print("✓ Training complete!")

# Evaluation
print("\n" + "="*80)
print("EVALUATION")
print("="*80)

for md in pinn_models:
    md['O'].eval()
    md['R'].eval()

Nt, Nx = 500, 120
te = torch.linspace(0, T_final, Nt, device=DEVICE).view(-1, 1)
xe = torch.linspace(0, L, Nx, device=DEVICE).view(-1, 1)

def eval_pinn(md, fdr):
    sys = md['system']
    with torch.no_grad():
        x0 = torch.zeros_like(te)
        cOs = compute_c(md['O'], x0, te, sys['c_O'], md['cOf'], 'O')
        cRs = compute_c(md['R'], x0, te, sys['c_R'], md['cRf'], 'R')
        M = 1e-3
        kr = k_red_t(te, fdr['E_eq'])
        ko = k_ox_t(te, fdr['E_eq'])
        rn = kr * cOs - ko * cRs
        ip = n * F * A * M * rn
        Ep = E_torch(te)
        
        mask = (Ep >= -0.8) & (Ep < 0.1)
        im = ip.clone()
        im[~mask] = float('nan')
        
        if sys['ratio'] >= 1.0:
            tmp = torch.where(torch.isnan(im), torch.tensor(float('-inf'), device=im.device), im)
            ipk = torch.argmax(tmp).item()
        else:
            tmp = torch.where(torch.isnan(im), torch.tensor(float('inf'), device=im.device), im)
            ipk = torch.argmin(tmp).item()
        
        ib = max(0, int(0.8 * ipk))
        ia = min(Nt - 1, int(1.2 * ipk))
        
        profs = {}
        for lbl, idx in zip(['t1', 't2', 't3'], [ib, ipk, ia]):
            tp = te[idx].repeat(Nx, 1)
            cOp = compute_c(md['O'], xe, tp, sys['c_O'], md['cOf'], 'O')
            cRp = compute_c(md['R'], xe, tp, sys['c_R'], md['cRf'], 'R')
            profs[lbl] = {'cO': cOp.cpu().numpy(), 'cR': cRp.cpu().numpy(),
                         't': te[idx].item(), 'E': Ep[idx].item()}
        
        return {'E': Ep.cpu().numpy().flatten(), 'i': ip.cpu().numpy().flatten(),
                't': te.cpu().numpy().flatten(), 'profiles': profs, 'indices': [ib, ipk, ia],
                'k_red': kr.cpu().numpy().flatten(), 'k_ox': ko.cpu().numpy().flatten()}

pinn_results = [eval_pinn(md, fdr) | {'system': md['system']} for md, fdr in zip(pinn_models, fd_results)]

# Plotting
print("\nGenerating plots...")
plt.rcParams.update({'font.size': 11, 'axes.labelsize': 12, 'axes.titlesize': 13,
                    'legend.fontsize': 10, 'figure.dpi': 150, 'lines.linewidth': 2.5})

x_mm = xe.cpu().numpy().flatten() * 10
x_mm_fd = x_grid * 10
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
prof_colors = ['#d62728', '#2ca02c', '#9467bd']

# Plot 1: Voltammograms
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for i, (ax, fdr, pr) in enumerate(zip(axes, fd_results, pinn_results)):
    sys = fdr['system']
    mfd = (fdr['E_t'] >= -0.8) & (fdr['E_t'] < 0.1)
    mpn = (pr['E'] >= -0.8) & (pr['E'] < 0.1)
    ax.plot(fdr['E_t'][mfd], fdr['i_t'][mfd]*1e6, '--', lw=3, label='FD', color='gray', alpha=0.6)
    ax.plot(pr['E'][mpn], pr['i'][mpn]*1e6, lw=2.5, label='PINN', color=colors[i])
    for idx, col in zip(pr['indices'], prof_colors):
        if -0.8 <= pr['E'][idx] < 0.1:
            ax.axvline(pr['E'][idx], color=col, lw=1.5, alpha=0.5)
    ax.set_xlabel('Potential (V)', fontweight='bold')
    ax.set_ylabel('Current (μA)', fontweight='bold')
    ax.set_title(f"{sys['name']}: c_O/c_R = {sys['ratio']}", fontweight='bold')
    ax.set_xlim(-0.8, 0.1)
    ax.axhline(0, color='k', ls='--', alpha=0.3, lw=1)
    ax.legend()
    ax.grid(True, alpha=0.2)
plt.tight_layout()
plt.savefig('01_voltammogram_comparison_all_systems.png', dpi=300, bbox_inches='tight')
plt.savefig('01_voltammogram_comparison_all_systems.pdf', bbox_inches='tight')
print("  ✓ 01_voltammogram_comparison")
plt.close()

# Plot 2: Profiles
for i, (fdr, pr) in enumerate(zip(fd_results, pinn_results)):
    sys = fdr['system']
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f"{sys['name']}: Concentration Profiles and Dynamics", fontsize=15, fontweight='bold')
    
    # Potential
    axes[0, 0].plot(pr['t'], pr['E'], lw=2, color=colors[i])
    axes[0, 0].axhline(E_0, ls='--', color='k', lw=1.2, label='E₀')
    axes[0, 0].axhline(fdr['E_eq'], ls=':', color='red', lw=1.5, label='E_eq')
    axes[0, 0].axvline(t_switch, ls='--', color='gray', lw=1, label='Switch')
    for pf in pr['profiles'].values():
        axes[0, 0].axvline(pf['t'], ls=':', alpha=0.5, lw=1)
    axes[0, 0].set_xlabel('Time (s)', fontweight='bold')
    axes[0, 0].set_ylabel('Potential (V)', fontweight='bold')
    axes[0, 0].set_title('Potential Waveform', fontweight='bold')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.2)
    
    # Rate constants (dummy - not computed in eval)
    #[0, 1].text(0.5, 0.5, 'Rate Constants\n(see training)', ha='center', va='center', transform=axes[0, 1].transAxes)
    
    ax = axes[0, 1]
    ax.plot(pr['t'], pr['k_red'], linewidth=2, 
                label='k_red', color='#1f77b4')
    ax.plot(pr['t'], pr['k_ox'], linewidth=2,
                label='k_ox', color='#ff7f0e')
    ax.axvline(t_switch, linestyle='--', color='gray', linewidth=1, alpha=0.5)
    ax.set_xlabel('Time (s)', fontweight='bold')
    axes[0, 1].set_xlabel('Time (s)', fontweight='bold')
    axes[0, 1].set_ylabel('Rate Constant (cm/s)', fontweight='bold')
    axes[0, 1].set_title('Rate Constants', fontweight='bold')
    axes[0, 1].grid(True, alpha=0.2)
    
    # c_R profiles
    for j, (lbl, pf) in enumerate(pr['profiles'].items()):
        ifd = np.argmin(np.abs(t_grid - pf['t']))
        axes[1, 0].plot(x_mm_fd, fdr['c_R'][:, ifd] / sys['c_R'], '--', lw=2, color=prof_colors[j], alpha=0.5)
        axes[1, 0].plot(x_mm, pf['cR'] / sys['c_R'], lw=2.5, color=prof_colors[j],
                       label=f"t={pf['t']:.2f}s (E={pf['E']:.2f}V)")
    axes[1, 0].axhline(1.0, ls='--', color='k', lw=1.2)
    axes[1, 0].set_xlabel('Distance from Electrode (mm)', fontweight='bold')
    axes[1, 0].set_ylabel('c_R / c_R,bulk', fontweight='bold')
    axes[1, 0].set_title('Reduced Species', fontweight='bold')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.2)
    
    # c_O profiles
    for j, (lbl, pf) in enumerate(pr['profiles'].items()):
        ifd = np.argmin(np.abs(t_grid - pf['t']))
        axes[1, 1].plot(x_mm_fd, fdr['c_O'][:, ifd] / sys['c_O'], '--', lw=2, color=prof_colors[j], alpha=0.5)
        axes[1, 1].plot(x_mm, pf['cO'] / sys['c_O'], lw=2.5, color=prof_colors[j],
                       label=f"t={pf['t']:.2f}s (E={pf['E']:.2f}V)")
    axes[1, 1].axhline(1.0, ls='--', color='k', lw=1.2)
    axes[1, 1].set_xlabel('Distance from Electrode (mm)', fontweight='bold')
    axes[1, 1].set_ylabel('c_O / c_O,bulk', fontweight='bold')
    axes[1, 1].set_title('Oxidized Species', fontweight='bold')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.2)
    
    plt.tight_layout()
    plt.savefig(f'02_profiles_system{i+1}.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'02_profiles_system{i+1}.pdf', bbox_inches='tight')
    print(f"  ✓ 02_profiles_system{i+1}")
    plt.close()

# Plot 3: Training loss
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for i, (ax, md) in enumerate(zip(axes, pinn_models)):
    h = md['hist']
    if not h:
        continue
    steps = [hi['step'] for hi in h]
    ax.semilogy(steps, [hi['total'] for hi in h], lw=2.5, label='Total', color=colors[i])
    ax.semilogy(steps, [hi['physics'] for hi in h], lw=1.5, label='Physics', alpha=0.7)
    ax.semilogy(steps, [hi['flux'] for hi in h], lw=1.5, label='Flux', alpha=0.7)
    ax.semilogy(steps, [hi['conservation'] for hi in h], lw=1.5, label='Conservation', alpha=0.7)
    ax.set_xlabel('Training Step', fontweight='bold')
    ax.set_ylabel('Loss', fontweight='bold')
    ax.set_title(f"{md['system']['name']} Training Loss", fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('03_training_loss_curves.png', dpi=300, bbox_inches='tight')
plt.savefig('03_training_loss_curves.pdf', bbox_inches='tight')
print("  ✓ 03_training_loss_curves")
plt.close()

# Plot 4: Overlay
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
for i, fdr in enumerate(fd_results):
    sys = fdr['system']
    mfd = (fdr['E_t'] >= -0.8) & (fdr['E_t'] < 0.1)
    ax1.plot(fdr['E_t'][mfd], fdr['i_t'][mfd]*1e6, lw=2.5,
            label=f"{sys['name']} (c_O/c_R={sys['ratio']})", color=colors[i])

ax1.set_xlabel('Potential (V)', fontweight='bold')
ax1.set_ylabel('Current (μA)', fontweight='bold')
ax1.set_title('Finite Difference: All Systems', fontweight='bold', fontsize=14)
ax1.axhline(0, color='k', ls='--', alpha=0.3)
ax1.legend()
ax1.grid(True, alpha=0.2)

for i, pr in enumerate(pinn_results):
    sys = pr['system']
    mpn = (pr['E'] >= -0.8) & (pr['E'] < 0.1)
    ax2.plot(pr['E'][mpn], pr['i'][mpn]*1e6, lw=2.5,
            label=f"{sys['name']} (c_O/c_R={sys['ratio']})", color=colors[i])

ax2.set_xlabel('Potential (V)', fontweight='bold')
ax2.set_ylabel('Current (μA)', fontweight='bold')
ax2.set_title('PINN: All Systems', fontweight='bold', fontsize=14)
ax2.axhline(0, color='k', ls='--', alpha=0.3)
ax2.legend()
ax2.grid(True, alpha=0.2)

plt.tight_layout()
plt.savefig('04_all_systems_overlay.png', dpi=300, bbox_inches='tight')
plt.savefig('04_all_systems_overlay.pdf', bbox_inches='tight')
print("  ✓ 04_all_systems_overlay")
plt.close()

# Summary
print("\n" + "="*80)
print("SIMULATION SUMMARY")
print("="*80)

for i, (fdr, pr) in enumerate(zip(fd_results, pinn_results)):
    sys = fdr['system']
    print(f"\n{sys['name']} (c_O={sys['c_O']}, c_R={sys['c_R']}, ratio={sys['ratio']}):")
    print(f"  FD Peak Current:   {np.max(np.abs(fdr['i_t']))*1e6:.2f} μA")
    print(f"  PINN Peak Current: {np.max(np.abs(pr['i']))*1e6:.2f} μA")
    err = np.abs(np.max(np.abs(pr['i'])) - np.max(np.abs(fdr['i_t']))) / np.max(np.abs(fdr['i_t'])) * 100
    print(f"  Relative Error:    {err:.2f}%")

print("\n" + "="*80)
print("✓ ALL PLOTS GENERATED WITH ADAPTIVE WEIGHTING!")
print("="*80)
print("\nGenerated files:")
print("  01_voltammogram_comparison_all_systems.png/pdf")
print("  02_profiles_system1.png/pdf")
print("  02_profiles_system2.png/pdf")
print("  02_profiles_system3.png/pdf")
print("  03_training_loss_curves.png/pdf")
print("  04_all_systems_overlay.png/pdf")
print("\n" + "="*80)
print("\nKEY IMPROVEMENTS:")
print("  1. GradNorm: Automatically balances loss gradients")
print("  2. Cosine annealing: Smooth learning rate decay")
print("  3. Gradient clipping: Prevents exploding gradients")
print("  4. Adaptive weights: Updated every 10 steps")
print("="*80)

CV SIMULATION: ADAPTIVE LOSS WEIGHTING

Running FD...
✓ FD complete

TRAINING WITH GRADNORM
Training 3 systems for 9000 epochs...
  Step     0 | Avg Loss: 3.31e+01
  Step  1000 | Avg Loss: 2.50e-04
  Step  2000 | Avg Loss: 1.06e-03
  Step  3000 | Avg Loss: 3.67e-03
  Step  4000 | Avg Loss: 9.14e-04
  Step  5000 | Avg Loss: 7.12e-06
  Step  6000 | Avg Loss: 1.77e-05
  Step  7000 | Avg Loss: 4.37e-06
  Step  8000 | Avg Loss: 1.47e-03
  Step  9000 | Avg Loss: 1.54e-05
✓ Training complete!

EVALUATION

Generating plots...
  ✓ 01_voltammogram_comparison
  ✓ 02_profiles_system1
  ✓ 02_profiles_system2
  ✓ 02_profiles_system3
  ✓ 03_training_loss_curves
  ✓ 04_all_systems_overlay

SIMULATION SUMMARY

System 1 (c_O=1.0, c_R=0.01, ratio=100):
  FD Peak Current:   21570949.20 μA
  PINN Peak Current: 1143739008.00 μA
  Relative Error:    5202.22%

System 2 (c_O=1.0, c_R=1.0, ratio=1):
  FD Peak Current:   35881974.37 μA
  PINN Peak Current: 1143739383808.00 μA
  Relative Error:    3187404.02%

Sy