In [None]:
import json
import numpy as np
from matplotlib import pyplot as plt
from ase.io import read as ase_read
from collections import defaultdict
import copy

In [None]:
input_xyz = "./jp_dio-orig_min4.xyz"
atoms = ase_read(input_xyz)

In [None]:
with open("./grain_pair_to_envs_mapping_max60pcnt.json", "r") as f:
    grain_pair_mapping = json.load(f)

print(f"Loaded grain pair mapping with {len(grain_pair_mapping)} pairs")
print(f"Total environments: {sum(len(envs) for envs in grain_pair_mapping.values())}")

# Show some examples
print("\nExample grain pairs and counts:")
for i, (pair_key, envs) in enumerate(sorted(grain_pair_mapping.items(),
                                            key=lambda x: len(x[1]),
                                            reverse=True)[:5]):
    print(f"  {pair_key}: {len(envs)} environments")

In [None]:
# Load PTM and grain labels

# Load index mapping
with open("./noOidx2orig.json", "r") as f:
    index_map = json.load(f)

# Reverse the mapping (orig -> noO)
index_map = {int(v): int(k) for k, v in index_map.items()}

# Load grain and PTM data
grain_ptm_data = np.load("./grains_ptm_111025_min4_fixed.npz")
noO_grains = grain_ptm_data["grains"]
noO_ptm_types = grain_ptm_data["ptm_types"]

xyz_ptm_types = np.full(len(atoms), -1, dtype=int)

# Map Hf atoms
for i, atm in enumerate(atoms):
    if atm.symbol == "O":
        continue
    xyz_ptm_types[i] = noO_ptm_types[index_map[i]]

In [None]:
# Compute pairwise distances with periodic boundary conditions
# FRAGILE FRAGILE FRAGILE - only works for orthorhombic
def compute_distances_pbc(atoms, central_idx, neighbor_idxs):
    """
    Compute distances between a central atom and its neighbors,
    accounting for periodic boundary conditions.

    Parameters:
    -----------
    atoms : ase.Atoms
        The atomic structure
    central_idx : int
        Index of the central atom
    neighbor_idxs : list of int
        Indices of neighbor atoms

    Returns:
    --------
    dict
        Keys are neighbor indices, values are distances (in Angstroms)
    """
    central_pos = atoms[central_idx].position
    cell = atoms.cell

    distances = {}
    for nidx in neighbor_idxs:
        neighbor_pos = atoms[nidx].position

        # Compute displacement vector
        delta = neighbor_pos - central_pos

        # orthorhombic assertion that really should just be done once
        for i in range(3):
            for j in range(3):
                if i == j:
                    continue
                else:
                    assert np.isclose(cell[i,j],0.0,atol=1e-12)

        # Apply minimum image convention for PBC
        # For orthorhombic cell: wrap to [-L/2, L/2]
        for i in range(3):
            if atoms.pbc[i]:
                cell_length = cell[i, i]
                delta[i] -= cell_length * np.round(delta[i] / cell_length)

        # Compute distance
        distance = np.linalg.norm(delta)
        distances[nidx] = distance

    return distances

print("Distance computation function defined")

In [None]:
def compute_Hf_mask(env, atoms):
    indices = env["neighbor_idxs"]
    Hf_mask = np.array([True if atoms[i].symbol=="Hf" else False for i in indices ])
    return Hf_mask

In [None]:
for pair_key, envs in grain_pair_mapping.items():
    print(pair_key)
    for env in envs:
        distances = compute_distances_pbc(atoms, env['index'], env['neighbor_idxs'])
        Hf_mask = compute_Hf_mask(env, atoms)

        env["distances"] = distances
        env["Hf_mask"] = Hf_mask

In [None]:
gpm_keys = list(grain_pair_mapping.keys())
len(gpm_keys)

In [None]:
# Identify first NN shell and compute first_shell_fract_hcp_c
# Need to compute distances and add to envs first before using

