# Train Joint RGT + Fault Model (Independent Pipeline)

This notebook trains a multitask model using:
- seismic input (`seis`)
- RGT labels (`rgt`)
- fault labels (`fault`)

It is independent from the RGT-only training notebook and keeps both workflows separate.

## 1) Environment and Configuration

In [None]:
import os
import json
import time
import random
from pathlib import Path
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset as TorchDataset, DataLoader
from torch.optim import lr_scheduler
from tqdm.auto import tqdm

from sklearn.model_selection import train_test_split

from models import net3d
from lossf.loss import mse3DLoss, ssim3DLoss

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

ROOT = Path.cwd()
DATA_ROOT = ROOT / "datasets" / "syn"
SEIS_DIR = DATA_ROOT / "seis"
RGT_DIR = DATA_ROOT / "rgt"
FAULT_DIR = DATA_ROOT / "fault"

SESSION_NAME = f"rgt_fault_{datetime.now().strftime('%b%d_%H%M%S')}_Train"
SESSION_PATH = ROOT / "sessions" / SESSION_NAME
CKPT_PATH = SESSION_PATH / "checkpoint"
HISTORY_PATH = SESSION_PATH / "history"
FIG_PATH = SESSION_PATH / "picture"
for p in [SESSION_PATH, CKPT_PATH, HISTORY_PATH, FIG_PATH]:
    p.mkdir(parents=True, exist_ok=True)

CFG = {
    "shape": (256, 256, 128),
    "batch_size": 1,
    "epochs": 200,
    "lr": 8e-4,
    "weight_decay": 1e-4,
    "lr_factor": 0.5,
    "lr_patience": 2,
    "num_workers": 2,
    "encoder_channels": 512,
    "decoder_channels": 16,
    "pin_memory": True,
    "mixed_precision": True,
    "grad_clip": None,
    "loss_rgt": "SSIM",    # SSIM or MSE
    "lambda_rgt": 1.0,
    "lambda_fault": 1.0,
    "train_ratio": 0.8,
    "val_ratio": 0.1,
    "test_ratio": 0.1,
    "max_samples": None,
    "use_augmentation": True,
    "pretrained_rgt_ckpt": None,
    "save_every": 5,
}

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
if use_cuda:
    print(f"[GPU MODE] Using {torch.cuda.get_device_name(0)}")
else:
    print("[CPU MODE] CUDA not available")

def resolve_pretrained_rgt_ckpt(explicit_path):
    if explicit_path:
        p = Path(explicit_path)
        return p if p.is_file() else None

    candidates = []
    candidates += sorted((ROOT / "sessions").glob("*_Train/checkpoint/best_rgt_only.pth"))
    candidates += sorted((ROOT / "checkpoints").glob("*.pth"))

    if not candidates:
        return None
    return max(candidates, key=lambda p: p.stat().st_mtime)


PRETRAINED_RGT_CKPT = resolve_pretrained_rgt_ckpt(CFG["pretrained_rgt_ckpt"])
if PRETRAINED_RGT_CKPT is not None:
    print(f"Using pretrained RGT checkpoint: {PRETRAINED_RGT_CKPT}")
else:
    print("No pretrained RGT checkpoint found. Joint model will train from scratch.")

print("Session:", SESSION_PATH)
print(json.dumps(CFG, indent=2))

## 2) Load and Inspect Seis + RGT + Fault Data

In [None]:
assert SEIS_DIR.exists(), f"Missing: {SEIS_DIR}"
assert RGT_DIR.exists(), f"Missing: {RGT_DIR}"
assert FAULT_DIR.exists(), f"Missing: {FAULT_DIR}"

seis_ids = {f.name for f in SEIS_DIR.iterdir() if f.is_file()}
rgt_ids = {f.name for f in RGT_DIR.iterdir() if f.is_file()}
fault_ids = {f.name for f in FAULT_DIR.iterdir() if f.is_file()}

ids = sorted(list(seis_ids & rgt_ids & fault_ids))
print(f"Seis: {len(seis_ids)}, RGT: {len(rgt_ids)}, Fault: {len(fault_ids)}")
print(f"Matched triples: {len(ids)}")

