In [None]:
import json
import numpy as np
from matplotlib import pyplot as plt
from ase.io import read as ase_read
from collections import defaultdict
import copy

In [None]:
input_xyz = "./jp_dio-orig_min4.xyz"
atoms = ase_read(input_xyz)

In [None]:
with open("./grain_pair_to_envs_mapping_max60pcnt.json", "r") as f:
    grain_pair_mapping = json.load(f)

print(f"Loaded grain pair mapping with {len(grain_pair_mapping)} pairs")
print(f"Total environments: {sum(len(envs) for envs in grain_pair_mapping.values())}")

# Show some examples
print("\nExample grain pairs and counts:")
for i, (pair_key, envs) in enumerate(sorted(grain_pair_mapping.items(),
                                            key=lambda x: len(x[1]),
                                            reverse=True)[:5]):
    print(f"  {pair_key}: {len(envs)} environments")

In [None]:
# Load PTM and grain labels

# Load index mapping
with open("./noOidx2orig.json", "r") as f:
    index_map = json.load(f)

# Reverse the mapping (orig -> noO)
index_map = {int(v): int(k) for k, v in index_map.items()}

# Load grain and PTM data
grain_ptm_data = np.load("./grains_ptm_111025_min4_fixed.npz")
noO_grains = grain_ptm_data["grains"]
noO_ptm_types = grain_ptm_data["ptm_types"]

xyz_ptm_types = np.full(len(atoms), -1, dtype=int)

# Map Hf atoms
for i, atm in enumerate(atoms):
    if atm.symbol == "O":
        continue
    xyz_ptm_types[i] = noO_ptm_types[index_map[i]]

In [None]:
# Compute pairwise distances with periodic boundary conditions
# FRAGILE FRAGILE FRAGILE - only works for orthorhombic
def compute_distances_pbc(atoms, central_idx, neighbor_idxs):
    """
    Compute distances between a central atom and its neighbors,
    accounting for periodic boundary conditions.

    Parameters:
    -----------
    atoms : ase.Atoms
        The atomic structure
    central_idx : int
        Index of the central atom
    neighbor_idxs : list of int
        Indices of neighbor atoms

    Returns:
    --------
    dict
        Keys are neighbor indices, values are distances (in Angstroms)
    """
    central_pos = atoms[central_idx].position
    cell = atoms.cell

    distances = {}
    for nidx in neighbor_idxs:
        neighbor_pos = atoms[nidx].position

        # Compute displacement vector
        delta = neighbor_pos - central_pos

        # orthorhombic assertion that really should just be done once
        for i in range(3):
            for j in range(3):
                if i == j:
                    continue
                else:
                    assert np.isclose(cell[i,j],0.0,atol=1e-12)

        # Apply minimum image convention for PBC
        # For orthorhombic cell: wrap to [-L/2, L/2]
        for i in range(3):
            if atoms.pbc[i]:
                cell_length = cell[i, i]
                delta[i] -= cell_length * np.round(delta[i] / cell_length)

        # Compute distance
        distance = np.linalg.norm(delta)
        distances[nidx] = distance

    return distances

print("Distance computation function defined")

