# Train RGT-Only Model (Independent Pipeline)

This notebook trains only the **RGT prediction model** (seismic ➜ RGT) and keeps it independent from any future fault+RGT multitask model.

## What this notebook includes
- GPU-first setup with explicit device message.
- Dataset checks and sample inspection.
- Train/validation/test splits and DataLoaders.
- RGT model training with progress bars and timing.
- Validation metrics and checkpointing.
- Final test evaluation and reproducibility artifacts.

## 1) Set Up Environment and Configuration

In [None]:
import os
import json
import time
import random
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import utils
from models import net3d
from data.dataloader import Dataset
from data.augments import Reshape, ToTensor
from lossf.loss import mse3DLoss, ssim3DLoss
from lossf.metrics import Result

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

ROOT = Path.cwd()
DATA_ROOT = ROOT / "datasets" / "syn"
SEIS_DIR = DATA_ROOT / "seis"
RGT_DIR = DATA_ROOT / "rgt"

session_name = f"rgt_only_{datetime.now().strftime('%b%d_%H%M%S')}_Train"
session_path = ROOT / "sessions" / session_name
checkpoint_path = session_path / "checkpoint"
history_path = session_path / "history"
picture_path = session_path / "picture"
for p in [session_path, checkpoint_path, history_path, picture_path]:
    p.mkdir(parents=True, exist_ok=True)

CFG = {
    # Paper-inspired defaults
    "shape": (256, 256, 128),
    "n_channels": 1,
    "batch_size": 1,
    "epochs": 400,
    "lr": 8e-4,
    "weight_decay": 1e-4,
    "lr_factor": 0.5,
    "lr_patience": 2,
    "loss_type": "SSIM",  # 'SSIM' or 'MSE'

    # Runtime
    "num_workers": 2,
    "encoder_channels": 512,
    "decoder_channels": 16,
    "pin_memory": True,
    "mixed_precision": True,
    "grad_clip": None,

    # Splits
    "train_ratio": 0.8,
    "val_ratio": 0.1,
    "test_ratio": 0.1,

    # Optional cap for quick tests
    "max_samples": None,
    "use_augmentation": True,

    # Save intervals
    "save_every": 10,
}

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
if use_cuda:
    print(f"[GPU MODE] Training will run on: {torch.cuda.get_device_name(0)}")
else:
    print("[CPU MODE] CUDA not available. Training will run on CPU.")

print("Session:", session_path)
print("Data root:", DATA_ROOT)
print("Config:")
print(json.dumps(CFG, indent=2))

  from .autonotebook import tqdm as notebook_tqdm


[GPU MODE] Training will run on: NVIDIA RTX A5000
Session: /home/roderickperez/DataScienceProjects/RGT_Net/sessions/rgt_only_Feb18_125329_Train
Data root: /home/roderickperez/DataScienceProjects/RGT_Net/datasets/syn
Config:
{
  "shape": [
    256,
    256,
    128
  ],
  "n_channels": 1,
  "batch_size": 1,
  "epochs": 400,
  "lr": 0.0008,
  "weight_decay": 0.0001,
  "lr_factor": 0.5,
  "lr_patience": 2,
  "loss_type": "MSE",
  "num_workers": 2,
  "encoder_channels": 512,
  "decoder_channels": 16,
  "pin_memory": true,
  "mixed_precision": true,
  "grad_clip": null,
  "train_ratio": 0.8,
  "val_ratio": 0.1,
  "test_ratio": 0.1,
  "max_samples": null,
  "save_every": 10
}


## 2) Load and Inspect RGT-Only Training Data

In [2]:
assert SEIS_DIR.exists(), f"Missing seismic directory: {SEIS_DIR}"
assert RGT_DIR.exists(), f"Missing RGT directory: {RGT_DIR}"

seis_files = sorted([f.name for f in SEIS_DIR.iterdir() if f.is_file()])
rgt_files = sorted([f.name for f in RGT_DIR.iterdir() if f.is_file()])
common_files = sorted(list(set(seis_files).intersection(set(rgt_files))))