if CFG["max_samples"] is not None:
    ids = ids[:CFG["max_samples"]]
print("Using samples:", len(ids))

n1, n2, n3 = CFG["shape"]
rid = random.choice(ids)
seis = np.fromfile(SEIS_DIR / rid, dtype=np.float32).reshape(n1, n2, n3)
rgt = np.fromfile(RGT_DIR / rid, dtype=np.float32).reshape(n1, n2, n3)
fault = np.fromfile(FAULT_DIR / rid, dtype=np.float32).reshape(n1, n2, n3)
fault = (fault > 0.5).astype(np.float32)

print("Example sample:", rid)
print("Fault positive ratio:", float(fault.mean()))

mid_i, mid_x, mid_t = n1 // 2, n2 // 2, n3 // 2
fig, ax = plt.subplots(3, 3, figsize=(12, 10))
ax[0, 0].imshow(seis[mid_i], cmap="gray", aspect="auto"); ax[0, 0].set_title("Seis inline")
ax[0, 1].imshow(seis[:, mid_x, :], cmap="gray", aspect="auto"); ax[0, 1].set_title("Seis xline")
ax[0, 2].imshow(seis[:, :, mid_t], cmap="gray", aspect="auto"); ax[0, 2].set_title("Seis time")
ax[1, 0].imshow(rgt[mid_i], cmap="jet", aspect="auto"); ax[1, 0].set_title("RGT inline")
ax[1, 1].imshow(rgt[:, mid_x, :], cmap="jet", aspect="auto"); ax[1, 1].set_title("RGT xline")
ax[1, 2].imshow(rgt[:, :, mid_t], cmap="jet", aspect="auto"); ax[1, 2].set_title("RGT time")
ax[2, 0].imshow(fault[mid_i], cmap="magma", aspect="auto"); ax[2, 0].set_title("Fault inline")
ax[2, 1].imshow(fault[:, mid_x, :], cmap="magma", aspect="auto"); ax[2, 1].set_title("Fault xline")
ax[2, 2].imshow(fault[:, :, mid_t], cmap="magma", aspect="auto"); ax[2, 2].set_title("Fault time")
for a in ax.ravel():
    a.axis("off")
plt.tight_layout()
plt.show()

## 3) Build Dataset and DataLoader Pipelines

In [None]:
class RgtFaultDataset(TorchDataset):
    def __init__(self, root_dir: Path, ids_list, shape, augment=False):
        self.root = root_dir
        self.ids = ids_list
        self.shape = shape
        self.augment = augment

    def __len__(self):
        return len(self.ids)

    def _read_vol(self, folder: str, file_id: str):
        arr = np.fromfile(self.root / folder / file_id, dtype=np.float32)
        arr = arr.reshape(self.shape)
        return arr

    def __getitem__(self, index):
        file_id = self.ids[index]
        seis = self._read_vol("seis", file_id)
        rgt = self._read_vol("rgt", file_id)
        fault = self._read_vol("fault", file_id)

        if self.augment:
            # 1) Random Horizontal Flip (left-right)
            if random.random() > 0.5:
                seis = np.flip(seis, axis=2).copy()
                rgt = np.flip(rgt, axis=2).copy()
                fault = np.flip(fault, axis=2).copy()

            # 2) Random Vertical Flip (inline direction)
            if random.random() > 0.5:
                seis = np.flip(seis, axis=0).copy()
                rgt = np.flip(rgt, axis=0).copy()
                fault = np.flip(fault, axis=0).copy()

            # 3) Random Rotation in spatial plane only (axes 1 & 2)
            # Do NOT rotate time/depth axis (axis 0)
            k = random.randint(0, 3)
            if k > 0:
                seis = np.rot90(seis, k=k, axes=(1, 2)).copy()
                rgt = np.rot90(rgt, k=k, axes=(1, 2)).copy()
                fault = np.rot90(fault, k=k, axes=(1, 2)).copy()

        # normalize seismic and rgt by mean/std
        seis = (seis - seis.mean()) / (seis.std() + 1e-8)
        rgt = (rgt - rgt.mean()) / (rgt.std() + 1e-8)
        fault = (fault > 0.5).astype(np.float32)

        seis = np.ascontiguousarray(seis)
        rgt = np.ascontiguousarray(rgt)
        fault = np.ascontiguousarray(fault)

        # (D,H,W) -> (1,W,H,D) to stay consistent with existing project ordering
        seis_t = torch.from_numpy(np.transpose(seis[None, ...], (0, 3, 2, 1)).copy())
        rgt_t = torch.from_numpy(np.transpose(rgt[None, ...], (0, 3, 2, 1)).copy())
        fault_t = torch.from_numpy(np.transpose(fault[None, ...], (0, 3, 2, 1)).copy())

        return seis_t.float(), rgt_t.float(), fault_t.float(), file_id


