In [None]:
import json
import numpy as np
from matplotlib import pyplot as plt
from ase.io import read as ase_read
from collections import defaultdict

In [None]:
input_xyz = "./jp_dio-orig_min4.xyz"
atoms = ase_read(input_xyz)

In [None]:
with open("./grain_pair_to_envs_mapping_max60pcnt.json", "r") as f:
    grain_pair_mapping = json.load(f)

print(f"Loaded grain pair mapping with {len(grain_pair_mapping)} pairs")
print(f"Total environments: {sum(len(envs) for envs in grain_pair_mapping.values())}")

# Show some examples
print("\nExample grain pairs and counts:")
for i, (pair_key, envs) in enumerate(sorted(grain_pair_mapping.items(),
                                            key=lambda x: len(x[1]),
                                            reverse=True)[:5]):
    print(f"  {pair_key}: {len(envs)} environments")

In [None]:
# Load PTM and grain labels

# Load index mapping
with open("./noOidx2orig.json", "r") as f:
    index_map = json.load(f)

# Reverse the mapping (orig -> noO)
index_map = {int(v): int(k) for k, v in index_map.items()}

# Load grain and PTM data
grain_ptm_data = np.load("./grains_ptm_111025_min4_fixed.npz")
noO_grains = grain_ptm_data["grains"]
noO_ptm_types = grain_ptm_data["ptm_types"]

xyz_ptm_types = np.full(len(atoms), -1, dtype=int)

# Map Hf atoms
for i, atm in enumerate(atoms):
    if atm.symbol == "O":
        continue
    xyz_ptm_types[i] = noO_ptm_types[index_map[i]]

In [None]:
# Compute pairwise distances with periodic boundary conditions
# FRAGILE FRAGILE FRAGILE - only works for orthorhombic
def compute_distances_pbc(atoms, central_idx, neighbor_idxs):
    """
    Compute distances between a central atom and its neighbors,
    accounting for periodic boundary conditions.

    Parameters:
    -----------
    atoms : ase.Atoms
        The atomic structure
    central_idx : int
        Index of the central atom
    neighbor_idxs : list of int
        Indices of neighbor atoms

    Returns:
    --------
    dict
        Keys are neighbor indices, values are distances (in Angstroms)
    """
    central_pos = atoms[central_idx].position
    cell = atoms.cell

    distances = {}
    for nidx in neighbor_idxs:
        neighbor_pos = atoms[nidx].position

        # Compute displacement vector
        delta = neighbor_pos - central_pos

        # orthorhombic assertion that really should just be done once
        for i in range(3):
            for j in range(3):
                if i == j:
                    continue
                else:
                    assert np.isclose(cell[i,j],0.0,atol=1e-12)

        # Apply minimum image convention for PBC
        # For orthorhombic cell: wrap to [-L/2, L/2]
        for i in range(3):
            if atoms.pbc[i]:
                cell_length = cell[i, i]
                delta[i] -= cell_length * np.round(delta[i] / cell_length)

        # Compute distance
        distance = np.linalg.norm(delta)
        distances[nidx] = distance

    return distances

print("Distance computation function defined")

In [None]:
# Plot RDF for a single environment