In [None]:
def plot_rdf_grain_pair(
    environments_with_distances,
    bins=50,
    r_max=None,
    figsize=(10, 6),
    grain_pair_label=None,
    ax=None,
    exclude_O = False,
    exclude_Hf = False
):
    """
    Plot aggregate RDF for all environments in a grain pair set.
    """

    # Collect all distances from all environments
    all_distances = np.array([])
    for env in environments_with_distances:
        # Strongly assuming env["distances"] values is the same order as neighbor_idxs
        env_distances = np.array(list(env["distances"].values()))
        if exclude_O:
            assert not exclude_Hf
            hf_mask = env["Hf_mask"]
            all_distances = np.concatenate((all_distances, env_distances[hf_mask]))
        elif exclude_Hf:
            assert not exclude_O
            o_mask = np.invert(env["Hf_mask"])
            all_distances = np.concatenate((all_distances, env_distances[o_mask]))
            #all_distances.extend(env_distances[o_mask])
        else:
            all_distances = np.concatenate((all_distances, env_distances))
            #all_distances.extend(env_distances)

    # --- figure / axes handling ---
    small_title = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure
        small_title = True


    # Set r_max
    if r_max is None:
        r_max = max(all_distances)

    # Histogram
    counts, bin_edges = np.histogram(
        all_distances, bins=bins, range=(0, r_max)
    )
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Plot
    ax.plot(bin_centers, counts, linewidth=2, color="darkblue")
    ax.fill_between(bin_centers, counts, alpha=0.3, color="skyblue")

    # Labels
    if not small_title:
        ax.set_xlabel("Distance (Å)", fontsize=10)
        ax.set_ylabel("Total Count", fontsize=10)

    # Title
    if grain_pair_label:
        if not small_title:
            title = (
                f"Aggregate RDF for Grain Pair {grain_pair_label}\n"
                f"({len(environments_with_distances)} environments)"
            )
        else:
            title= f"{grain_pair_label}"
    else:
        title = f"Aggregate RDF\n({len(environments_with_distances)} environments)"

    if not small_title:
        ax.set_title(title, fontsize=11, fontweight="bold")
    else:
        ax.set_title(title, fontsize=8, fontweight="bold")


    ax.grid(True, alpha=0.3, linestyle=":")

    # Only auto-layout if we created the figure
    if ax is None:
        plt.tight_layout()

    return fig, ax

In [None]:
# Test the distance computation function on one environment

# Get a sample environment from the first grain pair
first_pair_key = list(grain_pair_mapping.keys())[0]
sample_env = grain_pair_mapping[first_pair_key][0]

print(f"Testing on environment at O atom index: {sample_env['index']}")
print(f"Number of neighbors: {len(sample_env['neighbor_idxs'])}")

# Compute distances
test_distances = compute_distances_pbc(atoms, sample_env['index'], sample_env['neighbor_idxs'])

print(f"Computed {len(test_distances)} distances")
print(f"Min distance: {min(test_distances.values()):.3f} Å")
print(f"Max distance: {max(test_distances.values()):.3f} Å")
print(f"Mean distance: {np.mean(list(test_distances.values())):.3f} Å")

In [None]:
def compute_Hf_mask(env, atoms):
    indices = env["neighbor_idxs"]
    Hf_mask = np.array([True if atoms[i].symbol=="Hf" else False for i in indices ])
    return Hf_mask

In [None]:
for pair_key, envs in grain_pair_mapping.items():
    print(pair_key)
    for env in envs:
        distances = compute_distances_pbc(atoms, env['index'], env['neighbor_idxs'])
        Hf_mask = compute_Hf_mask(env, atoms)

        env["distances"] = distances
        env["Hf_mask"] = Hf_mask

In [None]:
gpm_keys = list(grain_pair_mapping.keys())
len(gpm_keys)

In [None]:
FIRST_SHELL_CUTOFF = 3.6
#FIRST_SHELL_CUTOFF = 2.6

n_rows, n_cols = 8, 4
fig, axes = plt.subplots(
    n_rows,
    n_cols,
    figsize=(28, 12),
    sharex=True,
    sharey=False,
)

axes = axes.flatten()

for i in range(len(gpm_keys)):
    pair_key = gpm_keys[i]
    envs = grain_pair_mapping[pair_key]
    ax = axes[i]

    plot_rdf_grain_pair(
        envs,
        bins=100,
        r_max=11.0,
        grain_pair_label=pair_key,
        ax=ax,
        exclude_O=True
    )

    # First-shell cutoff
    ax.axvline(
        x=FIRST_SHELL_CUTOFF,
        color="red",
        linestyle="--",
        linewidth=1.2,
    )

    # Reduce tick clutter
    ax.tick_params(labelsize=8)

# Hide the unused 32nd subplot
axes[-1].axis("off")

# Shared labels
#fig.supxlabel("Distance (Å)", fontsize=14)
#fig.supylabel("Total Count", fontsize=14)

fig.subplots_adjust(
    wspace=0.25,
    hspace=0.35,
)

plt.show()