train_ids, temp_ids = train_test_split(ids, test_size=(1.0 - CFG["train_ratio"]), random_state=SEED, shuffle=True)
val_rel = CFG["val_ratio"] / (CFG["val_ratio"] + CFG["test_ratio"])
val_ids, test_ids = train_test_split(temp_ids, test_size=(1.0 - val_rel), random_state=SEED, shuffle=True)

print(f"Train/Val/Test = {len(train_ids)}/{len(val_ids)}/{len(test_ids)}")

train_ds = RgtFaultDataset(DATA_ROOT, train_ids, CFG["shape"], augment=CFG["use_augmentation"])
val_ds = RgtFaultDataset(DATA_ROOT, val_ids, CFG["shape"], augment=False)
test_ds = RgtFaultDataset(DATA_ROOT, test_ids, CFG["shape"], augment=False)

train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,
                          num_workers=CFG["num_workers"], pin_memory=CFG["pin_memory"] and use_cuda)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
                        num_workers=CFG["num_workers"], pin_memory=CFG["pin_memory"] and use_cuda)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False,
                         num_workers=CFG["num_workers"], pin_memory=CFG["pin_memory"] and use_cuda)

xb, y_rgt_b, y_fault_b, _ = next(iter(train_loader))
print("Batch shapes:", tuple(xb.shape), tuple(y_rgt_b.shape), tuple(y_fault_b.shape))

## 4) Define the Multitask Model (Shared Encoder, Two Heads)

In [None]:
class MultiTaskRgtFaultModel(nn.Module):
    def __init__(self, pretrained_rgt_path=None):
        super().__init__()
        self.rgt_net = net3d.model({
            "input_channels": 1,
            "encoder_channels": 512,
            "decoder_channels": 16,
        })
        self.fault_head = nn.Sequential(
            nn.Conv3d(2, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, 8, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(8, 1, kernel_size=1),
        )

        if pretrained_rgt_path is not None:
            ckpt = torch.load(pretrained_rgt_path, map_location="cpu")
            state = ckpt.get("model_state_dict", ckpt)
            cleaned_state = {}
            for key, value in state.items():
                nk = key[7:] if key.startswith("module.") else key
                if nk.startswith("rgt_net."):
                    nk = nk[len("rgt_net."): ]
                cleaned_state[nk] = value

            missing, unexpected = self.rgt_net.load_state_dict(cleaned_state, strict=False)
            print(f"Loaded pretrained RGT weights from: {pretrained_rgt_path}")
            print(f"Warm-start missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")

    def forward(self, seis):
        pred_rgt = self.rgt_net(seis)
        fault_in = torch.cat([seis, pred_rgt], dim=1)
        pred_fault_logits = self.fault_head(fault_in)
        return pred_rgt, pred_fault_logits


model = MultiTaskRgtFaultModel(pretrained_rgt_path=PRETRAINED_RGT_CKPT)
if use_cuda:
    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())).to(device)
else:
    model = model.to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Total params: {n_params:,}")

with torch.no_grad():
    xb = xb.to(device)
    rgt_pred_b, fault_logit_b = model(xb)
print("Forward sanity:", tuple(rgt_pred_b.shape), tuple(fault_logit_b.shape))

## 5) Configure Loss, Optimizer, Scheduler

In [None]:
if CFG["loss_rgt"].upper() == "SSIM":
    criterion_rgt = ssim3DLoss()
