In [None]:
import json
import numpy as np
from matplotlib import pyplot as plt
from ase.io import read as ase_read
from collections import defaultdict

In [None]:
input_xyz = "./jp_dio-orig_min4.xyz"
atoms = ase_read(input_xyz)

In [None]:
with open("./grain_pair_to_envs_mapping_max60pcnt.json", "r") as f:
    grain_pair_mapping = json.load(f)

print(f"Loaded grain pair mapping with {len(grain_pair_mapping)} pairs")
print(f"Total environments: {sum(len(envs) for envs in grain_pair_mapping.values())}")

# Show some examples
print("\nExample grain pairs and counts:")
for i, (pair_key, envs) in enumerate(sorted(grain_pair_mapping.items(),
                                            key=lambda x: len(x[1]),
                                            reverse=True)[:5]):
    print(f"  {pair_key}: {len(envs)} environments")

In [None]:
# Load PTM and grain labels

# Load index mapping
with open("./noOidx2orig.json", "r") as f:
    index_map = json.load(f)

# Reverse the mapping (orig -> noO)
index_map = {int(v): int(k) for k, v in index_map.items()}

# Load grain and PTM data
grain_ptm_data = np.load("./grains_ptm_111025_min4_fixed.npz")
noO_grains = grain_ptm_data["grains"]
noO_ptm_types = grain_ptm_data["ptm_types"]

xyz_ptm_types = np.full(len(atoms), -1, dtype=int)

# Map Hf atoms
for i, atm in enumerate(atoms):
    if atm.symbol == "O":
        continue
    xyz_ptm_types[i] = noO_ptm_types[index_map[i]]

In [None]:
# Compute pairwise distances with periodic boundary conditions
# FRAGILE FRAGILE FRAGILE - only works for orthorhombic
def compute_distances_pbc(atoms, central_idx, neighbor_idxs):
    """
    Compute distances between a central atom and its neighbors,
    accounting for periodic boundary conditions.

    Parameters:
    -----------
    atoms : ase.Atoms
        The atomic structure
    central_idx : int
        Index of the central atom
    neighbor_idxs : list of int
        Indices of neighbor atoms

    Returns:
    --------
    dict
        Keys are neighbor indices, values are distances (in Angstroms)
    """
    central_pos = atoms[central_idx].position
    cell = atoms.cell

    distances = {}
    for nidx in neighbor_idxs:
        neighbor_pos = atoms[nidx].position

        # Compute displacement vector
        delta = neighbor_pos - central_pos

        # orthorhombic assertion that really should just be done once
        for i in range(3):
            for j in range(3):
                if i == j:
                    continue
                else:
                    assert np.isclose(cell[i,j],0.0,atol=1e-12)

        # Apply minimum image convention for PBC
        # For orthorhombic cell: wrap to [-L/2, L/2]
        for i in range(3):
            if atoms.pbc[i]:
                cell_length = cell[i, i]
                delta[i] -= cell_length * np.round(delta[i] / cell_length)

        # Compute distance
        distance = np.linalg.norm(delta)
        distances[nidx] = distance

    return distances

print("Distance computation function defined")

In [None]:
# Plot RDF for a single environment

def plot_rdf_single_env(distances, bins=50, r_max=None, figsize=(10, 6), title=None):
    """
    Plot radial distribution function for a single O atom environment.

    Parameters:
    -----------
    distances : dict or list
        Dictionary mapping neighbor indices to distances, or list of distances
    bins : int
        Number of bins for histogram (default: 50)
    r_max : float or None
        Maximum distance for plotting (default: None, uses max distance)
    figsize : tuple
        Figure size (width, height)
    title : str or None
        Plot title

    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    # Convert to list if dict
    if isinstance(distances, dict):
        dist_list = list(distances.values())
    else:
        dist_list = distances

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Set r_max
    if r_max is None:
        r_max = max(dist_list)

    # Create histogram
    counts, bin_edges = np.histogram(dist_list, bins=bins, range=(0, r_max))
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Plot
    ax.plot(bin_centers, counts, linewidth=2)
    ax.fill_between(bin_centers, counts, alpha=0.3)

    # Labels and formatting
    ax.set_xlabel('Distance (Ã…)', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    if title:
        ax.set_title(title, fontsize=14, fontweight='bold')
    else:
        ax.set_title('Radial Distribution Function', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, linestyle=':')

    plt.tight_layout()
    return fig, ax

In [None]:
# Plot aggregate RDF for all environments in a grain pair set

def plot_rdf_grain_pair(environments_with_distances, bins=50, r_max=None,
                       figsize=(10, 6), grain_pair_label=None):
    """
    Plot aggregate RDF for all environments in a grain pair set.

    Parameters:
    -----------
    environments_with_distances : list of dict
        List of environment dicts, each containing a 'distances' key
    bins : int
        Number of bins for histogram (default: 50)
    r_max : float or None
        Maximum distance for plotting (default: None, uses max distance)
    figsize : tuple
        Figure size (width, height)
    grain_pair_label : str or None
        Label for the grain pair (e.g., "(2, 5)")

    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    # Collect all distances from all environments
    all_distances = []
    for env in environments_with_distances:
        if isinstance(env['distances'], dict):
            all_distances.extend(env['distances'].values())
        else:
            all_distances.extend(env['distances'])

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Set r_max
    if r_max is None:
        r_max = max(all_distances)

    # Create histogram
    counts, bin_edges = np.histogram(all_distances, bins=bins, range=(0, r_max))
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Plot
    ax.plot(bin_centers, counts, linewidth=2, color='darkblue')
    ax.fill_between(bin_centers, counts, alpha=0.3, color='skyblue')

    # Labels and formatting
    ax.set_xlabel('Distance (Ã…)', fontsize=12)
    ax.set_ylabel('Total Count', fontsize=12)

    if grain_pair_label:
        title = f'Aggregate RDF for Grain Pair {grain_pair_label}\n({len(environments_with_distances)} environments)'
    else:
        title = f'Aggregate RDF\n({len(environments_with_distances)} environments)'
    ax.set_title(title, fontsize=14, fontweight='bold')

    ax.grid(True, alpha=0.3, linestyle=':')

    plt.tight_layout()
    return fig, ax