In [None]:
# Identify first NN shell and compute first_shell_fract_hcp_c
# Need to compute distances and add to envs first before using

def compute_first_shell_fract_hcp_c(env, atoms, xyz_ptm_types,
                                    first_shell_cutoff=3.0):
    """
    Identify first nearest neighbor shell and compute first_shell_fract_hcp_c.

    Parameters:
    -----------
    env : dict
        Environment dictionary
    atoms : ase.Atoms
        The atomic structure
    xyz_ptm_types : array
        PTM type labels for all atoms
    first_shell_cutoff : float
        Distance cutoff for first NN shell (Angstroms)

    Returns:
    --------
    tuple of (first_shell_neighbor_idxs, first_shell_fract_hcp_c)
        first_shell_neighbor_idxs : list of int
            Indices of atoms in the first NN shell
        first_shell_fract_hcp_c : float
            Fraction of first shell Hf neighbors that are NOT HCP (i.e., HCP complement)
    """
    # Get neighbors within cutoff distance
    distances = env["distances"]
    first_shell_idxs = [idx for idx, dist in distances.items()
                       if dist <= first_shell_cutoff]

    # Count Hf atoms and their PTM types in first shell
    hf_count = 0
    hcp_count = 0

    for nidx in first_shell_idxs:
        if atoms[nidx].symbol == "O":
            continue

        hf_count += 1
        ptm_type = xyz_ptm_types[nidx]

        # PTM type 2 is HCP
        if ptm_type == 2:
            hcp_count += 1

    # Compute fraction HCP complement
    if hf_count > 0:
        fract_hcp = hcp_count / hf_count
        fract_hcp_c = 1.0 - fract_hcp
    else:
        fract_hcp_c = 0.0

    return first_shell_idxs, fract_hcp_c

In [None]:
#FIRST_SHELL_CUTOFF = 2.6
#updated_grain_pair_mapping = grain_pair_mapping_2p6 = {}
FIRST_SHELL_CUTOFF = 3.6
updated_grain_pair_mapping = grain_pair_mapping_3p6 = {}

total_envs = sum(len(envs) for envs in grain_pair_mapping.values())
print(f"Processing {total_envs} environments across {len(grain_pair_mapping)} grain pairs...")
print(f"First shell cutoff: {FIRST_SHELL_CUTOFF} Å")

processed_count = 0
for pair_key, envs in grain_pair_mapping.items():
    updated_envs = []

    for env in envs:

        #Compute first shell properties
        first_shell_idxs, first_shell_fract_hcp_c = compute_first_shell_fract_hcp_c(
            env, atoms, xyz_ptm_types, first_shell_cutoff=FIRST_SHELL_CUTOFF
        )

        # Create updated environment dict
        updated_env = copy.deepcopy(env)
        updated_env['first_shell_neighbor_idxs'] = first_shell_idxs
        updated_env['first_shell_fract_hcp_c'] = first_shell_fract_hcp_c

        updated_envs.append(updated_env)

        processed_count += 1
        if processed_count % 100 == 0:
            print(f"  Processed {processed_count}/{total_envs} environments...", end='\r')

    updated_grain_pair_mapping[pair_key] = updated_envs

print(f"\nCompleted processing {processed_count} environments")

Very strange that when you run this, you only the last "Processed" string, i.e. Processed 7900/7994 environments.. 
Oh it's a carriage return
also with the copy.deepcopy() it's now slow enough that it's obvious what's happening

In [None]:
#grain_pair_mapping[gpm_keys[1]][0]
grain_pair_mapping_2p6[gpm_keys[1]][0]

In [None]:
gpm_keys[1]

In [None]:
def extract_first_shell_fract_hcp_c(envs):
    """
    Extract first_shell_fract_hcp_c values from a list of environment dicts.
    Skips missing or None values.
    """
    return [
        env["first_shell_fract_hcp_c"]
        for env in envs
        if "first_shell_fract_hcp_c" in env
        and env["first_shell_fract_hcp_c"] is not None
    ]