else:
    criterion_rgt = mse3DLoss()

# estimate pos_weight from a few batches
pos = 0.0
neg = 0.0
with torch.no_grad():
    for i, (_, _, f, _) in enumerate(train_loader):
        if i >= min(20, len(train_loader)):
            break
        pos += float((f > 0.5).sum().item())
        neg += float((f <= 0.5).sum().item())
pos_weight = torch.tensor([neg / max(pos, 1.0)], device=device)
criterion_fault = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = optim.Adam(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=CFG["lr_factor"], patience=CFG["lr_patience"])
scaler = torch.cuda.amp.GradScaler(enabled=(use_cuda and CFG["mixed_precision"]))

print("RGT loss:", criterion_rgt.getLossName())
print("Fault loss: BCEWithLogitsLoss")
print("Fault pos_weight:", float(pos_weight.item()))

## 6) Training and Validation Loops

In [None]:
def init_reg_meter():
    return {"abs_sum": 0.0, "sq_sum": 0.0, "count": 0}


def update_reg_meter(meter, y_true, y_pred):
    diff = (y_true - y_pred).detach()
    meter["abs_sum"] += float(diff.abs().sum().item())
    meter["sq_sum"] += float((diff * diff).sum().item())
    meter["count"] += int(diff.numel())


def finalize_reg_meter(meter):
    if meter["count"] == 0:
        return np.nan, np.nan
    mae = meter["abs_sum"] / meter["count"]
    rmse = float(np.sqrt(meter["sq_sum"] / meter["count"]))
    return float(mae), float(rmse)


def compute_fault_stats(logits, target):
    pred = (torch.sigmoid(logits) > 0.5).float()
    tp = ((pred == 1) & (target == 1)).sum().item()
    fp = ((pred == 1) & (target == 0)).sum().item()
    fn = ((pred == 0) & (target == 1)).sum().item()
    tn = ((pred == 0) & (target == 0)).sum().item()
    return tp, fp, fn, tn


def summarize_fault(tp, fp, fn, tn):
    precision = tp / max(tp + fp, 1)
    recall = tp / max(tp + fn, 1)
    f1 = 2 * precision * recall / max(precision + recall, 1e-8)
    iou = tp / max(tp + fp + fn, 1)
    acc = (tp + tn) / max(tp + fp + fn + tn, 1)
    return dict(precision=precision, recall=recall, f1=f1, iou=iou, acc=acc)


def run_epoch(loader, train_mode=True, epoch_idx=0, total_epochs=1):
    if train_mode:
        model.train()
        title = "Train"
    else:
        model.eval()
        title = "Val"

    t0 = time.time()
    total_loss = 0.0
    total_rgt_loss = 0.0
    total_fault_loss = 0.0

    reg_meter = init_reg_meter()
    tp = fp = fn = tn = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch_idx+1}/{total_epochs} [{title}]", dynamic_ncols=True)

    for bidx, (seis, rgt_gt, fault_gt, _) in enumerate(pbar):
        seis = seis.to(device, non_blocking=True)
        rgt_gt = rgt_gt.to(device, non_blocking=True)
        fault_gt = fault_gt.to(device, non_blocking=True)

        if train_mode:
            optimizer.zero_grad(set_to_none=True)

        with torch.set_grad_enabled(train_mode):
            with torch.amp.autocast(device_type=device.type, enabled=scaler.is_enabled()):
                rgt_pred, fault_logits = model(seis)
                rgt_loss = criterion_rgt(rgt_pred, rgt_gt)
                fault_loss = criterion_fault(fault_logits, fault_gt)
                loss = CFG["lambda_rgt"] * rgt_loss + CFG["lambda_fault"] * fault_loss

            if train_mode:
                try:
                    scaler.scale(loss).backward()
                except torch.cuda.OutOfMemoryError as e:
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    raise RuntimeError("CUDA OOM during backward. Reduce CFG['batch_size'], set CFG['loss_rgt']=MSE, or reduce CFG['max_samples'] for a smoke test.") from e
                if CFG["grad_clip"] is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
                scaler.step(optimizer)
                scaler.update()

        total_loss += float(loss.detach().cpu())
        total_rgt_loss += float(rgt_loss.detach().cpu())
        total_fault_loss += float(fault_loss.detach().cpu())

        update_reg_meter(reg_meter, rgt_gt, rgt_pred)
        ctp, cfp, cfn, ctn = compute_fault_stats(fault_logits.detach(), fault_gt.detach())
        tp += ctp; fp += cfp; fn += cfn; tn += ctn

        avg_loss = total_loss / (bidx + 1)
        elapsed = time.time() - t0
        pbar.set_postfix(loss=f"{avg_loss:.4f}", rgt=f"{(total_rgt_loss/(bidx+1)):.4f}", fault=f"{(total_fault_loss/(bidx+1)):.4f}", lr=f"{optimizer.param_groups[0]['lr']:.2e}", sec=f"{elapsed:.1f}")

    mean_loss = total_loss / max(len(loader), 1)
    mean_rgt_loss = total_rgt_loss / max(len(loader), 1)
    mean_fault_loss = total_fault_loss / max(len(loader), 1)

    rgt_mae, rgt_rmse = finalize_reg_meter(reg_meter)

    fstats = summarize_fault(tp, fp, fn, tn)

    return {
        "loss": mean_loss,
        "rgt_loss": mean_rgt_loss,
        "fault_loss": mean_fault_loss,
        "rgt_mae": rgt_mae,
        "rgt_rmse": rgt_rmse,
        **fstats,
        "time_sec": time.time() - t0,
    }