def compute_first_shell_fract_hcp_c(env, atoms, xyz_ptm_types,
                                    first_shell_cutoff=3.0):
    """
    Identify first nearest neighbor shell and compute first_shell_fract_hcp_c.

    Parameters:
    -----------
    env : dict
        Environment dictionary
    atoms : ase.Atoms
        The atomic structure
    xyz_ptm_types : array
        PTM type labels for all atoms
    first_shell_cutoff : float
        Distance cutoff for first NN shell (Angstroms)

    Returns:
    --------
    tuple of (first_shell_neighbor_idxs, first_shell_fract_hcp_c)
        first_shell_neighbor_idxs : list of int
            Indices of atoms in the first NN shell
        first_shell_fract_hcp_c : float
            Fraction of first shell Hf neighbors that are NOT HCP (i.e., HCP complement)
    """
    # Get neighbors within cutoff distance
    distances = env["distances"]
    first_shell_idxs = [idx for idx, dist in distances.items()
                       if dist <= first_shell_cutoff]

    # Count Hf atoms and their PTM types in first shell
    hf_count = 0
    hcp_count = 0

    for nidx in first_shell_idxs:
        if atoms[nidx].symbol == "O":
            continue

        hf_count += 1
        ptm_type = xyz_ptm_types[nidx]

        # PTM type 2 is HCP
        if ptm_type == 2:
            hcp_count += 1

    # Compute fraction HCP complement
    if hf_count > 0:
        fract_hcp = hcp_count / hf_count
        fract_hcp_c = 1.0 - fract_hcp
    else:
        fract_hcp_c = 0.0

    return first_shell_idxs, fract_hcp_c

In [None]:
FIRST_SHELL_CUTOFF = 2.6
grain_pair_mapping_2p6 = {}

total_envs = sum(len(envs) for envs in grain_pair_mapping.values())
print(f"Processing {total_envs} environments across {len(grain_pair_mapping)} grain pairs...")
print(f"First shell cutoff: {FIRST_SHELL_CUTOFF} Å")

processed_count = 0
for pair_key, envs in grain_pair_mapping.items():
    updated_envs = []

    for env in envs:

        #Compute first shell properties
        first_shell_idxs, first_shell_fract_hcp_c = compute_first_shell_fract_hcp_c(
            env, atoms, xyz_ptm_types, first_shell_cutoff=FIRST_SHELL_CUTOFF
        )

        # Create updated environment dict
        updated_env = copy.deepcopy(env)
        updated_env['first_shell_neighbor_idxs'] = first_shell_idxs
        updated_env['first_shell_fract_hcp_c'] = first_shell_fract_hcp_c

        updated_envs.append(updated_env)

        processed_count += 1
        if processed_count % 100 == 0:
            print(f"  Processed {processed_count}/{total_envs} environments...", end='\r')

    grain_pair_mapping_2p6[pair_key] = updated_envs

print(f"\nCompleted processing {processed_count} environments")

In [None]:
grain_pair_mapping_2p6[gpm_keys[1]]