print("Grain pair aggregate RDF plotting function defined")

In [None]:
# Test the distance computation function on one environment

# Get a sample environment from the first grain pair
first_pair_key = list(grain_pair_mapping.keys())[0]
sample_env = grain_pair_mapping[first_pair_key][0]

print(f"Testing on environment at O atom index: {sample_env['index']}")
print(f"Number of neighbors: {len(sample_env['neighbor_idxs'])}")

# Compute distances
test_distances = compute_distances_pbc(atoms, sample_env['index'], sample_env['neighbor_idxs'])

print(f"Computed {len(test_distances)} distances")
print(f"Min distance: {min(test_distances.values()):.3f} Ã…")
print(f"Max distance: {max(test_distances.values()):.3f} Ã…")
print(f"Mean distance: {np.mean(list(test_distances.values())):.3f} Ã…")

In [None]:
fig1, ax1 = plot_rdf_single_env(test_distances, bins=50,
                                title=f"RDF for O atom {sample_env['index']}")
plt.show()

In [None]:
for pair_key, envs in grain_pair_mapping.items():
    print(pair_key)
    for env in envs:
        distances = compute_distances_pbc(atoms, env['index'], env['neighbor_idxs'])
        env["distances"] = distances

In [None]:
gpm_keys = list(grain_pair_mapping.keys())
len(gpm_keys)

i = 30
pair_key = gpm_keys[i]
envs = grain_pair_mapping[pair_key]

fig, ax = plot_rdf_grain_pair(
    envs,
    bins=100,
    r_max=11.0,
    grain_pair_label=pair_key
)

In [None]:
proposed_cutoff_1 = 3.77
fig, axes = plt.subplots(8, 4, figsize=(16, 24))
axes = axes.flatten()

for i in range(len(gpm_keys)):
    pair_key = gpm_keys[i]
    envs = grain_pair_mapping[pair_key]
    fig_tmp, ax_tmp = plot_rdf_grain_pair(
        envs,
        bins=100,
        r_max=11.0,
        grain_pair_label=pair_key
    )
    ax_tmp.axvline(x=proposed_cutoff_1, color='red', linestyle='--', linewidth=2)

    # Copy artists from ax_tmp to target axis
    for artist in ax_tmp.get_children():
        artist.remove()
        axes[i].add_artist(artist)

    plt.close(fig_tmp)

In [None]:
proposed_cutoff_1 = 3.77
n_rows, n_cols = 8, 4

fig, axes = plt.subplots(
    n_rows,
    n_cols,
    figsize=(24, 14),  # wider than tall (see below)
    sharex=True,
    sharey=True
)

axes = axes.flatten()

for i in range(len(gpm_keys)):
    pair_key = gpm_keys[i]
    envs = grain_pair_mapping[pair_key]
    ax = axes[i]

    # ðŸ‘‡ THIS IS THE KEY LINE
    plt.sca(ax)

    plot_rdf_grain_pair(
        envs,
        bins=100,
        r_max=11.0,
        grain_pair_label=pair_key
    )

    ax.axvline(
        x=proposed_cutoff_1,
        color='red',
        linestyle='--',
        linewidth=1.5
    )

In [None]:
n_rows, n_cols = 8, 4
fig, axes = plt.subplots(
    n_rows, n_cols,
    figsize=(28, 12)
)
axes = axes.flatten()

for i in range(len(gpm_keys)):
    pair_key = gpm_keys[i]
    envs = grain_pair_mapping[pair_key]

    # Let the function do its thing
    fig_tmp, ax_tmp = plot_rdf_grain_pair(
        envs,
        bins=100,
        r_max=11.0,
        grain_pair_label=pair_key
    )

    ax_tmp.axvline(
        x=proposed_cutoff_1,
        color="red",
        linestyle="--",
        linewidth=1.5
    )