def plot_rdf_single_env(distances, bins=50, r_max=None, figsize=(10, 6), title=None):
    """
    Plot radial distribution function for a single O atom environment.

    Parameters:
    -----------
    distances : dict or list
        Dictionary mapping neighbor indices to distances, or list of distances
    bins : int
        Number of bins for histogram (default: 50)
    r_max : float or None
        Maximum distance for plotting (default: None, uses max distance)
    figsize : tuple
        Figure size (width, height)
    title : str or None
        Plot title

    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    # Convert to list if dict
    if isinstance(distances, dict):
        dist_list = list(distances.values())
    else:
        dist_list = distances

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Set r_max
    if r_max is None:
        r_max = max(dist_list)

    # Create histogram
    counts, bin_edges = np.histogram(dist_list, bins=bins, range=(0, r_max))
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Plot
    ax.plot(bin_centers, counts, linewidth=2)
    ax.fill_between(bin_centers, counts, alpha=0.3)

    # Labels and formatting
    ax.set_xlabel('Distance (Å)', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    if title:
        ax.set_title(title, fontsize=14, fontweight='bold')
    else:
        ax.set_title('Radial Distribution Function', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, linestyle=':')

    plt.tight_layout()
    return fig, ax

In [None]:
def plot_rdf_grain_pair(
    environments_with_distances,
    bins=50,
    r_max=None,
    figsize=(10, 6),
    grain_pair_label=None,
    ax=None,
    exclude_O = False,
    exclude_Hf = False
):
    """
    Plot aggregate RDF for all environments in a grain pair set.
    """

    # Collect all distances from all environments
    all_distances = np.array([])
    for env in environments_with_distances:
        # Strongly assuming env["distances"] values is the same order as neighbor_idxs
        env_distances = np.array(list(env["distances"].values()))
        if exclude_O:
            assert not exclude_Hf
            hf_mask = env["Hf_mask"]
            all_distances = np.concatenate((all_distances, env_distances[hf_mask]))
        elif exclude_Hf:
            assert not exclude_O
            o_mask = np.invert(env["Hf_mask"])
            all_distances = np.concatenate((all_distances, env_distances[o_mask]))
            #all_distances.extend(env_distances[o_mask])
        else:
            all_distances = np.concatenate((all_distances, env_distances))
            #all_distances.extend(env_distances)

    # --- figure / axes handling ---
    small_title = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure
        small_title = True


    # Set r_max
    if r_max is None:
        r_max = max(all_distances)

    # Histogram
    counts, bin_edges = np.histogram(
        all_distances, bins=bins, range=(0, r_max)
    )
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Plot
    ax.plot(bin_centers, counts, linewidth=2, color="darkblue")
    ax.fill_between(bin_centers, counts, alpha=0.3, color="skyblue")

    # Labels
    if not small_title:
        ax.set_xlabel("Distance (Å)", fontsize=10)
        ax.set_ylabel("Total Count", fontsize=10)

    # Title
    if grain_pair_label:
        if not small_title:
            title = (
                f"Aggregate RDF for Grain Pair {grain_pair_label}\n"
                f"({len(environments_with_distances)} environments)"
            )
        else:
            title= f"{grain_pair_label}"
    else:
        title = f"Aggregate RDF\n({len(environments_with_distances)} environments)"

    if not small_title:
        ax.set_title(title, fontsize=11, fontweight="bold")
    else:
        ax.set_title(title, fontsize=8, fontweight="bold")


    ax.grid(True, alpha=0.3, linestyle=":")

    # Only auto-layout if we created the figure
    if ax is None:
        plt.tight_layout()

    return fig, ax

In [None]:
# Test the distance computation function on one environment

# Get a sample environment from the first grain pair
first_pair_key = list(grain_pair_mapping.keys())[0]
sample_env = grain_pair_mapping[first_pair_key][0]

print(f"Testing on environment at O atom index: {sample_env['index']}")
print(f"Number of neighbors: {len(sample_env['neighbor_idxs'])}")

# Compute distances
test_distances = compute_distances_pbc(atoms, sample_env['index'], sample_env['neighbor_idxs'])

print(f"Computed {len(test_distances)} distances")
print(f"Min distance: {min(test_distances.values()):.3f} Å")
print(f"Max distance: {max(test_distances.values()):.3f} Å")
print(f"Mean distance: {np.mean(list(test_distances.values())):.3f} Å")

In [None]:
atoms[1].symbol

In [None]:
def compute_Hf_mask(env, atoms):
    indices = env["neighbor_idxs"]
    Hf_mask = np.array([True if atoms[i].symbol=="Hf" else False for i in indices ])
    return Hf_mask

In [None]:
for pair_key, envs in grain_pair_mapping.items():
    print(pair_key)
    for env in envs:
        distances = compute_distances_pbc(atoms, env['index'], env['neighbor_idxs'])
        Hf_mask = compute_Hf_mask(env, atoms)

        env["distances"] = distances
        env["Hf_mask"] = Hf_mask

In [None]:
# This is just to check that distances are in the same order as neighbor_idxs
for pair_key, envs in grain_pair_mapping.items():
    print(pair_key)
    count = 0
    for env in envs:
        neigh_idxs = env["neighbor_idxs"]
        distances_keys = [int(k) for k in env["distances"].keys()]
        assert distances_keys == neigh_idxs
        if count == 0:
            print(f"dkeys: {distances_keys} \nnidxs: {neigh_idxs}")
        count+=1

In [None]:
#sample_env["o_count"]
#sample_env["Hf_mask"]
#len([i for i in range(len(sample_env["Hf_mask"])) if sample_env["Hf_mask"][i]==False])
sample_env["distances"]

In [None]:
gpm_keys = list(grain_pair_mapping.keys())
len(gpm_keys)

i = 1
pair_key = gpm_keys[i]
envs = grain_pair_mapping[pair_key]

fig, ax = plot_rdf_grain_pair(
    envs,
    bins=100,
    r_max=11.0,
    grain_pair_label=pair_key,
    #exclude_O=True,
    exclude_Hf=True
)

In [None]:
FIRST_SHELL_CUTOFF = 3.6
#FIRST_SHELL_CUTOFF = 2.6

n_rows, n_cols = 8, 4
fig, axes = plt.subplots(
    n_rows,
    n_cols,
    figsize=(28, 12),
    sharex=True,
    sharey=False,
)

axes = axes.flatten()

for i in range(len(gpm_keys)):
    pair_key = gpm_keys[i]
    envs = grain_pair_mapping[pair_key]
    ax = axes[i]

    plot_rdf_grain_pair(
        envs,
        bins=100,
        r_max=11.0,
        grain_pair_label=pair_key,
        ax=ax,
        exclude_O=True
    )

    # First-shell cutoff
    ax.axvline(
        x=FIRST_SHELL_CUTOFF,
        color="red",
        linestyle="--",
        linewidth=1.2,
    )

    # Reduce tick clutter
    ax.tick_params(labelsize=8)

# Hide the unused 32nd subplot
axes[-1].axis("off")

# Shared labels
#fig.supxlabel("Distance (Å)", fontsize=14)
#fig.supylabel("Total Count", fontsize=14)

fig.subplots_adjust(
    wspace=0.25,
    hspace=0.35,
)

plt.show()