print(f"Seis files: {len(seis_files)}")
print(f"RGT files:  {len(rgt_files)}")
print(f"Matched pairs: {len(common_files)}")

missing_in_rgt = sorted(list(set(seis_files) - set(rgt_files)))
missing_in_seis = sorted(list(set(rgt_files) - set(seis_files)))
if missing_in_rgt:
    print(f"Warning: {len(missing_in_rgt)} seismic files missing in rgt")
if missing_in_seis:
    print(f"Warning: {len(missing_in_seis)} rgt files missing in seis")

if CFG["max_samples"] is not None:
    common_files = common_files[:CFG["max_samples"]]

print("Using sample count:", len(common_files))

# Quick sample stats
n1, n2, n3 = CFG["shape"]
sample_id = random.choice(common_files)
seis_raw = np.fromfile(SEIS_DIR / sample_id, dtype=np.float32).reshape((n1, n2, n3, CFG["n_channels"]))[..., 0]
rgt_raw = np.fromfile(RGT_DIR / sample_id, dtype=np.float32).reshape((n1, n2, n3, CFG["n_channels"]))[..., 0]

print("Random sample:", sample_id)
print("Seis stats  min/max/mean/std:", float(seis_raw.min()), float(seis_raw.max()), float(seis_raw.mean()), float(seis_raw.std()))
print("RGT stats   min/max/mean/std:", float(rgt_raw.min()), float(rgt_raw.max()), float(rgt_raw.mean()), float(rgt_raw.std()))

mid_i, mid_x, mid_t = n1 // 2, n2 // 2, n3 // 2
fig, ax = plt.subplots(2, 3, figsize=(14, 8))
ax[0, 0].imshow(seis_raw[mid_i, :, :], cmap="gray", aspect="auto"); ax[0, 0].set_title("Seis Inline(mid)")
ax[0, 1].imshow(seis_raw[:, mid_x, :], cmap="gray", aspect="auto"); ax[0, 1].set_title("Seis Xline(mid)")
ax[0, 2].imshow(seis_raw[:, :, mid_t], cmap="gray", aspect="auto"); ax[0, 2].set_title("Seis Time(mid)")
ax[1, 0].imshow(rgt_raw[mid_i, :, :], cmap="jet", aspect="auto"); ax[1, 0].set_title("RGT Inline(mid)")
ax[1, 1].imshow(rgt_raw[:, mid_x, :], cmap="jet", aspect="auto"); ax[1, 1].set_title("RGT Xline(mid)")
ax[1, 2].imshow(rgt_raw[:, :, mid_t], cmap="jet", aspect="auto"); ax[1, 2].set_title("RGT Time(mid)")
for a in ax.ravel():
    a.axis("off")
plt.tight_layout()
plt.show()

Seis files: 500
RGT files:  500
Matched pairs: 500
Using sample count: 500
Random sample: 393.dat
Seis stats  min/max/mean/std: -7.617023944854736 8.524500846862793 0.0018466644687578082 1.166604995727539
RGT stats   min/max/mean/std: 19.889814376831055 170.3407745361328 91.55370330810547 30.695148468017578


  plt.show()


## 3) Build Dataset and DataLoader Pipelines

In [None]:
from sklearn.model_selection import train_test_split

all_ids = common_files.copy()

train_ratio = CFG["train_ratio"]
val_ratio = CFG["val_ratio"]
test_ratio = CFG["test_ratio"]
assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-8, "Ratios must sum to 1.0"

train_ids, temp_ids = train_test_split(all_ids, test_size=(1.0 - train_ratio), random_state=SEED, shuffle=True)
val_rel = val_ratio / (val_ratio + test_ratio)
val_ids, test_ids = train_test_split(temp_ids, test_size=(1.0 - val_rel), random_state=SEED, shuffle=True)

print(f"Train/Val/Test: {len(train_ids)}/{len(val_ids)}/{len(test_ids)}")