In [None]:
def to_jsonable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, np.generic):
        return obj.item()
    if isinstance(obj, dict):
        return {k: to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [to_jsonable(v) for v in obj]
    return obj

In [None]:
with open("grain_pair_mapping_2p6.json", "w") as f:
    json.dump(to_jsonable(grain_pair_mapping_2p6),f)

In [None]:
FIRST_SHELL_CUTOFF = 3.6
grain_pair_mapping_3p6 = {}

total_envs = sum(len(envs) for envs in grain_pair_mapping.values())
print(f"Processing {total_envs} environments across {len(grain_pair_mapping)} grain pairs...")
print(f"First shell cutoff: {FIRST_SHELL_CUTOFF} Å")

processed_count = 0
for pair_key, envs in grain_pair_mapping.items():
    updated_envs = []

    for env in envs:

        #Compute first shell properties
        first_shell_idxs, first_shell_fract_hcp_c = compute_first_shell_fract_hcp_c(
            env, atoms, xyz_ptm_types, first_shell_cutoff=FIRST_SHELL_CUTOFF
        )

        # Create updated environment dict
        updated_env = copy.deepcopy(env)
        updated_env['first_shell_neighbor_idxs'] = first_shell_idxs
        updated_env['first_shell_fract_hcp_c'] = first_shell_fract_hcp_c

        updated_envs.append(updated_env)

        processed_count += 1
        if processed_count % 100 == 0:
            print(f"  Processed {processed_count}/{total_envs} environments...", end='\r')

    grain_pair_mapping_3p6[pair_key] = updated_envs

print(f"\nCompleted processing {processed_count} environments")

In [None]:
with open("grain_pair_mapping_3p6.json", "w") as f:
    json.dump(to_jsonable(grain_pair_mapping_3p6),f)

In [None]:
import math
import numpy as np


def select_top_first_shell_envs(
    grain_pair_mapping,
    top_fraction=0.10,
    top_n=None,
    value_key="first_shell_fract_hcp_c",
):
    """
    Select top environments per grain pair based on first-shell HCP-C fraction.

    Parameters
    ----------
    grain_pair_mapping : dict[str, list[dict]]
        Mapping from grain-pair key to list of environment dicts
    top_fraction : float, optional
        Fraction (0 < top_fraction <= 1) of environments to select per grain pair.
        Ignored if top_n is provided.
    top_n : int or None, optional
        Absolute number of environments to select per grain pair.
        If provided, overrides top_fraction.
    value_key : str
        Key used to rank environments.

    Returns
    -------
    selected_grain_pair_mapping : dict[str, dict]
        Each value is a dict with:
            - "selected_envs": list[dict]
            - "mean_first_shell_fract_hcp_c": float
    """
    if top_n is None:
        if not (0.0 < top_fraction <= 1.0):
            raise ValueError("top_fraction must be in (0, 1].")
    else:
        if top_n <= 0:
            raise ValueError("top_n must be a positive integer.")

    selected_grain_pair_mapping = {}

    for pair_key, envs in grain_pair_mapping.items():

        # Filter out envs without a valid value
        valid_envs = [
            env for env in envs
            if value_key in env and env[value_key] is not None
        ]

        if len(valid_envs) == 0:
            selected_grain_pair_mapping[pair_key] = {
                "selected_envs": [],
                "mean_first_shell_fract_hcp_c": np.nan,
            }
            continue

        # Sort descending by value
        sorted_envs = sorted(
            valid_envs,
            key=lambda env: env[value_key],
            reverse=True,
        )

        # Determine how many to select
        if top_n is not None:
            n_select = min(top_n, len(sorted_envs))
        else:
            n_select = max(
                1,
                math.ceil(top_fraction * len(sorted_envs)),
            )

        selected_envs = sorted_envs[:n_select]

        mean_val = float(
            np.mean([env[value_key] for env in selected_envs])
        )

        selected_grain_pair_mapping[pair_key] = {
            "selected_envs": selected_envs,
            "mean_first_shell_fract_hcp_c": mean_val,
        }

    return selected_grain_pair_mapping

In [None]:
top_10_pcnt_envs_2p6 = select_top_first_shell_envs(grain_pair_mapping_2p6, top_fraction=0.1)
top_10_pcnt_envs_3p6 = select_top_first_shell_envs(grain_pair_mapping_3p6, top_fraction=0.1)

In [None]:
with open("top_10_pcnt_grain_pair_mapping_2p6.json", "w") as f:
    json.dump(to_jsonable(top_10_pcnt_envs_2p6),f)

with open("top_10_pcnt_grain_pair_mapping_3p6.json", "w") as f:
    json.dump(to_jsonable(top_10_pcnt_envs_3p6),f)
