# ENM5320 Shearflow Assignment (Q1 + Q2)

This notebook addresses:
1. **Q1**: Load and verify the **mean concentration field only**.
2. **Q2**: Fit a **linear 3-point finite-difference stencil** to the mean concentration data.

The notebook intentionally uses only `tracer_mean` as requested.

## Training Strategy (Required Explanation for Q2)

We model the mean concentration profile $u(t,x)$ with a constant linear 3-point stencil
$$L[u]_i = a u_{i-1} + b u_i + c u_{i+1},$$
and use implicit Euler time stepping
$$u^{n+1} = u^n + \Delta t\,L[u^{n+1}],$$
which yields
$$(I - \Delta t\,D(a,b,c))u^{n+1} = u^n.$$

We optimize $(a,b,c)$ by rolling out this dynamics through all time steps and minimizing MSE to the observed mean concentration snapshots:
$$\mathcal{L} = \frac{1}{NT}\sum_{n,i}\left(u^{n}_{\text{pred},i} - u^{n}_{\text{true},i}\right)^2.$$

Optimization uses Adam. This is a baseline model and does **not** include velocity coupling or unresolved nonlinear effects, as discussed in class.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import urllib.request

torch.set_default_dtype(torch.float32)

In [None]:
# ========================
# MODIFY THESE PARAMETERS
# ========================
Re = 5e5
Sc = 2.0e-1
ic_index = 0

print(f'Target case: Re={Re:.0e}, Sc={Sc:.1e}, ic={ic_index}')

In [None]:
# Q1: Load ONLY mean concentration field
filename = f'shearflow_1d_profiles_Re{Re:.0e}_Sc{Sc:.1e}_ic{ic_index}.npz'
path = Path(filename)

if not path.exists():
    url = 'https://raw.githubusercontent.com/natrask/ENM5320-2026/main/NewMaterial/shearflow_project/' + filename
    print(f'Local file not found. Downloading from: {url}')
    urllib.request.urlretrieve(url, filename)
    print('Download complete.')

data = np.load(filename)

# Load required arrays
time = data['time']
x = data['x']
tracer_mean = data['tracer_mean']   # ONLY mean concentration field

print('Loaded keys:', list(data.keys()))
print('tracer_mean shape:', tracer_mean.shape)
print('time range:', float(time[0]), '->', float(time[-1]))
print('x range:', float(x[0]), '->', float(x[-1]))

