# Prediction with a pre‑trained model (TensorFlow)

This notebook demonstrates how to use a pre‑trained **3D U‑Net** model to perform fault segmentation on new seismic data.  
It covers two scenarios:

1. **Simple prediction** – run inference on a single validation cube to verify the model.
2. **Complex prediction** – run inference on a large F3 field volume using an overlap‑and‑blend tiling strategy.


## Step&nbsp;1 · Import libraries and configure environment

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from tensorflow.keras.models import load_model
# --- custom U‑Net + loss ---
from unet3_tf import cross_entropy_balanced   # unet3_tf.py must expose it

# ------------------------------------------------------------------
# GPU configuration (identical to training notebook)
# ------------------------------------------------------------------
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(len(gpus), "Physical GPUs,", len(tf.config.list_logical_devices('GPU')), "Logical GPUs")
    except RuntimeError as exc:
        print(exc)
else:
    print("No GPU detected – running on CPU.")

# ------------------------------------------------------------------
# Robust project‑root detection (works in notebook *and* .py script)
# ------------------------------------------------------------------
try:
    ROOT_DIR = os.path.dirname(os.path.abspath(__file__))   # running as script
except NameError:
    ROOT_DIR = os.getcwd()                                  # running in notebook
print("ROOT_DIR =", ROOT_DIR)


## Step&nbsp;2 · Define paths and parameters

In [None]:
# ── Directory tree — mirrors the training notebook ─────────────────
base_dir            = os.path.abspath(os.path.join(ROOT_DIR, "..", "data", "data_from_Wu"))
processed_data_dir  = os.path.join(ROOT_DIR, "data")                    # npy cubes
validation_dir_new  = os.path.join(processed_data_dir, "validation_npy")
prediction_dir_f3d  = os.path.join(base_dir, "prediction", "f3d")
model_dir           = os.path.join(ROOT_DIR, "model")

# ── Model file (edit if you switch models) ───────────────────────────
model_name = "unet_tf_model_200pairs_10epochs_2025-07-29_18-18-55.keras"
model_path = os.path.join(model_dir, model_name)

# ── Model input size (must match training) ───────────────────────────
patch_n1 = patch_n2 = patch_n3 = 128

print("Model           :", model_path)
print("Validation cubes:", validation_dir_new)
print("F3 field folder :", prediction_dir_f3d)


## Step 3 · Load the trained model

In [None]:
if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model file not found: {model_path}")

model = load_model(model_path,
                   custom_objects={"cross_entropy_balanced": cross_entropy_balanced})
print("\n✓ Model loaded")
model.summary()


## Step&nbsp;4 · Simple prediction on a validation sample

In [None]:
def plot_prediction_slices(gx_slice, fx_slice, fp_slice):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    ax1.set_title("Input seismic");       ax1.imshow(gx_slice.T, cmap='gray', vmin=-2, vmax=2)
    ax2.set_title("Ground‑truth faults"); ax2.imshow(fx_slice.T, cmap='gray', vmin=0, vmax=1)
    ax3.set_title("Predicted faults");    ax3.imshow(fp_slice.T, cmap='gray', vmin=0, vmax=1)
    plt.tight_layout(); plt.show()

sample_id = "10"
gx_orig = np.load(os.path.join(validation_dir_new, "seis",  f"{sample_id}.npy"))
fx_orig = np.load(os.path.join(validation_dir_new, "fault", f"{sample_id}.npy"))

gx_proc = (gx_orig - gx_orig.mean()) / (gx_orig.std() + 1e-8)
gx_tr   = np.transpose(gx_proc)   # match training orientation
fx_tr   = np.transpose(fx_orig)

inp = gx_tr.reshape(1, patch_n1, patch_n2, patch_n3, 1)
fp_tr = model.predict(inp, verbose=0)[0, ..., 0]

slice_idx = 64
print(f"Displaying slice {slice_idx} of validation cube {sample_id}")
plot_prediction_slices(gx_tr[slice_idx], fx_tr[slice_idx], fp_tr[slice_idx])


## Step&nbsp;5 · Complex prediction on a large field image (tiling/patching)

In [None]:
def create_gaussian_mask(overlap, n1, n2, n3):
    """3‑D Gaussian mask for smooth patch blending."""
    sig = overlap / 4.0
    sig = 0.5 / (sig * sig)
    ramp = np.exp(-((np.arange(overlap) - overlap + 1) ** 2) * sig).astype(np.single)

    sc = np.ones((n1, n2, n3), dtype=np.single)
    for i in range(overlap):
        sc[i    , :, :] *= ramp[i]
        sc[n1-1-i, :, :] *= ramp[i]
        sc[:, i    , :] *= ramp[i]
        sc[:, n2-1-i, :] *= ramp[i]
        sc[:, :, i    ] *= ramp[i]
        sc[:, :, n3-1-i] *= ramp[i]
    return sc