class AugmentedPairDataset:
    def __init__(self, base_ds, augment=False):
        self.base_ds = base_ds
        self.augment = augment

    def __len__(self):
        return len(self.base_ds)

    def __getitem__(self, index):
        x, y = self.base_ds[index]
        if self.augment:
            # shape is (C, D, H, W); protect D (depth/time)
            # Random flips on spatial axes only (H, W)
            if random.random() < 0.5:
                x = torch.flip(x, dims=[2])
                y = torch.flip(y, dims=[2])
            if random.random() < 0.5:
                x = torch.flip(x, dims=[3])
                y = torch.flip(y, dims=[3])

            # Random 90-degree in-plane rotation (H, W)
            k = random.randint(0, 3)
            if k > 0:
                x = torch.rot90(x, k=k, dims=(2, 3))
                y = torch.rot90(y, k=k, dims=(2, 3))

        return x.contiguous(), y.contiguous()

transform = transforms.Compose([
    Reshape((CFG["shape"][0], CFG["shape"][1], CFG["shape"][2], CFG["n_channels"])),
    ToTensor(),
])

train_base_ds = Dataset(root_dir=str(DATA_ROOT), list_IDs=train_ids, transform=transform, only_load_input=False)
val_ds = Dataset(root_dir=str(DATA_ROOT), list_IDs=val_ids, transform=transform, only_load_input=False)
test_ds = Dataset(root_dir=str(DATA_ROOT), list_IDs=test_ids, transform=transform, only_load_input=False)
train_ds = AugmentedPairDataset(train_base_ds, augment=CFG["use_augmentation"])

train_loader = DataLoader(
    train_ds,
    batch_size=CFG["batch_size"],
    shuffle=True,
    num_workers=CFG["num_workers"],
    pin_memory=CFG["pin_memory"] and use_cuda,
)
val_loader = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=CFG["num_workers"],
    pin_memory=CFG["pin_memory"] and use_cuda,
)
test_loader = DataLoader(
    test_ds,
    batch_size=1,
    shuffle=False,
    num_workers=CFG["num_workers"],
    pin_memory=CFG["pin_memory"] and use_cuda,
)

print(f"DataLoaders ready. Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")

Train/Val/Test: 400/50/50
DataLoaders ready. Train batches: 400, Val batches: 50, Test batches: 50


## 4) Define the RGT Model Architecture

In [4]:
param_model = {
    "input_channels": 1,
    "encoder_channels": 512,
    "decoder_channels": 16,
}
model = net3d.model(param_model)

if use_cuda:
    num_gpu = torch.cuda.device_count()
    model = torch.nn.DataParallel(model, device_ids=range(num_gpu)).to(device)
else:
    model = model.to(device)

num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: total={num_params:,}, trainable={num_trainable:,}")

# Forward sanity check
xb, yb = next(iter(train_loader))
xb = xb.to(device)
with torch.no_grad():
    pred = model(xb)
print("Input shape:", tuple(xb.shape), "Pred shape:", tuple(pred.shape))

Model parameters: total=56,641,012, trainable=56,641,012
Input shape: (1, 1, 128, 256, 256) Pred shape: (1, 1, 128, 256, 256)


## 5) Configure Loss, Optimizer, and Scheduler

In [5]:
if CFG["loss_type"].upper() == "SSIM":
    criterion = ssim3DLoss()
else:
    criterion = mse3DLoss()

optimizer = optim.Adam(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
scheduler = lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=CFG["lr_factor"],
    patience=CFG["lr_patience"],
)

scaler = torch.cuda.amp.GradScaler(enabled=(use_cuda and CFG["mixed_precision"]))

print(f"Loss: {criterion.getLossName()}")
print(f"Optimizer: Adam(lr={CFG['lr']}, weight_decay={CFG['weight_decay']})")
print(f"Scheduler: ReduceLROnPlateau(factor={CFG['lr_factor']}, patience={CFG['lr_patience']})")
print(f"Mixed precision enabled: {scaler.is_enabled()}")