## 7) Train, Validate, and Save Checkpoints

In [None]:
history = []
best_val = float("inf")
best_epoch = -1
best_ckpt = CKPT_PATH / "best_rgt_fault.pth"

train_begin = time.time()
for epoch in range(CFG["epochs"]):
    train_stats = run_epoch(train_loader, train_mode=True, epoch_idx=epoch, total_epochs=CFG["epochs"])
    val_stats = run_epoch(val_loader, train_mode=False, epoch_idx=epoch, total_epochs=CFG["epochs"])

    scheduler.step(val_stats["loss"])
    lr_now = optimizer.param_groups[0]["lr"]

    row = {
        "epoch": epoch,
        "lr": lr_now,
        "train": train_stats,
        "val": val_stats,
        "elapsed_total_sec": time.time() - train_begin,
    }
    history.append(row)

    print(
        f"Epoch {epoch+1}/{CFG['epochs']} | "
        f"train_loss={train_stats['loss']:.5f} val_loss={val_stats['loss']:.5f} | "
        f"train_f1={train_stats['f1']:.4f} val_f1={val_stats['f1']:.4f} | "
        f"train_rmse={train_stats['rgt_rmse']:.4f} val_rmse={val_stats['rgt_rmse']:.4f} | "
        f"lr={lr_now:.2e}"
    )

    save_obj = {
        "epoch": epoch,
        "model_state_dict": model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "cfg": CFG,
        "session_name": SESSION_NAME,
        "history_tail": history[-5:],
    }

    if val_stats["loss"] < best_val:
        best_val = val_stats["loss"]
        best_epoch = epoch
        torch.save(save_obj, best_ckpt)
        print(f"  -> Saved BEST checkpoint: {best_ckpt}")

    if (epoch + 1) % CFG["save_every"] == 0:
        ep_ckpt = CKPT_PATH / f"epoch_{epoch+1:04d}.pth"
        torch.save(save_obj, ep_ckpt)
        print(f"  -> Saved periodic checkpoint: {ep_ckpt}")

print(f"Training finished. Best epoch={best_epoch+1}, best val loss={best_val:.6f}")

with open(HISTORY_PATH / "history_rgt_fault.json", "w", encoding="utf-8") as f:
    json.dump(history, f, indent=2)

with open(HISTORY_PATH / "run_info_rgt_fault.json", "w", encoding="utf-8") as f:
    json.dump({
        "session_name": SESSION_NAME,
        "best_epoch": best_epoch,
        "best_val_loss": best_val,
        "best_checkpoint": str(best_ckpt),
        "cfg": CFG,
        "seed": SEED,
        "device": str(device),
    }, f, indent=2)