def plot_first_shell_fract_hcp_c_hist(
    envs,
    bins=50,
    ax=None,
    grain_pair_label=None,
    density=False,
):
    """
    Plot a histogram of first_shell_fract_hcp_c values.

    Parameters
    ----------
    envs : list[dict]
        List of environment dicts
    bins : int
        Number of histogram bins
    ax : matplotlib.axes.Axes or None
        Axis to plot on (creates one if None)
    grain_pair_label : str or None
        Label/title for the subplot
    density : bool
        Whether to normalize the histogram
    """
    values = extract_first_shell_fract_hcp_c(envs)

    no_labels = True
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
        no_labels = False

    ax.hist(
        values,
        bins=bins,
        density=density,
        alpha=0.75,
        edgecolor="black",
    )
    if not no_labels:
        ax.set_xlabel("First-shell fraction HCP-C")
        ax.set_ylabel("Count" if not density else "Density")

    if grain_pair_label is not None:
        ax.set_title(grain_pair_label, fontsize=10)

    return ax


In [None]:
#modified_grain_pair_map = grain_pair_mapping_2p6
modified_grain_pair_map = grain_pair_mapping_3p6

pair_key = gpm_keys[0]

hist_envs = modified_grain_pair_map[pair_key]

plot_first_shell_fract_hcp_c_hist(
    hist_envs,
    bins=40,
    grain_pair_label=pair_key,
)

plt.show()

In [None]:
modified_grain_pair_map = grain_pair_mapping_2p6
#modified_grain_pair_map = grain_pair_mapping_3p6

n_rows, n_cols = 8, 4
fig, axes = plt.subplots(
    n_rows,
    n_cols,
    figsize=(28, 12),
    sharex=True,
    sharey=False,
)

axes = axes.flatten()

for i, pair_key in enumerate(gpm_keys):
    envs = modified_grain_pair_map[pair_key]
    ax = axes[i]

    plot_first_shell_fract_hcp_c_hist(
        envs,
        bins=40,
        grain_pair_label=pair_key,
        ax=ax,
    )

    # Reduce tick clutter
    ax.tick_params(labelsize=8)

# Hide unused axes
for ax in axes[len(gpm_keys):]:
    ax.axis("off")

fig.subplots_adjust(
    wspace=0.25,
    hspace=0.35,
)

plt.show()

My sense is that the 2.6 A shell will be sufficient here. 

In [None]:
import math
import numpy as np


def select_top_first_shell_envs(
    grain_pair_mapping,
    top_fraction=0.10,
    top_n=None,
    value_key="first_shell_fract_hcp_c",
):
    """
    Select top environments per grain pair based on first-shell HCP-C fraction.

    Parameters
    ----------
    grain_pair_mapping : dict[str, list[dict]]
        Mapping from grain-pair key to list of environment dicts
    top_fraction : float, optional
        Fraction (0 < top_fraction <= 1) of environments to select per grain pair.
        Ignored if top_n is provided.
    top_n : int or None, optional
        Absolute number of environments to select per grain pair.
        If provided, overrides top_fraction.
    value_key : str
        Key used to rank environments.

    Returns
    -------
    selected_grain_pair_mapping : dict[str, dict]
        Each value is a dict with:
            - "selected_envs": list[dict]
            - "mean_first_shell_fract_hcp_c": float
    """
    if top_n is None:
        if not (0.0 < top_fraction <= 1.0):
            raise ValueError("top_fraction must be in (0, 1].")
    else:
        if top_n <= 0:
            raise ValueError("top_n must be a positive integer.")

    selected_grain_pair_mapping = {}

    for pair_key, envs in grain_pair_mapping.items():

        # Filter out envs without a valid value
        valid_envs = [
            env for env in envs
            if value_key in env and env[value_key] is not None
        ]

        if len(valid_envs) == 0:
            selected_grain_pair_mapping[pair_key] = {
                "selected_envs": [],
                "mean_first_shell_fract_hcp_c": np.nan,
            }
            continue

        # Sort descending by value
        sorted_envs = sorted(
            valid_envs,
            key=lambda env: env[value_key],
            reverse=True,
        )

        # Determine how many to select
        if top_n is not None:
            n_select = min(top_n, len(sorted_envs))
        else:
            n_select = max(
                1,
                math.ceil(top_fraction * len(sorted_envs)),
            )

        selected_envs = sorted_envs[:n_select]

        mean_val = float(
            np.mean([env[value_key] for env in selected_envs])
        )

        selected_grain_pair_mapping[pair_key] = {
            "selected_envs": selected_envs,
            "mean_first_shell_fract_hcp_c": mean_val,
        }

    return selected_grain_pair_mapping