Loss: MSE
Optimizer: Adam(lr=0.0008, weight_decay=0.0001)
Scheduler: ReduceLROnPlateau(factor=0.5, patience=2)
Mixed precision enabled: True


  scaler = torch.cuda.amp.GradScaler(enabled=(use_cuda and CFG["mixed_precision"]))


## 6) Implement the RGT Training Loop

In [6]:
def init_regression_meter():
    return {"abs_sum": 0.0, "sq_sum": 0.0, "count": 0}


def update_regression_meter(meter, y_true, y_pred):
    diff = (y_true - y_pred).detach()
    meter["abs_sum"] += float(diff.abs().sum().item())
    meter["sq_sum"] += float((diff * diff).sum().item())
    meter["count"] += int(diff.numel())


def finalize_regression_meter(meter):
    if meter["count"] == 0:
        return {"mae": np.nan, "rmse": np.nan}
    mae = meter["abs_sum"] / meter["count"]
    rmse = float(np.sqrt(meter["sq_sum"] / meter["count"]))
    return {"mae": float(mae), "rmse": float(rmse)}


def train_one_epoch(model, loader, criterion, optimizer, scaler, device, epoch_idx, total_epochs):
    model.train()
    t0 = time.time()
    running_loss = 0.0
    meter = init_regression_meter()

    pbar = tqdm(loader, desc=f"Epoch {epoch_idx+1}/{total_epochs} [Train]", dynamic_ncols=True)
    for batch_idx, (seis, rgt) in enumerate(pbar):
        seis = seis.to(device, non_blocking=True)
        rgt = rgt.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type=device.type, enabled=scaler.is_enabled()):
            pred = model(seis)
            loss = criterion(pred, rgt)

        try:
            scaler.scale(loss).backward()
        except torch.cuda.OutOfMemoryError as e:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            raise RuntimeError("CUDA OOM during backward. Reduce CFG['batch_size'], switch CFG['loss_type'] to 'MSE', or reduce CFG['max_samples'] for a smoke test.") from e

        if CFG["grad_clip"] is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])

        scaler.step(optimizer)
        scaler.update()

        loss_value = float(loss.detach().cpu())
        running_loss += loss_value

        update_regression_meter(meter, rgt, pred)

        avg_loss = running_loss / (batch_idx + 1)
        current_lr = optimizer.param_groups[0]["lr"]
        elapsed = time.time() - t0
        pbar.set_postfix(loss=f"{loss_value:.4f}", avg=f"{avg_loss:.4f}", lr=f"{current_lr:.2e}", sec=f"{elapsed:.1f}")

    epoch_loss = running_loss / max(len(loader), 1)
    metrics = finalize_regression_meter(meter)
    metrics["loss"] = float(epoch_loss)
    metrics["time_sec"] = float(time.time() - t0)
    return metrics


def validate_one_epoch(model, loader, criterion, device, epoch_idx, total_epochs):
    model.eval()
    t0 = time.time()
    running_loss = 0.0
    meter = init_regression_meter()

    pbar = tqdm(loader, desc=f"Epoch {epoch_idx+1}/{total_epochs} [Val]", dynamic_ncols=True)
    with torch.no_grad():
        for batch_idx, (seis, rgt) in enumerate(pbar):
            seis = seis.to(device, non_blocking=True)
            rgt = rgt.to(device, non_blocking=True)

            pred = model(seis)
            loss = criterion(pred, rgt)

            loss_value = float(loss.detach().cpu())
            running_loss += loss_value

            update_regression_meter(meter, rgt, pred)

            avg_loss = running_loss / (batch_idx + 1)
            elapsed = time.time() - t0
            pbar.set_postfix(loss=f"{loss_value:.4f}", avg=f"{avg_loss:.4f}", sec=f"{elapsed:.1f}")

    epoch_loss = running_loss / max(len(loader), 1)
    metrics = finalize_regression_meter(meter)
    metrics["loss"] = float(epoch_loss)
    metrics["time_sec"] = float(time.time() - t0)
    return metrics