print("Saved history + run info in", HISTORY_PATH)

## 8) Final Test Evaluation

In [None]:
ckpt = torch.load(best_ckpt, map_location=device)
if isinstance(model, torch.nn.DataParallel):
    model.module.load_state_dict(ckpt["model_state_dict"])
else:
    model.load_state_dict(ckpt["model_state_dict"])

model.eval()
test_stats = run_epoch(test_loader, train_mode=False, epoch_idx=0, total_epochs=1)
print("Test metrics:")
print(json.dumps(test_stats, indent=2))

## 9) 3Ã—5 Visualization Panel (Inline / Xline / Time)

In [None]:
def slices_3(v):
    # v shape (D,H,W)
    d, h, w = v.shape
    return [v[d // 2, :, :], v[:, h // 2, :], v[:, :, w // 2]]

model.eval()
with torch.no_grad():
    seis_b, rgt_b, fault_b, ids_b = next(iter(test_loader))
    seis_b = seis_b.to(device)
    rgt_pred_b, fault_logit_b = model(seis_b)

seis_v = seis_b[0, 0].detach().cpu().numpy()      # (W,H,D) due to transpose in dataset
rgt_gt_v = rgt_b[0, 0].detach().cpu().numpy()
rgt_pr_v = rgt_pred_b[0, 0].detach().cpu().numpy()
fault_gt_v = fault_b[0, 0].detach().cpu().numpy()
fault_pr_v = torch.sigmoid(fault_logit_b[0, 0]).detach().cpu().numpy()

# bring back to (D,H,W) for display consistency
seis_v = np.transpose(seis_v, (2, 1, 0))
rgt_gt_v = np.transpose(rgt_gt_v, (2, 1, 0))
rgt_pr_v = np.transpose(rgt_pr_v, (2, 1, 0))
fault_gt_v = np.transpose(fault_gt_v, (2, 1, 0))
fault_pr_v = np.transpose(fault_pr_v, (2, 1, 0))

rows = [
    ("Mid Inline", lambda v: slices_3(v)[0]),
    ("Mid Xline", lambda v: slices_3(v)[1]),
    ("Mid Time", lambda v: slices_3(v)[2]),
]

col_titles = ["Seismic", "Fault GT", "Fault Pred", "RGT GT", "RGT Pred"]
vols = [seis_v, fault_gt_v, fault_pr_v, rgt_gt_v, rgt_pr_v]
cmaps = ["gray", "magma", "magma", "jet", "jet"]

fig, ax = plt.subplots(3, 5, figsize=(18, 10))
for r, (row_name, row_fn) in enumerate(rows):
    for c in range(5):
        sl = row_fn(vols[c])
        ax[r, c].imshow(sl, cmap=cmaps[c], aspect="auto")
        if r == 0:
            ax[r, c].set_title(col_titles[c])
        if c == 0:
            ax[r, c].set_ylabel(row_name)
        ax[r, c].axis("off")

sample_name = ids_b[0]
plt.suptitle(f"Sample: {sample_name}", y=0.98)
plt.tight_layout()
plt.show()

fig.savefig(FIG_PATH / f"panel_3x5_{sample_name}.png", dpi=150)
print("Saved panel figure:", FIG_PATH / f"panel_3x5_{sample_name}.png")

## 10) Minimal Inference Helper (Optional)

In [None]:
def predict_one(file_id: str):
    ds = RgtFaultDataset(DATA_ROOT, [file_id], CFG["shape"])
    seis_t, rgt_t, fault_t, _ = ds[0]
    x = seis_t.unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        pr_rgt, pr_fault_logit = model(x)
    pr_fault = torch.sigmoid(pr_fault_logit)
    return {
        "seis": seis_t.numpy(),
        "rgt_gt": rgt_t.numpy(),
        "fault_gt": fault_t.numpy(),
        "rgt_pred": pr_rgt[0].cpu().numpy(),
        "fault_pred": pr_fault[0].cpu().numpy(),
    }

print("Use: out = predict_one(test_ids[0])")