In [None]:
import os
import warnings

import torch
from inference import load_real_binary_file, run_batch_inference

from viz import show_images

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Directory containing the model .pth file and potentially index files
MODEL_PARENT_DIR = "/path/to/your/data/please/edit/me"
# Specific model file to load
MODEL_FILENAME = "/path/to/your/data/please/edit/me"
# Directory containing the 'real' data to run inference on
DATA_REAL_DIR = "/path/to/your/data/please/edit/me"  # Example path
# Filename of the 'real' input data
DATA_REAL_INPUT_NAME = "cube5_extract.dat"  # Example filename
# Directory to save the predicted output
PREDICTED_DIR = MODEL_PARENT_DIR  # Save predictions in the model's directory
# Filename for the predicted output binary file
PREDICTED_NAME = "inference_for_real_cube5_extract_data.bin"

# Shape of the 'real' input data cube (assuming Z is the batch-like dim)
NX_real = 481
NY_real = 751
NZ_real = 481

# --- Inference Parameters ---
BATCH_SIZE = 12  # Adjust based on GPU memory
# Normalization functions corresponding to the *trained* model (must match training)
# These should be names from the utils.py module (batch-aware versions)
NORMALIZE_INPUT = "normalize_input_to_max1"
UNNORM_OUTPUT = "unnormalize_output_to_max1"

# --- Path Validation ---
MODEL_PATH = os.path.join(MODEL_PARENT_DIR, MODEL_FILENAME)
REAL_DATA_PATH = os.path.join(DATA_REAL_DIR, DATA_REAL_INPUT_NAME)
PREDICTED_PATH = os.path.join(PREDICTED_DIR, PREDICTED_NAME)

if not os.path.exists(MODEL_PATH):
    warnings.warn(f"Model file not found: {MODEL_PATH}")
if not os.path.exists(REAL_DATA_PATH):
    warnings.warn(f"Real data file not found: {REAL_DATA_PATH}")

os.makedirs(PREDICTED_DIR, exist_ok=True)  # Ensure output directory exists

In [None]:
model = torch.load(MODEL_PATH, map_location=DEVICE)
model.eval()  # Set to evaluation mode
print(f"Model loaded successfully from {MODEL_PATH}")

real_imgs = load_real_binary_file(REAL_DATA_PATH, NX_real, NY_real, NZ_real)
print(f"Real data loaded successfully from {REAL_DATA_PATH}")
print(
    f"Data shape: {real_imgs.shape}, Min: {real_imgs.min():.2f}, Max: {real_imgs.max():.2f}"
)

In [None]:
predicted_results = run_batch_inference(
    model,
    real_imgs,
    batch_size=BATCH_SIZE,
    normalize_input=NORMALIZE_INPUT,
    unnorm_output=UNNORM_OUTPUT,
    device=DEVICE,
)

In [None]:
# Analyze and Save Results
if predicted_results is not None:
    print(f"Predicted results shape: {predicted_results.shape}")
    print(
        f"Predicted Min: {predicted_results.min():.2f}, Max: {predicted_results.max():.2f}, Mean: {predicted_results.mean():.2f}"
    )

    # Save the results
    # save_tensor_to_binary(predicted_results, PREDICTED_PATH)

    # Optional: Visualize a slice
    slice_idx = NZ_real // 2  # Show middle slice
    print(f"\nVisualizing slice {slice_idx}...")
    show_images(
        real_imgs[slice_idx],
        predicted_results[slice_idx],
        titles=[f"Input Slice {slice_idx}", f"Predicted Slice {slice_idx}"],
    )
else:
    print("No prediction results to analyze or save.")

In [None]:
# Optional: Compare with previously saved prediction if needed

PREVIOUS_PREDICTED_NAME = "inference_for_real_cube5_extract_data.bin"
PREVIOUS_PREDICTED_PATH = os.path.join(MODEL_PARENT_DIR, PREVIOUS_PREDICTED_NAME)

if os.path.exists(PREVIOUS_PREDICTED_PATH) and predicted_results is not None:
    print(f"\nLoading previous prediction for comparison: {PREVIOUS_PREDICTED_PATH}")
    try:
        previous_predicted = load_real_binary_file(
            PREVIOUS_PREDICTED_PATH, NX_real, NY_real, NZ_real
        )
        print(
            f"Previous prediction shape: {previous_predicted.shape}, Min: {previous_predicted.min():.2f}, Max: {previous_predicted.max():.2f}"
        )

        # Simple comparison (e.g., Mean Absolute Error)
        mae = torch.mean(torch.abs(predicted_results - previous_predicted))
        print(f"Mean Absolute Error between current and previous prediction: {mae:.4e}")

        # Visualize comparison for a slice
        slice_idx_comp = NZ_real // 2  # Example slice index from original notebook
        if slice_idx_comp < NZ_real:
            show_images(
                predicted_results[slice_idx_comp],
                previous_predicted[slice_idx_comp],
                titles=[
                    f"Current Pred {slice_idx_comp}",
                    f"Previous Pred {slice_idx_comp}",
                ],
            )
        else:
            print(f"Slice index {slice_idx_comp} out of bounds for visualization.")

    except Exception as e:
        print(f"Error loading or comparing previous prediction: {e}")
else:
    if not os.path.exists(PREVIOUS_PREDICTED_PATH):
        print(
            f"\nPrevious prediction file not found at {PREVIOUS_PREDICTED_PATH}. Cannot compare."
        )
    if predicted_results is None:
        print("\nCurrent prediction not available. Cannot compare.")