In [None]:
# Q1 evidence plot 1: snapshots of mean concentration
fig, ax = plt.subplots(figsize=(8, 4))
n_time = len(time)
idxs = [0, n_time//4, n_time//2, 3*n_time//4, n_time-1]
for idx in idxs:
    ax.plot(x, tracer_mean[idx], label=f't={time[idx]:.2f}')
ax.set_xlabel('x')
ax.set_ylabel('mean concentration')
ax.set_title('Q1 Evidence: Mean concentration snapshots')
ax.grid(True, alpha=0.3)
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Q1 evidence plot 2 and 3: spacetime map + variance decay
variance_t = np.var(tracer_mean, axis=1)

fig, axes = plt.subplots(1, 2, figsize=(13, 4))

im = axes[0].contourf(x, time, tracer_mean, levels=30, cmap='viridis')
axes[0].set_xlabel('x')
axes[0].set_ylabel('t')
axes[0].set_title('Q1 Evidence: Mean concentration $c(t,x)$')
plt.colorbar(im, ax=axes[0], label='c')

axes[1].semilogy(time, variance_t, 'o-', ms=3)
axes[1].set_xlabel('t')
axes[1].set_ylabel('Var_x(c)')
axes[1].set_title('Q1 Evidence: Spatial variance decay')
axes[1].grid(True, alpha=0.3, which='both')

plt.tight_layout()
plt.show()

## Q2: Fit a Linear 3-Point Stencil

In [None]:
device = 'cpu'
u_true = torch.tensor(tracer_mean, dtype=torch.float32, device=device)  # shape (T, N)
t_torch = torch.tensor(time, dtype=torch.float32, device=device)
x_torch = torch.tensor(x, dtype=torch.float32, device=device)

T, N = u_true.shape
dt = float(time[1] - time[0])
dx = float(x[1] - x[0])

print(f'T={T}, N={N}, dt={dt:.4e}, dx={dx:.4e}')

In [None]:
def build_stencil_matrix_periodic(N, coeffs, device='cpu'):
    # coeffs = [a, b, c]
    a, b, c = coeffs[0], coeffs[1], coeffs[2]
    D = torch.zeros((N, N), dtype=coeffs.dtype, device=device)
    idx = torch.arange(N, device=device)
    D[idx, idx] = b
    D[idx, (idx - 1) % N] = a
    D[idx, (idx + 1) % N] = c
    return D

def rollout_implicit_euler(u0, coeffs, T, dt):
    N = u0.shape[0]
    D = build_stencil_matrix_periodic(N, coeffs, device=u0.device)
    A = torch.eye(N, device=u0.device, dtype=u0.dtype) - dt * D

    u_hist = [u0]
    u = u0
    for _ in range(T - 1):
        u = torch.linalg.solve(A, u)
        u_hist.append(u)
    return torch.stack(u_hist, dim=0)

In [None]:
# Initialize near centered second-derivative stencil
coeffs = torch.nn.Parameter(torch.tensor([1.0, -2.0, 1.0], device=device) / (dx**2))
optimizer = torch.optim.Adam([coeffs], lr=5e-3)

num_epochs = 2000
loss_history = []

for epoch in range(num_epochs):
    optimizer.zero_grad()

    u_pred = rollout_implicit_euler(u_true[0], coeffs, T=T, dt=dt)
    loss = torch.mean((u_pred - u_true) ** 2)

    loss.backward()
    optimizer.step()

    loss_history.append(float(loss.item()))

    if epoch % 200 == 0 or epoch == num_epochs - 1:
        c = coeffs.detach().cpu().numpy()
        print(f'Epoch {epoch:4d} | loss={loss.item():.4e} | coeffs={c}')

print('Training complete.')
print('Learned coeffs [a,b,c] =', coeffs.detach().cpu().numpy())

In [None]:
# Q2 results: loss curve + prediction comparison
with torch.no_grad():
    u_fit = rollout_implicit_euler(u_true[0], coeffs, T=T, dt=dt).cpu().numpy()

fig, axes = plt.subplots(1, 2, figsize=(13, 4))

axes[0].semilogy(loss_history)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE loss')
axes[0].set_title('Q2 Training loss')
axes[0].grid(True, alpha=0.3, which='both')

idxs = [0, n_time//4, n_time//2, 3*n_time//4, n_time-1]
for idx in idxs:
    axes[1].plot(x, tracer_mean[idx], '-', lw=2, alpha=0.8)
    axes[1].plot(x, u_fit[idx], '--', lw=1.5, alpha=0.8)
axes[1].set_xlabel('x')
axes[1].set_ylabel('concentration')
axes[1].set_title('Q2 True (solid) vs Linear-stencil fit (dashed)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Report final errors
with torch.no_grad():
    u_fit_torch = rollout_implicit_euler(u_true[0], coeffs, T=T, dt=dt)
    mse = torch.mean((u_fit_torch - u_true) ** 2).item()
    rel_l2 = (torch.norm(u_fit_torch - u_true) / torch.norm(u_true)).item()

print(f'Final MSE: {mse:.6e}')
print(f'Relative L2 error: {rel_l2:.6e}')
print('Learned stencil [a,b,c]:', coeffs.detach().cpu().numpy())

## Brief Discussion

This linear stencil baseline captures part of the mean-concentration evolution but cannot represent full nonlinear transport physics.
As noted in the assignment, the true process depends on velocity coupling and unresolved variability (e.g., standard deviations), which motivates nonlinear models in the next assignment.