In [None]:
# In your Jupyter Notebook (analysis/01_compare_decomposition.ipynb)

import openmc
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# --- Add project root to path ---
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
# --------------------------------

# --- Import your new plotting function ---
from analysis.common_plotting import plot_phir_slice

# --- 1. Set paths and Tally Name ---
SP_FILE = '../data/run_full_source_flat/statepoint.10.h5'
TALLY_NAME = 'cyl_tally' # <-- Make sure this matches your tally name in inputs.py
Z_SLICE_INDEX = 1 # Which Z-slice to plot (e.g., the first one)

# --- 2. Load Data from StatePoint ---
sp = openmc.StatePoint(SP_FILE)
try:
    tally = sp.get_tally(name=TALLY_NAME)
except ValueError:
    print(f"Error: Tally '{TALLY_NAME}' not found.")
    sp.close()
    # exit() # In a notebook, just stop

# --- 3. Get the Mesh Filter (THIS IS THE FIX) ---
try:
    mesh_filter = tally.find_filter(openmc.MeshFilter)
    
    # Check if its mesh is actually a CylindricalMesh
    if not isinstance(mesh_filter.mesh, openmc.CylindricalMesh):
        raise TypeError(f"Tally '{TALLY_NAME}' mesh is not CylindricalMesh.")
        
except (ValueError, TypeError) as e:
    print(f"Error: Tally '{TALLY_NAME}' does not have the correct mesh filter.")
    print(f"Details: {e}")
    sp.close()
    # exit()
# --- END FIX ---

# --- 4. Extract Data, Grids, and Volumes ---
# Get the mesh object itself
mesh = mesh_filter.mesh

# Get the grid edges
r_grid = mesh.r_grid
phi_grid = mesh.phi_grid
z_grid = mesh.z_grid

# Get the raw tally data (mean and std_dev)
# The shape will be (num_r_bins, num_phi_bins, num_z_bins)
flux_mean = tally.mean.reshape(mesh.dimension)
flux_std_dev = tally.std_dev.reshape(mesh.dimension)

# Get the volumes of each mesh bin
volumes = mesh.volumes

# We're done with the statepoint file
sp.close()

# --- 5. Call Your Plotting Function ---
print(f"Plotting Z-slice index: {Z_SLICE_INDEX}")

# Note: Your z_grid = np.array([-10,-9.9,9.9,10]) has 3 bins
# Valid Z_SLICE_INDEX values are 0, 1, or 2
# We are plotting index 1, which is the z_bin from -9.9 to 9.9

ax = plot_phir_slice(
    tally_data=flux_mean,  # Pass the 3D mean data
    volumes=volumes,       # Pass the 3D volumes
    phi_grid=phi_grid,     # Pass the phi bin edges
    r_grid=r_grid,         # Pass the r bin edges
    slice_index=Z_SLICE_INDEX
)

ax.set_title(f"Full Run Flux (Z-slice {Z_SLICE_INDEX})")
plt.show()

ValueError: cannot reshape array of size 22800 into shape (40,95,3)