## 7) Run Validation and Track Metrics

In [7]:
history = []
best_val_loss = float("inf")
best_epoch = -1
best_ckpt_path = checkpoint_path / "best_rgt_only.pth"

train_start = time.time()
for epoch in range(CFG["epochs"]):
    train_stats = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device, epoch, CFG["epochs"])
    val_stats = validate_one_epoch(model, val_loader, criterion, device, epoch, CFG["epochs"])

    scheduler.step(val_stats["loss"])
    current_lr = optimizer.param_groups[0]["lr"]

    row = {
        "epoch": epoch,
        "lr": current_lr,
        "train_loss": train_stats["loss"],
        "train_mae": train_stats["mae"],
        "train_rmse": train_stats["rmse"],
        "val_loss": val_stats["loss"],
        "val_mae": val_stats["mae"],
        "val_rmse": val_stats["rmse"],
        "train_time_sec": train_stats["time_sec"],
        "val_time_sec": val_stats["time_sec"],
        "elapsed_total_sec": time.time() - train_start,
    }
    history.append(row)

    print(
        f"Epoch {epoch+1}/{CFG['epochs']} | "
        f"train_loss={row['train_loss']:.5f}, val_loss={row['val_loss']:.5f}, "
        f"train_rmse={row['train_rmse']:.5f}, val_rmse={row['val_rmse']:.5f}, "
        f"lr={row['lr']:.2e}, epoch_time={(row['train_time_sec'] + row['val_time_sec']):.1f}s"
    )

    if row["val_loss"] < best_val_loss:
        best_val_loss = row["val_loss"]
        best_epoch = epoch
        save_obj = {
            "epoch": epoch,
            "model_state_dict": model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "config": CFG,
            "best_val_loss": best_val_loss,
            "session_name": session_name,
        }
        torch.save(save_obj, best_ckpt_path)
        print(f"  -> Saved new best checkpoint: {best_ckpt_path} (val_loss={best_val_loss:.6f})")

    if (epoch + 1) % CFG["save_every"] == 0:
        latest_ckpt_path = checkpoint_path / f"epoch_{epoch+1:04d}.pth"
        save_obj = {
            "epoch": epoch,
            "model_state_dict": model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "config": CFG,
            "session_name": session_name,
        }
        torch.save(save_obj, latest_ckpt_path)
        print(f"  -> Saved periodic checkpoint: {latest_ckpt_path}")

print(f"Training complete. Best epoch: {best_epoch+1}, best val_loss={best_val_loss:.6f}")

Epoch 1/400 [Train]: 100%|██████████| 400/400 [15:54<00:00,  2.39s/it, avg=0.2835, loss=0.0648, lr=8.00e-04, sec=954.6]
Epoch 1/400 [Val]: 100%|██████████| 50/50 [01:15<00:00,  1.50s/it, avg=0.2863, loss=0.2925, sec=75.2]


Epoch 1/400 | train_loss=0.28349, val_loss=0.28635, train_rmse=0.53244, val_rmse=0.53511, lr=8.00e-04, epoch_time=1029.8s
  -> Saved new best checkpoint: /home/roderickperez/DataScienceProjects/RGT_Net/sessions/rgt_only_Feb18_125329_Train/checkpoint/best_rgt_only.pth (val_loss=0.286348)


Epoch 2/400 [Train]: 100%|██████████| 400/400 [16:32<00:00,  2.48s/it, avg=0.2424, loss=0.1486, lr=8.00e-04, sec=992.7]
Epoch 2/400 [Val]: 100%|██████████| 50/50 [01:15<00:00,  1.50s/it, avg=0.2489, loss=0.2380, sec=75.2]


Epoch 2/400 | train_loss=0.24236, val_loss=0.24887, train_rmse=0.49230, val_rmse=0.49887, lr=8.00e-04, epoch_time=1068.0s
  -> Saved new best checkpoint: /home/roderickperez/DataScienceProjects/RGT_Net/sessions/rgt_only_Feb18_125329_Train/checkpoint/best_rgt_only.pth (val_loss=0.248875)