# 1️⃣  Load and transpose seismic volume
f3_path = os.path.join(prediction_dir_f3d, "gxl.dat")
d, i, x = 512, 384, 128   # depth, inline, xline
gx = np.fromfile(f3_path, dtype=np.single).reshape(d, i, x)
print("Loaded F3 volume:", gx.shape, "(depth, inline, xline)")

gx = gx.transpose(2, 1, 0)   # → (xline, inline, depth) == (nx, ny, nz)
nx, ny, nz = gx.shape
print("Transposed for model:", gx.shape)

# 2️⃣  Patch parameters
overlap_size = 12
stride = (patch_n1 - overlap_size,
          patch_n2 - overlap_size,
          patch_n3 - overlap_size)

pad_x = (stride[0] - (nx - patch_n1) % stride[0]) % stride[0]
pad_y = (stride[1] - (ny - patch_n2) % stride[1]) % stride[1]
pad_z = (stride[2] - (nz - patch_n3) % stride[2]) % stride[2]

gp = np.pad(gx, ((0, pad_x), (0, pad_y), (0, pad_z)), mode='constant')
px, py, pz = gp.shape
print("Padded to:", gp.shape)

gy = np.zeros_like(gp, dtype=np.single)
mk = np.zeros_like(gp, dtype=np.single)
mask = create_gaussian_mask(overlap_size, patch_n1, patch_n2, patch_n3)
patch_buf = np.zeros((1, patch_n1, patch_n2, patch_n3, 1), dtype=np.single)

print("\nStarting tiled prediction …")
for x0 in tqdm(range(0, px - patch_n1 + 1, stride[0]), desc="x‑dim"):
    for y0 in range(0, py - patch_n2 + 1, stride[1]):
        for z0 in range(0, pz - patch_n3 + 1, stride[2]):
            patch = gp[x0:x0+patch_n1, y0:y0+patch_n2, z0:z0+patch_n3]
            patch = (patch - patch.mean()) / (patch.std() + 1e-8)
            patch_buf[0, ..., 0] = patch

            pred = model.predict(patch_buf, verbose=0)[0, ..., 0]
            gy[x0:x0+patch_n1, y0:y0+patch_n2, z0:z0+patch_n3] += pred * mask
            mk[x0:x0+patch_n1, y0:y0+patch_n2, z0:z0+patch_n3] += mask

# 4️⃣  Normalise & crop back to original size
gy = np.divide(gy, mk, out=np.zeros_like(gy), where=mk != 0)
fp_final = gy[:nx, :ny, :nz]          # still (xline, inline, depth)

# Save in (xline, inline, depth) orientation
pred_out_path = os.path.join(prediction_dir_f3d, "fp_tensorflow.dat")
fp_final.astype(np.single).tofile(pred_out_path)
print("\n✓ Prediction finished — saved to", pred_out_path)


## Step&nbsp;6 · Visualise the field prediction results

In [None]:
# Reload seismic in original orientation
gx_orig = np.fromfile(os.path.join(prediction_dir_f3d, 'gxl.dat'), dtype=np.single).reshape(512, 384, 128)

# Reload prediction and transpose back to (depth, inline, xline)
fp_raw = np.fromfile(pred_out_path, dtype=np.single).reshape(128, 384, 512)
fp_orig = fp_raw.transpose(2, 1, 0)

print("Seismic shape    :", gx_orig.shape)
print("Prediction shape :", fp_orig.shape)

k_depth, k_inline, k_xline = 29, 29, 99

# X‑line slice
fig = plt.figure(figsize=(9,9))
plt.suptitle(f'X‑line slice {k_xline}', y=0.82, fontsize=16)
plt.subplot(1,2,1); plt.title("Seismic");         plt.imshow(gx_orig[:, :, k_xline].T, aspect=1.5, cmap='gray')
plt.subplot(1,2,2); plt.title("Fault prediction"); plt.imshow(fp_orig[:, :, k_xline].T, aspect=1.5, cmap='gray', vmin=0.4, vmax=1.0)
plt.show()

# Inline slice
fig = plt.figure(figsize=(9,9))
plt.suptitle(f'Inline slice {k_inline}', y=0.82, fontsize=16)
plt.subplot(1,2,1); plt.title("Seismic");         plt.imshow(gx_orig[:, k_inline, :].T, aspect=1.5, cmap='gray')
plt.subplot(1,2,2); plt.title("Fault prediction"); plt.imshow(fp_orig[:, k_inline, :].T, aspect=1.5, cmap='gray', vmin=0.4, vmax=1.0)
plt.show()

# Depth slice
fig = plt.figure(figsize=(9,9))
plt.suptitle(f'Depth slice {k_depth}', y=0.82, fontsize=16)
plt.subplot(1,2,1); plt.title("Seismic");         plt.imshow(gx_orig[k_depth].T, cmap='gray')
plt.subplot(1,2,2); plt.title("Fault prediction"); plt.imshow(fp_orig[k_depth].T, cmap='gray', vmin=0.4, vmax=1.0)
plt.show()