#    # --- render figure to array ---
#    fig_tmp.canvas.draw()
#    w, h = fig_tmp.canvas.get_width_height()
#    buf = np.frombuffer(fig_tmp.canvas.buffer_rgba(), dtype=np.uint8)
#    img = buf.reshape(h, w, 4)
#
#    # --- show inside target subplot ---
#    axes[i].imshow(img)
#    axes[i].axis("off")
#
#    plt.close(fig_tmp)
#
## Hide unused panel
#axes[-1].axis("off")
#
#fig.tight_layout()

    # --- render figure ---
    fig_tmp.canvas.draw()
    renderer = fig_tmp.canvas.get_renderer()

    # âœ… get tight bounding box of AXES ONLY
    bbox = ax_tmp.get_tightbbox(renderer).transformed(
        fig_tmp.dpi_scale_trans.inverted()
    )

    # --- extract RGBA buffer ---
    w, h = fig_tmp.canvas.get_width_height()
    buf = np.frombuffer(fig_tmp.canvas.buffer_rgba(), dtype=np.uint8)
    img = buf.reshape(h, w, 4)

    # --- convert bbox (in inches) to pixels ---
    dpi = fig_tmp.dpi
    x0, y0, x1, y1 = bbox.extents
    x0, x1 = int(x0 * dpi), int(x1 * dpi)
    y0, y1 = int(y0 * dpi), int(y1 * dpi)

    cropped = img[y0:y1, x0:x1]

    # --- display cropped axes ---
    axes[i].imshow(cropped)
    axes[i].axis("off")

    plt.close(fig_tmp)

# Hide unused panel
axes[-1].axis("off")

fig.subplots_adjust(wspace=0.05, hspace=0.1)

plt.show()

In [None]:
def updated_plot_rdf_grain_pair(
    environments_with_distances,
    bins=50,
    r_max=None,
    figsize=(10, 6),
    grain_pair_label=None,
    ax=None,
):
    """
    Plot aggregate RDF for all environments in a grain pair set.
    """

    # Collect all distances from all environments
    all_distances = []
    for env in environments_with_distances:
        if isinstance(env["distances"], dict):
            all_distances.extend(env["distances"].values())
        else:
            all_distances.extend(env["distances"])

    # --- figure / axes handling ---
    small_title = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure
        small_title = True


    # Set r_max
    if r_max is None:
        r_max = max(all_distances)

    # Histogram
    counts, bin_edges = np.histogram(
        all_distances, bins=bins, range=(0, r_max)
    )
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Plot
    ax.plot(bin_centers, counts, linewidth=2, color="darkblue")
    ax.fill_between(bin_centers, counts, alpha=0.3, color="skyblue")

    # Labels
    if not small_title:
        ax.set_xlabel("Distance (Ã…)", fontsize=10)
        ax.set_ylabel("Total Count", fontsize=10)

    # Title
    if grain_pair_label:
        if not small_title:
            title = (
                f"Aggregate RDF for Grain Pair {grain_pair_label}\n"
                f"({len(environments_with_distances)} environments)"
            )
        else:
            title= f"{grain_pair_label}"
    else:
        title = f"Aggregate RDF\n({len(environments_with_distances)} environments)"

    if not small_title:
        ax.set_title(title, fontsize=11, fontweight="bold")
    else:
        ax.set_title(title, fontsize=8, fontweight="bold")


    ax.grid(True, alpha=0.3, linestyle=":")

    # Only auto-layout if we created the figure
    if ax is None:
        plt.tight_layout()

    return fig, ax

In [None]:
#FIRST_SHELL_CUTOFF = 3.77
#FIRST_SHELL_CUTOFF = 3.6
FIRST_SHELL_CUTOFF = 2.6

n_rows, n_cols = 8, 4
fig, axes = plt.subplots(
    n_rows,
    n_cols,
    figsize=(28, 12),
    sharex=True,
    sharey=False,
)

axes = axes.flatten()

for i in range(len(gpm_keys)):
    pair_key = gpm_keys[i]
    envs = grain_pair_mapping[pair_key]
    ax = axes[i]

    updated_plot_rdf_grain_pair(
        envs,
        bins=100,
        r_max=11.0,
        grain_pair_label=pair_key,
        ax=ax,
    )

    # First-shell cutoff
    ax.axvline(
        x=FIRST_SHELL_CUTOFF,
        color="red",
        linestyle="--",
        linewidth=1.2,
    )

    # Reduce tick clutter
    ax.tick_params(labelsize=8)

# Hide the unused 32nd subplot
axes[-1].axis("off")

# Shared labels
fig.supxlabel("Distance (Ã…)", fontsize=14)
fig.supylabel("Total Count", fontsize=14)

fig.subplots_adjust(
    wspace=0.25,
    hspace=0.35,
)

plt.show()