In [None]:
top_10_pcnt_envs = select_top_first_shell_envs(grain_pair_mapping_2p6, top_fraction=0.1)

In [None]:
alt_top_10_pcnt_envs = select_top_first_shell_envs(grain_pair_mapping_3p6, top_fraction=0.1)

In [None]:
for k,v in top_10_pcnt_envs.items():
    print(f"{k}: {v["mean_first_shell_fract_hcp_c"]}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt


def plot_mean_first_shell_hcp_c_heatmap(
    selected_grain_pair_mapping,
    grain_ids,
    figsize=(12, 10),
    cmap="viridis",
    upper_triangle_only=True,
    missing_color="lightgray",
    value_fmt="{:.3f}",
):
    """
    Create a heat map of mean first-shell HCP-C fraction for grain pairs.

    Parameters
    ----------
    selected_grain_pair_mapping : dict[str, dict]
        Mapping from grain-pair key (e.g. "(1, 3)") to a dict containing
        "mean_first_shell_fract_hcp_c".
    grain_ids : list[int]
        Sorted list of grain IDs (e.g. [1, 2, ..., 11]).
    figsize : tuple
        Figure size (width, height).
    cmap : str
        Colormap name.
    upper_triangle_only : bool
        If True, only display the upper triangle.
    missing_color : str
        Color used for missing grain-pair entries.
    value_fmt : str
        Format string for cell annotations.
    """
    n = len(grain_ids)
    grain_id_to_idx = {gid: i for i, gid in enumerate(grain_ids)}

    # Initialize matrix with NaNs (missing pairs)
    value_matrix = np.full((n, n), np.nan)

    # Fill matrix from dict
    for pair_key, data in selected_grain_pair_mapping.items():
        i, j = eval(pair_key) if isinstance(pair_key, str) else pair_key
        if i not in grain_id_to_idx or j not in grain_id_to_idx:
            continue

        row = grain_id_to_idx[i]
        col = grain_id_to_idx[j]

        value_matrix[row, col] = data["mean_first_shell_fract_hcp_c"]
        value_matrix[col, row] = data["mean_first_shell_fract_hcp_c"]

    # Mask lower triangle if requested
    display_matrix = value_matrix.copy()
    if upper_triangle_only:
        for i in range(n):
            for j in range(n):
                if i > j:
                    display_matrix[i, j] = np.nan

    # Create colormap with NaNs colored as missing_color
    cmap_obj = plt.get_cmap(cmap).copy()
    cmap_obj.set_bad(color=missing_color)

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(display_matrix, cmap=cmap_obj, aspect="auto")

    # Set ticks and labels
    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(grain_ids)
    ax.set_yticklabels(grain_ids)

    plt.setp(
        ax.get_xticklabels(),
        rotation=45,
        ha="right",
        rotation_mode="anchor",
    )

    # Annotate cells
    for i in range(n):
        for j in range(n):
            if upper_triangle_only and i > j:
                continue

            val = display_matrix[i, j]
            if np.isnan(val):
                continue

            text_color = "white" if val > np.nanmax(display_matrix) * 0.6 else "black"
            ax.text(
                j,
                i,
                value_fmt.format(val),
                ha="center",
                va="center",
                fontsize=9,
                color=text_color,
                fontweight="bold",
            )

    # Colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label(
        "Mean first-shell fraction HCP-C",
        rotation=270,
        labelpad=20,
        fontsize=12,
    )

    # Labels and title
    ax.set_xlabel("Grain ID", fontsize=12, fontweight="bold")
    ax.set_ylabel("Grain ID", fontsize=12, fontweight="bold")
    title_suffix = " (Upper Triangle)" if upper_triangle_only else ""
    ax.set_title(
        f"Mean First-Shell HCP-C Fraction Heat Map{title_suffix}",
        fontsize=14,
        fontweight="bold",
        pad=20,
    )

    plt.tight_layout()
    return fig, ax