Epoch 3/400 [Train]: 100%|██████████| 400/400 [16:52<00:00,  2.53s/it, avg=0.2313, loss=0.2503, lr=8.00e-04, sec=1012.3]
Epoch 3/400 [Val]: 100%|██████████| 50/50 [01:14<00:00,  1.49s/it, avg=0.2533, loss=0.3748, sec=74.7]


Epoch 3/400 | train_loss=0.23126, val_loss=0.25330, train_rmse=0.48090, val_rmse=0.50329, lr=8.00e-04, epoch_time=1087.0s


Epoch 4/400 [Train]:  18%|█▊        | 71/400 [03:04<14:13,  2.59s/it, avg=0.1939, loss=0.2761, lr=8.00e-04, sec=181.5]


KeyboardInterrupt: 

## 8) Save Best Checkpoints and Training Artifacts

In [None]:
history_file = history_path / "train_history_rgt_only.json"
with open(history_file, "w", encoding="utf-8") as f:
    json.dump(history, f, indent=2)

run_info = {
    "session_name": session_name,
    "best_epoch": best_epoch,
    "best_val_loss": best_val_loss,
    "best_checkpoint": str(best_ckpt_path),
    "config": CFG,
    "seed": SEED,
    "device": str(device),
}
run_info_file = history_path / "run_info_rgt_only.json"
with open(run_info_file, "w", encoding="utf-8") as f:
    json.dump(run_info, f, indent=2)

print("Saved:")
print(" -", history_file)
print(" -", run_info_file)
print(" -", best_ckpt_path)

# Quick curves
if len(history) > 0:
    tr_loss = [x["train_loss"] for x in history]
    va_loss = [x["val_loss"] for x in history]
    tr_rmse = [x["train_rmse"] for x in history]
    va_rmse = [x["val_rmse"] for x in history]

    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    ax[0].plot(tr_loss, label="train")
    ax[0].plot(va_loss, label="val")
    ax[0].set_title("Loss")
    ax[0].legend()

    ax[1].plot(tr_rmse, label="train")
    ax[1].plot(va_rmse, label="val")
    ax[1].set_title("RMSE")
    ax[1].legend()
    plt.tight_layout()
    plt.show()

## 9) Evaluate the Trained RGT Model

In [None]:
ckpt = torch.load(best_ckpt_path, map_location=device)
state = ckpt["model_state_dict"]
if isinstance(model, torch.nn.DataParallel):
    model.module.load_state_dict(state)
else:
    model.load_state_dict(state)
model.eval()

test_loss = 0.0
y_true_list, y_pred_list = [], []
worst_cases = []  # (abs_err_mean, seis, rgt, pred)

with torch.no_grad():
    for seis, rgt in tqdm(test_loader, desc="Test", dynamic_ncols=True):
        seis = seis.to(device)
        rgt = rgt.to(device)
        pred = model(seis)
        loss = criterion(pred, rgt)

        test_loss += float(loss.detach().cpu())
        y_true_np = rgt.detach().cpu().numpy()
        y_pred_np = pred.detach().cpu().numpy()
        y_true_list.append(y_true_np)
        y_pred_list.append(y_pred_np)

        abs_err = float(np.mean(np.abs(y_true_np - y_pred_np)))
        worst_cases.append((abs_err, seis.detach().cpu().numpy(), y_true_np, y_pred_np))

mean_test_loss = test_loss / max(len(test_loader), 1)
y_true = np.concatenate(y_true_list, axis=0)
y_pred = np.concatenate(y_pred_list, axis=0)
test_mae = float(np.mean(np.abs(y_true - y_pred)))
test_rmse = float(np.sqrt(np.mean((y_true - y_pred) ** 2)))
print(f"Test loss: {mean_test_loss:.6f}")
print(f"Test MAE:  {test_mae:.6f}")
print(f"Test RMSE: {test_rmse:.6f}")

