# Physics-Informed Water Breakthrough Prediction

This notebook demonstrates the **end-to-end pipeline** for training and evaluating a
Physics-Informed Neural Network (PINN) that predicts water breakthrough in oil reservoirs.

The model combines:
- **LSTM temporal encoder** for sequential production data
- **Buckley-Leverett physics branch** (fractional flow, Corey relative permeability)
- **Data-driven MLP branch** for residual patterns
- **Log-normal survival head** for time-to-breakthrough uncertainty (P10/P50/P90)
- **Breakthrough classifier** (binary detection)

### Outline
1. [Setup & Imports](#1-setup--imports)
2. [Configuration](#2-configuration)
3. [Data Loading & Feature Engineering](#3-data-loading--feature-engineering)
4. [Exploratory Data Analysis](#4-exploratory-data-analysis)
5. [Dataset Preparation (Sequences & Splits)](#5-dataset-preparation)
6. [Model Architecture](#6-model-architecture)
7. [Training](#7-training)
8. [Evaluation Metrics](#8-evaluation-metrics)
9. [Prediction Visualizations](#9-prediction-visualizations)
10. [Physics Interpretation](#10-physics-interpretation)
11. [Survival Analysis](#11-survival-analysis)
12. [Save & Export](#12-save--export)

---
## 1. Setup & Imports

Import all required libraries and the project source modules.
The project expects `torch`, `numpy`, `pandas`, `matplotlib`, `scikit-learn`, and `scipy`.

In [None]:
import sys
from pathlib import Path

# Ensure the project root is on the Python path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from src.data_loader import (
    build_dataset,
    generate_synthetic_volve_data,
    compute_water_cut,
    compute_cumulative_production,
    compute_time_features,
    detect_breakthrough,
    filter_production_wells,
    VOLVE_COLUMNS,
)
from src.model import PhysicsInformedBreakthroughModel
from src.train import train_model, TrainConfig
from src.evaluate import (
    evaluate_model,
    print_metrics,
    plot_training_history,
    plot_predictions,
    plot_fractional_flow_curve,
    plot_survival_analysis,
)
from src.physics import compute_breakthrough_time_analytical
from src.survival import compute_survival_function, compute_hazard_function, compute_percentiles

%matplotlib inline
plt.rcParams.update({"figure.dpi": 120, "font.size": 11})

print(f"PyTorch version: {torch.__version__}")
print(f"Project root:    {PROJECT_ROOT}")

---
## 2. Configuration

Set the random seed for reproducibility, choose a device, and define training
hyperparameters. Modify these values to experiment.

| Parameter | Description |
|-----------|-------------|
| `SEED` | Random seed for reproducibility |
| `SEQ_LENGTH` | Number of time steps in each input sequence |
| `TEST_FRACTION` | Fraction of data held out for testing (temporal split) |
| `DATA_PATH` | Path to Volve CSV, or `None` for synthetic data |

In [None]:
# --- Reproducibility ---
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# --- Device ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- Data ---
DATA_PATH = None          # Set to a CSV path to use real Volve data
SEQ_LENGTH = 30           # LSTM input window
TEST_FRACTION = 0.2       # 80/20 temporal split

# --- Training hyperparameters ---
config = TrainConfig(
    epochs=200,
    batch_size=64,
    learning_rate=1e-3,
    hidden_size=64,
    lstm_layers=2,
    dropout=0.1,
    physics_anneal_end=0.1,
    lambda_survival=0.2,
    model_save_path=str(PROJECT_ROOT / "models" / "best_model.pt"),
)

SAVE_DIR = str(PROJECT_ROOT / "results")
print(f"Results will be saved to: {SAVE_DIR}")

---
## 3. Data Loading & Feature Engineering

The pipeline supports two modes:
- **Real data**: Provide a path to the [Volve production CSV](https://www.kaggle.com/datasets/lamyalbert/volve-production-data).
- **Synthetic data**: Auto-generated if no CSV is provided. Mimics realistic decline
  curves, sigmoid water breakthrough, and pressure dynamics across 5 wells over 1500 days.

Feature engineering steps performed:
1. **Water cut** computation: `water / (oil + water)`
2. **Cumulative production** per well
3. **Days on production** (time feature)
4. **Breakthrough labelling**: first time water cut exceeds threshold for 3 consecutive days

In [None]:
# Generate or load raw data
if DATA_PATH and Path(DATA_PATH).exists():
    from src.data_loader import load_volve_data
    df_raw = load_volve_data(DATA_PATH)
    df_raw = filter_production_wells(df_raw)
    print(f"Loaded real Volve data: {len(df_raw)} rows")
else:
    df_raw = generate_synthetic_volve_data(n_wells=5, n_days=1500, seed=SEED)
    print(f"Generated synthetic data: {len(df_raw)} rows")

# Feature engineering
df = compute_water_cut(df_raw.copy())
df = compute_cumulative_production(df)
df = compute_time_features(df)
df = detect_breakthrough(df)

print(f"Wells: {df[VOLVE_COLUMNS['well']].nunique()}")
print(f"Date range: {df[VOLVE_COLUMNS['date']].min()} to {df[VOLVE_COLUMNS['date']].max()}")
df.head()

---
## 4. Exploratory Data Analysis

Visualize key production trends before training to understand the data.

### 4.1 Water Cut Over Time per Well

The sigmoid-like rise in water cut after breakthrough is the core signal the model
must capture. The vertical dashed line marks the labelled breakthrough point.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

well_col = VOLVE_COLUMNS["well"]

# Water cut over time
ax = axes[0]
for well_id, group in df.groupby(well_col):
    ax.plot(group["DAYS_ON_PRODUCTION"], group["WATER_CUT"],
            label=f"Well {well_id}", linewidth=0.9, alpha=0.8)
ax.set_xlabel("Days on Production")
ax.set_ylabel("Water Cut")
ax.set_title("Water Cut Evolution per Well")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Oil rate decline
ax = axes[1]
for well_id, group in df.groupby(well_col):
    ax.plot(group["DAYS_ON_PRODUCTION"],
            group[VOLVE_COLUMNS["bore_oil_vol"]],
            label=f"Well {well_id}", linewidth=0.9, alpha=0.8)
ax.set_xlabel("Days on Production")
ax.set_ylabel("Oil Rate (Sm3/day)")
ax.set_title("Oil Rate Decline per Well")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### 4.2 Pressure & Cumulative Production

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Downhole pressure
ax = axes[0]
for well_id, group in df.groupby(well_col):
    ax.plot(group["DAYS_ON_PRODUCTION"],
            group[VOLVE_COLUMNS["avg_downhole_press"]],
            label=f"Well {well_id}", linewidth=0.9, alpha=0.8)
ax.set_xlabel("Days on Production")
ax.set_ylabel("Downhole Pressure (bar)")
ax.set_title("Average Downhole Pressure")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Cumulative liquid
ax = axes[1]
for well_id, group in df.groupby(well_col):
    ax.plot(group["DAYS_ON_PRODUCTION"], group["CUM_LIQUID"],
            label=f"Well {well_id}", linewidth=0.9, alpha=0.8)
ax.set_xlabel("Days on Production")
ax.set_ylabel("Cumulative Liquid (Sm3)")
ax.set_title("Cumulative Liquid Production")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### 4.3 Breakthrough Label Distribution

In [None]:
bt_counts = df["BREAKTHROUGH"].value_counts().sort_index()
print("Breakthrough label distribution:")
print(f"  Pre-breakthrough  (0): {bt_counts.get(0.0, 0):,}")
print(f"  Post-breakthrough (1): {bt_counts.get(1.0, 0):,}")
print(f"  Ratio: {bt_counts.get(1.0, 0) / len(df):.1%} post-breakthrough")

---
## 5. Dataset Preparation

The `build_dataset` function performs the full pipeline:
1. Feature scaling with `MinMaxScaler`
2. Sliding-window **sequence creation** (length = `SEQ_LENGTH`)
3. **Temporal train/test split** (no shuffling, preserves time order)
4. Conversion to **PyTorch tensors**
5. Preparation of **physics inputs** and **survival targets** (time-to-event, censoring)

In [None]:
dataset = build_dataset(
    data_path=DATA_PATH,
    seq_length=SEQ_LENGTH,
    test_fraction=TEST_FRACTION,
    device=DEVICE,
)

print(f"\nDataset summary:")
print(f"  Training samples:  {dataset['X_train'].shape[0]:,}")
print(f"  Test samples:      {dataset['X_test'].shape[0]:,}")
print(f"  Sequence length:   {dataset['seq_length']}")
print(f"  Input features:    {dataset['n_features']}")
print(f"  Wells:             {dataset['well_names']}")
print(f"  X_train shape:     {tuple(dataset['X_train'].shape)}")
print(f"  y_train shape:     {tuple(dataset['y_train'].shape)}")
print(f"  tte_train shape:   {tuple(dataset['tte_train'].shape)}")
print(f"  event_train shape: {tuple(dataset['event_train'].shape)}")

---
## 6. Model Architecture

The `PhysicsInformedBreakthroughModel` has the following structure:

```
Input (batch, seq_len, n_features)
  |-> LSTM Temporal Encoder -> h (batch, hidden_size)
       |
       |-> Physics Branch: h -> Saturation MLP -> S_w -> Corey Rel-Perm -> f_w(physics)
       |-> Data Branch:    h -> MLP -> f_w(data)
       |-> Gate Network:   h -> sigmoid gate alpha
       |
       |-> Blended water_cut = alpha * f_w(physics) + (1-alpha) * f_w(data)
       |-> Breakthrough Classifier: h -> logit
       |-> Survival Head: h -> (mu, sigma) for LogNormal time-to-breakthrough
```

Learnable physics parameters: `S_wc`, `S_or`, `kr_w_max`, `kr_o_max`, `n_w`, `n_o`,
`mu_w`, `mu_o`, `porosity`.

In [None]:
# Instantiate the model to inspect its architecture
model_preview = PhysicsInformedBreakthroughModel(
    n_features=dataset["n_features"],
    hidden_size=config.hidden_size,
    lstm_layers=config.lstm_layers,
    dropout=config.dropout,
).to(DEVICE)

total_params = sum(p.numel() for p in model_preview.parameters())
trainable_params = sum(p.numel() for p in model_preview.parameters() if p.requires_grad)
physics_params = sum(
    p.numel() for n, p in model_preview.named_parameters()
    if any(k in n for k in ["rel_perm", "log_mu", "log_porosity"])
)

print(f"Total parameters:    {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Physics parameters:  {physics_params:,}")
print(f"\nInitial physics parameters:")
for k, v in model_preview.get_physics_parameters().items():
    print(f"  {k:12s}: {v:.4f}")

del model_preview  # free memory before training

---
## 7. Training

The training loop uses a **composite loss function** with six components:

| Loss Term | Weight | Purpose |
|-----------|--------|---------|
| Data MSE | 1.0 | Match observed water cut |
| Physics fit | annealed 0.01 -> 0.1 | Physics branch matches data |
| Monotonicity | 0.05 | Penalize non-physical water cut decreases |
| Breakthrough BCE | 0.3 | Binary breakthrough detection |
| Material balance | 0.05 | Physics/data branch consistency |
| Survival NLL | 0.2 | Log-normal time-to-breakthrough |

Optimization features:
- **AdamW** with separate learning rates (physics params get 0.1x)
- **ReduceLROnPlateau** scheduler
- **Gradient clipping** (max norm = 1.0)
- **Early stopping** (patience = 30 epochs)
- **Physics weight annealing** over first 50 epochs

In [None]:
model, history = train_model(dataset, config, DEVICE)

### 7.1 Training Curves

Inspect total loss, data-only loss, physics weight schedule, and learned parameter evolution.

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Total loss
ax = axes[0, 0]
ax.plot(history["train_loss"], label="Train", linewidth=1.5)
ax.plot(history["val_loss"], label="Validation", linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Total Loss")
ax.set_title("Total Loss (all components)")
ax.legend()
ax.set_yscale("log")
ax.grid(True, alpha=0.3)

# Data loss
ax = axes[0, 1]
ax.plot(history["train_data_loss"], label="Train", linewidth=1.5)
ax.plot(history["val_data_loss"], label="Validation", linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE")
ax.set_title("Data Loss (Water Cut MSE)")
ax.legend()
ax.set_yscale("log")
ax.grid(True, alpha=0.3)

# Physics weight annealing
ax = axes[1, 0]
ax.plot(history["physics_weight"], linewidth=1.5, color="green")
ax.set_xlabel("Epoch")
ax.set_ylabel("Weight")
ax.set_title("Physics Loss Weight (Annealing Schedule)")
ax.grid(True, alpha=0.3)

# Learned physics parameters
ax = axes[1, 1]
params_hist = history.get("learned_params", [])
if params_hist:
    epochs = range(len(params_hist))
    ax.plot(epochs, [p["s_wc"] for p in params_hist], label="S_wc")
    ax.plot(epochs, [p["s_or"] for p in params_hist], label="S_or")
    ax.plot(epochs, [p["n_w"] / 6 for p in params_hist], label="n_w/6")
    ax.plot(epochs, [p["n_o"] / 6 for p in params_hist], label="n_o/6")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Value")
    ax.set_title("Learned Physics Parameters")
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---
## 8. Evaluation Metrics

Evaluate the trained model on the held-out **test set** (temporal split).

Metrics computed:
- **Regression**: MAE, RMSE, R-squared (scaled and original)
- **Classification**: Accuracy, Precision, Recall, F1 for breakthrough detection
- **Survival**: P10/P50/P90 percentile statistics and 80% calibration interval
- **Physics gate**: Mean gate value (1 = physics-dominated, 0 = data-dominated)

In [None]:
metrics = evaluate_model(model, dataset, DEVICE)
print_metrics(metrics)

---
## 9. Prediction Visualizations

Compare actual vs predicted water cut on the test set, and inspect how the
physics-data gating mechanism allocates predictions.

In [None]:
model.eval()
with torch.no_grad():
    outputs = model(dataset["X_test"].to(DEVICE))

y_true = dataset["y_test"].cpu().numpy().flatten()
y_pred = outputs["water_cut"].cpu().numpy().flatten()
y_physics = outputs["water_cut_physics"].cpu().numpy().flatten()
y_data = outputs["water_cut_data"].cpu().numpy().flatten()
gate = outputs["gate_value"].cpu().numpy().flatten()

### 9.1 Water Cut: Actual vs Predicted Time Series

In [None]:
n_plot = min(500, len(y_true))

fig, ax = plt.subplots(figsize=(14, 5))
ax.plot(range(n_plot), y_true[:n_plot], label="Actual", linewidth=1.2, alpha=0.8)
ax.plot(range(n_plot), y_pred[:n_plot], label="Predicted (blended)",
        linewidth=1.2, alpha=0.8)
ax.plot(range(n_plot), y_physics[:n_plot], label="Physics only",
        linewidth=0.8, alpha=0.5, linestyle="--")
ax.plot(range(n_plot), y_data[:n_plot], label="Data only",
        linewidth=0.8, alpha=0.5, linestyle=":")
ax.set_xlabel("Sample Index")
ax.set_ylabel("Water Cut (scaled)")
ax.set_title("Test Set: Actual vs Predicted Water Cut")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### 9.2 Scatter Plot & Residuals

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Scatter
ax = axes[0]
ax.scatter(y_true, y_pred, alpha=0.3, s=10)
ax.plot([0, 1], [0, 1], "r--", linewidth=1.5, label="Perfect")
ax.set_xlabel("Actual Water Cut")
ax.set_ylabel("Predicted Water Cut")
ax.set_title(f"Prediction Scatter (R² = {metrics['r2']:.4f})")
ax.legend()
ax.grid(True, alpha=0.3)

# Residual distribution
ax = axes[1]
residuals = y_pred - y_true
ax.hist(residuals, bins=50, alpha=0.7, edgecolor="black", linewidth=0.5)
ax.axvline(x=0, color="red", linestyle="--")
ax.set_xlabel("Residual (Predicted - Actual)")
ax.set_ylabel("Count")
ax.set_title(f"Residual Distribution (mean={residuals.mean():.4f}, std={residuals.std():.4f})")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### 9.3 Physics-Data Gate & Breakthrough Detection

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Gate
ax = axes[0]
ax.plot(range(n_plot), gate[:n_plot], linewidth=1, color="purple")
ax.set_xlabel("Sample Index")
ax.set_ylabel("Gate Value")
ax.set_title("Physics vs Data Gate (1=Physics, 0=Data)")
ax.set_ylim(-0.05, 1.05)
ax.grid(True, alpha=0.3)

# Breakthrough
ax = axes[1]
bt_true = dataset["bt_test"].cpu().numpy()
bt_logit = outputs["breakthrough_logit"].cpu().numpy().flatten()
bt_prob = 1 / (1 + np.exp(-bt_logit))
ax.plot(range(n_plot), bt_true[:n_plot], label="Actual", linewidth=1.2, alpha=0.8)
ax.plot(range(n_plot), bt_prob[:n_plot], label="Predicted probability",
        linewidth=1.2, alpha=0.8)
ax.axhline(y=0.5, color="r", linestyle="--", alpha=0.5, label="Decision boundary")
ax.set_xlabel("Sample Index")
ax.set_ylabel("Breakthrough")
ax.set_title("Breakthrough Detection")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---
## 10. Physics Interpretation

One of the key advantages of a PINN is **interpretability**: the model learns physically
meaningful reservoir parameters. Here we visualize the learned Buckley-Leverett
fractional flow curve and Corey relative permeability.

### 10.1 Learned Reservoir Parameters

In [None]:
params = model.get_physics_parameters()

print("Learned Reservoir Parameters:")
print(f"  Connate water saturation (S_wc):  {params['s_wc']:.4f}")
print(f"  Residual oil saturation (S_or):   {params['s_or']:.4f}")
print(f"  Max water rel-perm (kr_w_max):    {params['kr_w_max']:.4f}")
print(f"  Max oil rel-perm (kr_o_max):      {params['kr_o_max']:.4f}")
print(f"  Corey water exponent (n_w):       {params['n_w']:.4f}")
print(f"  Corey oil exponent (n_o):         {params['n_o']:.4f}")
print(f"  Water viscosity (mu_w):           {params['mu_w']:.4f} cP")
print(f"  Oil viscosity (mu_o):             {params['mu_o']:.4f} cP")
print(f"  Mobility ratio (M):               {params['mu_w']/params['mu_o']:.4f}")
print(f"  Porosity:                         {params['porosity']:.4f}")

### 10.2 Fractional Flow Curve & Welge Tangent

The Buckley-Leverett fractional flow curve `f_w(S_w)` determines the displacement
efficiency. The Welge tangent from `(S_wc, 0)` identifies the shock front saturation
and breakthrough time in pore volumes.

In [None]:
# Analytical solution with learned parameters
analytical = compute_breakthrough_time_analytical(
    s_wc=params["s_wc"], s_or=params["s_or"],
    mu_w=params["mu_w"], mu_o=params["mu_o"],
    n_w=params["n_w"], n_o=params["n_o"],
)

print(f"Analytical Buckley-Leverett Solution:")
print(f"  Breakthrough time: {analytical['breakthrough_pv']:.3f} pore volumes")
print(f"  Front saturation:  {analytical['s_w_front']:.3f}")
print(f"  Avg saturation behind front: {analytical['s_w_avg_behind_front']:.3f}")

In [None]:
# Compute relative permeability from model
model.eval()
s_w_tensor = torch.linspace(params["s_wc"], 1 - params["s_or"], 200).unsqueeze(1)
with torch.no_grad():
    f_w_learned = model.physics_branch.fractional_flow(s_w_tensor)
    kr_w, kr_o = model.physics_branch.rel_perm(s_w_tensor)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Fractional flow
ax = axes[0]
ax.plot(analytical["saturation_grid"], analytical["fractional_flow_curve"],
        "b-", linewidth=2, label="f_w(S_w)")
ax.axvline(x=analytical["s_w_front"], color="r", linestyle="--", alpha=0.7,
           label=f"Front S_wf={analytical['s_w_front']:.3f}")
# Welge tangent
slope = analytical["f_w_front"] / (analytical["s_w_front"] - params["s_wc"])
s_tang = np.linspace(params["s_wc"], analytical["s_w_front"], 50)
ax.plot(s_tang, slope * (s_tang - params["s_wc"]), "g--", linewidth=1.5,
        label="Welge tangent")
ax.set_xlabel("Water Saturation S_w")
ax.set_ylabel("Fractional Flow f_w")
ax.set_title("Buckley-Leverett Fractional Flow")
ax.legend()
ax.grid(True, alpha=0.3)

# Relative permeability
ax = axes[1]
s_np = s_w_tensor.numpy().flatten()
ax.plot(s_np, kr_w.numpy().flatten(), "b-", linewidth=2, label="kr_w")
ax.plot(s_np, kr_o.numpy().flatten(), "r-", linewidth=2, label="kr_o")
ax.set_xlabel("Water Saturation S_w")
ax.set_ylabel("Relative Permeability")
ax.set_title(f"Corey Rel-Perm (n_w={params['n_w']:.2f}, n_o={params['n_o']:.2f})")
ax.legend()
ax.grid(True, alpha=0.3)

# df/dS
ax = axes[2]
ax.plot(analytical["saturation_grid"], analytical["df_ds"], "b-", linewidth=2)
ax.axvline(x=analytical["s_w_front"], color="r", linestyle="--", alpha=0.7,
           label=f"BT at {analytical['breakthrough_pv']:.2f} PV")
ax.set_xlabel("Water Saturation S_w")
ax.set_ylabel("df_w/dS_w")
ax.set_title("Fractional Flow Derivative (Shock Speed)")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---
## 11. Survival Analysis

The log-normal survival head predicts a **distribution** over time-to-breakthrough for
each sample, providing:
- **P10** (optimistic/early estimate)
- **P50** (median estimate)
- **P90** (conservative/late estimate)

This is critical for risk-based decision-making in reservoir management.

### 11.1 Percentile Predictions with Uncertainty Bands

In [None]:
percentiles = outputs["survival_percentiles"]
p10 = percentiles["P10"].cpu().numpy().flatten()
p50 = percentiles["P50"].cpu().numpy().flatten()
p90 = percentiles["P90"].cpu().numpy().flatten()

tte_true = dataset["tte_test"].cpu().numpy().flatten()
event_true = dataset["event_test"].cpu().numpy().flatten()

fig, ax = plt.subplots(figsize=(14, 5))
n_surv = min(500, len(p50))
x_range = range(n_surv)

ax.fill_between(x_range, p10[:n_surv], p90[:n_surv],
                alpha=0.3, color="steelblue", label="P10-P90 interval")
ax.plot(x_range, p50[:n_surv], color="steelblue", linewidth=1.5, label="P50 (median)")

observed = event_true[:n_surv] == 1.0
if observed.any():
    ax.scatter(np.where(observed)[0], tte_true[:n_surv][observed],
              s=8, color="red", alpha=0.6, label="Observed BT time", zorder=5)

ax.set_xlabel("Sample Index")
ax.set_ylabel("Time to Breakthrough (normalized)")
ax.set_title("Predicted Time-to-Breakthrough with Uncertainty")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nSurvival Percentile Summary (test set):")
print(f"  P10 mean: {p10.mean():.4f} (+/- {p10.std():.4f})")
print(f"  P50 mean: {p50.mean():.4f} (+/- {p50.std():.4f})")
print(f"  P90 mean: {p90.mean():.4f} (+/- {p90.std():.4f})")
if "survival_calibration_80" in metrics:
    print(f"  80% interval calibration: {metrics['survival_calibration_80']:.1%} (ideal: 80%)")

### 11.2 Survival Curve & Hazard Function

The survival function `S(t) = P(T > t)` shows the probability of **not** having
broken through by time `t`. The hazard function `h(t) = f(t)/S(t)` gives the
instantaneous breakthrough rate.

In [None]:
mu_vals = outputs["survival_mu"].cpu()
sigma_vals = outputs["survival_sigma"].cpu()
mu_mean = mu_vals.mean()
sigma_mean = sigma_vals.mean()

t_grid = torch.linspace(0.01, 2.0, 200)
survival = compute_survival_function(
    mu_mean.unsqueeze(0), sigma_mean.unsqueeze(0), t_grid
).numpy().flatten()
hazard = compute_hazard_function(
    mu_mean.unsqueeze(0), sigma_mean.unsqueeze(0), t_grid
).numpy().flatten()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Survival curve
ax = axes[0]
ax.plot(t_grid.numpy(), survival, color="steelblue", linewidth=2)
for pval, label, ls in [(0.90, "P10", "--"), (0.50, "P50", "-"), (0.10, "P90", "--")]:
    ax.axhline(y=pval, color="gray", linestyle=ls, alpha=0.4)
    idx = np.searchsorted(-survival, -pval)
    if idx < len(t_grid):
        t_pct = t_grid[idx].item()
        ax.axvline(x=t_pct, color="gray", linestyle=ls, alpha=0.4)
        ax.annotate(label, xy=(t_pct, pval), xytext=(t_pct + 0.05, pval + 0.03), fontsize=9)
ax.set_xlabel("Time (normalized)")
ax.set_ylabel("Survival Probability S(t)")
ax.set_title("Mean Survival Curve")
ax.set_ylim(-0.05, 1.05)
ax.grid(True, alpha=0.3)

# Hazard function
ax = axes[1]
ax.plot(t_grid.numpy(), hazard, color="darkred", linewidth=2)
ax.set_xlabel("Time (normalized)")
ax.set_ylabel("Hazard Rate h(t)")
ax.set_title("Mean Hazard Function")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### 11.3 Percentile Distribution Histograms

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(p10, bins=30, alpha=0.5, color="green", label="P10", density=True)
ax.hist(p50, bins=30, alpha=0.5, color="steelblue", label="P50", density=True)
ax.hist(p90, bins=30, alpha=0.5, color="orange", label="P90", density=True)

observed_tte = tte_true[event_true == 1.0]
if len(observed_tte) > 0:
    ax.hist(observed_tte, bins=30, alpha=0.4, color="red",
            label="Observed BT times", density=True)

ax.set_xlabel("Time to Breakthrough (normalized)")
ax.set_ylabel("Density")
ax.set_title("Percentile Distributions vs Observed")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

---
## 12. Save & Export

Save all plots to the `results/` directory and export the final metrics report.

In [None]:
# Save all standard plots via the evaluate module
plot_training_history(history, SAVE_DIR)
plot_predictions(model, dataset, SAVE_DIR, DEVICE)
plot_fractional_flow_curve(model, SAVE_DIR)
plot_survival_analysis(model, dataset, SAVE_DIR, DEVICE)

# Save metrics to text file
from pathlib import Path
save_path = Path(SAVE_DIR)
save_path.mkdir(parents=True, exist_ok=True)
with open(save_path / "metrics.txt", "w") as f:
    f.write("Water Breakthrough Prediction - Evaluation Report\n")
    f.write("=" * 50 + "\n\n")
    for key, value in metrics.items():
        if isinstance(value, dict):
            f.write(f"\n{key}:\n")
            for k, v in value.items():
                f.write(f"  {k}: {v}\n")
        else:
            f.write(f"{key}: {value}\n")

print(f"\nAll outputs saved to: {SAVE_DIR}")
print(f"Model checkpoint:     {config.model_save_path}")

### Final Summary

In [None]:
print("=" * 60)
print("  FINAL SUMMARY")
print("=" * 60)
print(f"\n  Water Cut Prediction:")
print(f"    MAE:  {metrics['mae']:.4f}")
print(f"    RMSE: {metrics['rmse']:.4f}")
print(f"    R²:   {metrics['r2']:.4f}")
if "bt_f1" in metrics:
    print(f"\n  Breakthrough Detection:")
    print(f"    F1:       {metrics['bt_f1']:.4f}")
    print(f"    Accuracy: {metrics['bt_accuracy']:.4f}")
print(f"\n  Physics-Data Gate: {metrics['gate_mean']:.3f} (1=physics, 0=data)")
if "survival_p50_mean" in metrics:
    print(f"\n  Survival Percentiles:")
    print(f"    P10: {metrics['survival_p10_mean']:.4f}")
    print(f"    P50: {metrics['survival_p50_mean']:.4f}")
    print(f"    P90: {metrics['survival_p90_mean']:.4f}")
    if "survival_calibration_80" in metrics:
        print(f"    Calibration (80%): {metrics['survival_calibration_80']:.1%}")
print(f"\n  Learned Physics:")
p = metrics["physics_params"]
print(f"    S_wc={p['s_wc']:.3f}  S_or={p['s_or']:.3f}  "
      f"n_w={p['n_w']:.2f}  n_o={p['n_o']:.2f}  "
      f"M={p['mu_w']/p['mu_o']:.3f}")
print("=" * 60)