In [None]:
grain_ids = list(range(1, 12))

fig, ax = plot_mean_first_shell_hcp_c_heatmap(
    top_10_pcnt_envs,
    grain_ids,
    cmap="plasma",
)

plt.show()

In [None]:
def plot_mean_first_shell_hcp_c_heatmap(
    selected_grain_pair_mapping,
    grain_ids,
    figsize=(12, 10),
    cmap="viridis",
    missing_color="lightgray",
    diagonal_color="#b0b0b0",
    value_fmt="{:.3f}",
):
    """
    Heat map of mean first-shell HCP-C fraction for grain pairs.

    - Upper triangle only
    - Lower-left triangle left empty
    - Diagonal shown in a distinct gray
    """
    n = len(grain_ids)
    grain_id_to_idx = {gid: i for i, gid in enumerate(grain_ids)}

    # Initialize with NaNs (everything missing by default)
    value_matrix = np.full((n, n), np.nan)

    # Fill upper triangle only
    for pair_key, data in selected_grain_pair_mapping.items():
        i, j = eval(pair_key) if isinstance(pair_key, str) else pair_key
        if i not in grain_id_to_idx or j not in grain_id_to_idx:
            continue

        row = grain_id_to_idx[i]
        col = grain_id_to_idx[j]

        if row < col:
            value_matrix[row, col] = data["mean_first_shell_fract_hcp_c"]

    # Mask lower triangle INCLUDING diagonal (we will draw diagonal manually)
    mask = np.tril(np.ones_like(value_matrix, dtype=bool))
    display_matrix = np.ma.masked_where(mask, value_matrix)

    # Colormap setup
    cmap_obj = plt.get_cmap(cmap).copy()
    cmap_obj.set_bad(color=missing_color)

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(display_matrix, cmap=cmap_obj, aspect="auto")

    # Draw diagonal cells manually
    for i in range(n):
        rect = plt.Rectangle(
            (i - 0.5, i - 0.5),
            1,
            1,
            facecolor=diagonal_color,
            edgecolor="white",
            linewidth=1.0,
        )
        ax.add_patch(rect)

    # Ticks and labels
    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(grain_ids)
    ax.set_yticklabels(grain_ids)

    plt.setp(
        ax.get_xticklabels(),
        rotation=45,
        ha="right",
        rotation_mode="anchor",
    )

    # Annotate only valid upper-triangle values
    vmax = np.nanmax(value_matrix)
    for i in range(n):
        for j in range(i + 1, n):
            val = value_matrix[i, j]
            if np.isnan(val):
                continue

            text_color = "white" if val > 0.6 * vmax else "black"
            ax.text(
                j,
                i,
                value_fmt.format(val),
                ha="center",
                va="center",
                fontsize=9,
                fontweight="bold",
                color=text_color,
            )

    # Colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label(
        "Mean first-shell fraction HCP-C",
        rotation=270,
        labelpad=20,
        fontsize=12,
    )

    # Labels and title
    ax.set_xlabel("Grain ID", fontsize=12, fontweight="bold")
    ax.set_ylabel("Grain ID", fontsize=12, fontweight="bold")
    ax.set_title(
        "Mean First-Shell HCP-C Fraction (Upper Triangle Only)",
        fontsize=14,
        fontweight="bold",
        pad=20,
    )

    plt.tight_layout()
    return fig, ax

In [None]:
grain_ids = list(range(1, 12))

fig, ax = plot_mean_first_shell_hcp_c_heatmap(
    top_10_pcnt_envs,
    grain_ids,
    cmap="viridis",
)

plt.show()

In [None]:
grain_ids = list(range(1, 12))

fig, ax = plot_mean_first_shell_hcp_c_heatmap(
    alt_top_10_pcnt_envs,
    grain_ids,
    cmap="viridis",
)

plt.show()