# For regression, confusion matrix is not standard; show a simple binned confusion table instead.
y_true = y_true.ravel()
y_pred = y_pred.ravel()
num_bins = 10
bins = np.linspace(min(y_true.min(), y_pred.min()), max(y_true.max(), y_pred.max()), num_bins + 1)
true_bin = np.clip(np.digitize(y_true, bins) - 1, 0, num_bins - 1)
pred_bin = np.clip(np.digitize(y_pred, bins) - 1, 0, num_bins - 1)
conf = np.zeros((num_bins, num_bins), dtype=np.int64)
for t, p in zip(true_bin, pred_bin):
    conf[t, p] += 1

plt.figure(figsize=(6, 5))
plt.imshow(conf, cmap="magma", aspect="auto")
plt.title("Binned RGT Confusion Table (Regression Proxy)")
plt.xlabel("Predicted bin")
plt.ylabel("True bin")
plt.colorbar()
plt.tight_layout()
plt.show()

# Inspect top-3 worst cases
worst_cases = sorted(worst_cases, key=lambda x: x[0], reverse=True)[:3]
for i, (err, seis_np, rgt_np, pred_np) in enumerate(worst_cases, 1):
    seis_v = seis_np[0, 0]   # (D, H, W)
    rgt_v = rgt_np[0, 0]
    pred_v = pred_np[0, 0]
    mid_d = seis_v.shape[0] // 2

    fig, ax = plt.subplots(1, 3, figsize=(13, 4))
    ax[0].imshow(seis_v[mid_d], cmap="gray", aspect="auto"); ax[0].set_title(f"Worst #{i} Seis")
    ax[1].imshow(rgt_v[mid_d], cmap="jet", aspect="auto"); ax[1].set_title("GT RGT")
    ax[2].imshow(pred_v[mid_d], cmap="jet", aspect="auto"); ax[2].set_title(f"Pred RGT (abs err {err:.4f})")
    for a in ax:
        a.axis("off")
    plt.tight_layout()
    plt.show()

## 10) Export RGT Inference Script and Reproducibility Metadata

In [None]:
inference_script = ROOT / "infer_rgt_only_from_notebook.py"
inference_script.write_text(
'''import os
import torch
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader

from models import net3d
from data.dataloader import Dataset
from data.augments import Reshape, ToTensor

ROOT = os.path.abspath(".")
DATA_ROOT = os.path.join(ROOT, "datasets", "syn")
CKPT = os.path.join(ROOT, "sessions", "''' + session_name + '''", "checkpoint", "best_rgt_only.pth")

shape = (256, 256, 128)
n_channels = 1

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_list = sorted(os.listdir(os.path.join(DATA_ROOT, "seis")))
    transform = transforms.Compose([Reshape((shape[0], shape[1], shape[2], n_channels)), ToTensor()])
    ds = Dataset(root_dir=DATA_ROOT, list_IDs=data_list, transform=transform, only_load_input=True)
    dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)

    model = net3d.model({"input_channels":1, "encoder_channels":512, "decoder_channels":16})
    ckpt = torch.load(CKPT, map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    model = model.to(device).eval()

    with torch.no_grad():
        x = next(iter(dl)).to(device)
        pred = model(x)
    print("Inference OK. Pred shape:", tuple(pred.shape))

if __name__ == "__main__":
    main()
''',
encoding="utf-8"
)

repro = {
    "session_name": session_name,
    "best_checkpoint": str(best_ckpt_path),
    "seed": SEED,
    "config": CFG,
    "torch_version": torch.__version__,
    "cuda_available": torch.cuda.is_available(),
    "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
}
with open(history_path / "reproducibility_rgt_only.json", "w", encoding="utf-8") as f:
    json.dump(repro, f, indent=2)

print("Exported inference script:", inference_script)
print("Saved reproducibility metadata:", history_path / "reproducibility_rgt_only.json")