In [None]:
from ase.io import read as ase_read
from ase.io import write as ase_write
import numpy as np
import copy
import json
from collections import Counter
from matplotlib import pyplot as plt
import os, psutil

In [None]:
p = psutil.Process(os.getpid())
print(f"RSS: {p.memory_info().rss / 1024**2:.1f} MB")

In [None]:
input_xyz = "jp_dio-orig_min4.xyz"
min_orig_atoms = ase_read(input_xyz)

In [None]:
with open("noOidx2orig.json", "r") as f:
    index_map = json.load(f)

# I want to reverse this, i.e. go from orig to noO
index_map = {int(v): int(k) for k,v in index_map.items()}

In [None]:
nnlist_data = np.load("orig_dio_polycrystal_neighborlist10dot4.npz")
nn_i, nn_j, nn_S = nnlist_data["i"] , nnlist_data["j"], nnlist_data["S"]
nn_dict = {}
for k in range(len(nn_i)):
    iidx = int(nn_i[k])
    if iidx not in nn_dict:
        nn_dict[iidx] = []
    nn_dict[iidx].append(int(nn_j[k]))

In [None]:
grain_ptm_data = np.load("grains_ptm_111025_min4_fixed.npz") # notice! using the fixed version now
noO_grains = grain_ptm_data["grains"]
noO_ptm_types = grain_ptm_data["ptm_types"]

In [None]:
xyz_grain_idxs = []
xyz_ptm_types = []

for i,atm in enumerate(min_orig_atoms):
    if atm.symbol == "O":
        xyz_grain_idxs.append(-1)
        xyz_ptm_types.append(-1)
        continue

    xyz_grain_idxs.append(noO_grains[index_map[i]])
    xyz_ptm_types.append(noO_ptm_types[index_map[i]])

In [None]:
def generate_temp_xyz(index):
    min_orig_out = copy.deepcopy(min_orig_atoms)

    isneighbor = np.zeros(len(min_orig_atoms))
    isneighbor[nn_dict[index]] = 1
    isneighbor[index] = 1
    min_orig_out.set_array("is_neighbor", isneighbor)

    new_symbols = min_orig_out.get_chemical_symbols().copy()
    new_symbols[index] = "Np"
    min_orig_out.set_chemical_symbols(new_symbols)

    ref_symbols = min_orig_out.get_chemical_symbols().copy()
    new_symbols = ref_symbols.copy()
    new_symbols[90839] = "Np"
    min_orig_out.set_chemical_symbols(new_symbols)

    min_orig_out.set_array("grain_index", np.array(xyz_grain_idxs))
    min_orig_out.set_array("ptm_type", np.array(xyz_ptm_types))

    ase_write("temp.xyz", min_orig_out, format="extxyz")



In [None]:
o_idxs = [idx for idx in range(len(min_orig_atoms)) if min_orig_atoms[idx].symbol == "O"]

In [None]:
oatom_envs = []
for idx in o_idxs[:]:
    neighbor_idxs = nn_dict[idx]

    local_grains = []
    local_ptm_types = []
    o_count = 0
    for nidx in neighbor_idxs:
        atm = min_orig_atoms[nidx]
        if atm.symbol == "O":
            o_count += 1
            continue
        local_grains.append(int(noO_grains[index_map[nidx]]))
        local_ptm_types.append(int(noO_ptm_types[index_map[nidx]]))
    num_Hf_neighs = len(local_grains)

    num_hcp = local_ptm_types.count(2)
    num_other = local_ptm_types.count(0)
    fract_hcp = num_hcp/num_Hf_neighs
    fract_hcp_c = 1 -fract_hcp
    fract_other = num_other/num_Hf_neighs

    grain_counter = dict(Counter(local_grains))
    grain_fract = {k: v/num_Hf_neighs for k,v in grain_counter.items()}

    oatom_envs.append({"index" : idx,
                       "neighbor_idxs" : neighbor_idxs,
                       "o_count" : o_count,
                       "fract_hcp" : fract_hcp,
                       "fract_hcp_c" : fract_hcp_c,
                       "fract_other": fract_other,
                       "grain_fract": grain_fract})



In [None]:
with open("oatom_envs_jp_dio-orig_min4.json", "w") as f:
    json.dump(oatom_envs, f,indent=2)

In [None]:
#hcpc_bins = [atm_env["fract_hcp_c"] for atm_env in oatom_envs]
hcpc_bins = [atm_env["fract_hcp_c"] for atm_env in oatom_envs if atm_env["fract_hcp_c"]> 0.05]

In [None]:
counts, bins = np.histogram(hcpc_bins, bins=25)
#plt.stairs(counts,bins)
plt.hist(bins[:-1],bins,weights=counts)

In [None]:
oatom_envs[8]["grain_fract"]

In [None]:
np.arange(0,1.1,0.1)

In [None]:
# https://claude.ai/chat/5681ebec-9acf-41e2-9193-b1988d81b436
def preprocess_grain_fractions(environments):
    """
    Convert grain_fract dicts to sorted two-entry lists.

    Parameters:
    -----------
    environments : list of dict
        List of environment dictionaries containing 'grain_fract' key

    Returns:
    --------
    list of lists
        Each inner list contains [largest_fract, second_largest_fract]
    """
    processed = []

    for env in environments:
        grain_fract = env['grain_fract']

        # Get all fraction values and sort in descending order
        fractions = sorted(grain_fract.values(), reverse=True)

        # Create two-entry list
        if len(fractions) == 0:
            print("SOMETHINGS WRONG")
            result = [0.0, 0.0]
        elif len(fractions) == 1:
            result = [fractions[0], 0.0]
        else:
            result = [fractions[0], fractions[1]]

        processed.append(result)

    return processed


def plot_grain_scatter(environments, processed_fractions, figsize=(10, 8), vmin=None, vmax=None):
    """
    Create scatter plot of grain fractions colored by fract_hcp_c.

    Parameters:
    -----------
    environments : list of dict
        Original environment dictionaries
    processed_fractions : list of lists
        Output from preprocess_grain_fractions()
    figsize : tuple
        Figure size (width, height)

    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    # Extract data
    x_vals = [pf[0] for pf in processed_fractions]
    y_vals = [pf[1] for pf in processed_fractions]
    colors = [env['fract_hcp_c'] for env in environments]

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    ## Add constraint line (x + y = 1.0)
    #x_line = np.linspace(0, 1, 100)
    #y_line = 1.0 - x_line
    #ax.plot(x_line, y_line, '--', color='gray', alpha=0.5, linewidth=1,
    #        label='x + y = 1.0')

    # Add constraint lines
    # y = x line (for x <= 0.5)
    x_line1 = np.linspace(0, 0.5, 100)
    y_line1 = x_line1
    ax.plot(x_line1, y_line1, '--', color='gray', alpha=0.5, linewidth=1,
            label='y <= x')

    # x + y = 1.0 line (for x >= 0.5)
    x_line2 = np.linspace(0.5, 1, 100)
    y_line2 = 1.0 - x_line2
    ax.plot(x_line2, y_line2, '--', color='gray', alpha=0.5, linewidth=1,
            label='x + y = 1.0')

    # triplet constraint
    x_line3 = np.linspace(0.3333333, 1, 100)
    y_line3 = 0.5 - 0.5*x_line3
    ax.plot(x_line3, y_line3, '--', color='green', alpha=0.5, linewidth=1,
            label='y>0.5(1-x)')

    # Create scatter plot
    scatter = ax.scatter(x_vals, y_vals, c=colors, cmap='plasma',
                        s=5, alpha=0.6, edgecolors='none', vmin=vmin, vmax=vmax)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('fract_hcp_c', fontsize=12)
    #print(cbar.__dict__)

    # Labels and formatting
    ax.set_xlabel('Largest Grain Fraction', fontsize=12)
    ax.set_ylabel('Second Largest Grain Fraction', fontsize=12)
    ax.set_title('Grain Fraction Distribution (Scatter)', fontsize=14, fontweight='bold')
    xticks = np.arange(0,1.05,0.1)
    yticks = np.arange(0,0.65,0.1)
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_xlim(0, 1.05)
    ax.set_ylim(0, 0.65)
    ax.grid(True, alpha=0.3, linestyle=':')
    ax.legend(loc='upper right')

    plt.tight_layout()
    return fig, ax

def plot_grain_hexbin(processed_fractions, gridsize=30, vmax=None, figsize=(10, 8)):
    """
    Create hexbin plot of grain fractions.

    Parameters:
    -----------
    processed_fractions : list of lists
        Output from preprocess_grain_fractions()
    gridsize : int
        Number of hexagons in the x-direction
    vmax : float or None
        Maximum value for colormap. If None, uses the maximum count.
        Set this to cap the colormap and make lower-count bins more visible.
    figsize : tuple
        Figure size (width, height)

    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    # Extract data
    x_vals = [pf[0] for pf in processed_fractions]
    y_vals = [pf[1] for pf in processed_fractions]

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    ## Add constraint line (x + y = 1.0)
    #x_line = np.linspace(0, 1, 100)
    #y_line = 1.0 - x_line
    #ax.plot(x_line, y_line, '--', color='gray', alpha=0.7, linewidth=1.5,
    #        label='x + y = 1.0', zorder=10)

    # Add constraint lines
    # y = x line (for x <= 0.5)
    x_line1 = np.linspace(0, 0.5, 100)
    y_line1 = x_line1
    ax.plot(x_line1, y_line1, '--', color='gray', alpha=0.5, linewidth=1,
            label='y = x')

    # x + y = 1.0 line (for x >= 0.5)
    x_line2 = np.linspace(0.5, 1, 100)
    y_line2 = 1.0 - x_line2
    ax.plot(x_line2, y_line2, '--', color='gray', alpha=0.5, linewidth=1,
            label='x + y = 1.0')

    # triplet constraint
    x_line3 = np.linspace(0.3333333, 1, 100)
    y_line3 = 0.5 - 0.5*x_line3
    ax.plot(x_line3, y_line3, '--', color='green', alpha=0.5, linewidth=1,
            label='x + y = 1.0')


    # Create hexbin plot
    hexbin = ax.hexbin(x_vals, y_vals, gridsize=gridsize, cmap='Blues',
                       mincnt=1, edgecolors='white', linewidths=0.5, vmax=vmax)

    # Add colorbar
    cbar = plt.colorbar(hexbin, ax=ax)
    cbar.set_label('Count', fontsize=12)

    # Labels and formatting
    ax.set_xlabel('Largest Grain Fraction', fontsize=12)
    ax.set_ylabel('Second Largest Grain Fraction', fontsize=12)
    ax.set_title('Grain Fraction Distribution (Hexbin)', fontsize=14, fontweight='bold')
    ax.set_xlim(0, 1.05)
    ax.set_ylim(0, 1.05)
    ax.legend(loc='upper right')

    plt.tight_layout()
    return fig, ax

In [None]:
processed_grain_fracts = preprocess_grain_fractions(oatom_envs)

In [None]:
fig1, ax1 = plot_grain_scatter(oatom_envs, processed_grain_fracts)

In [None]:
fig2, ax2 = plot_grain_hexbin(processed_grain_fracts, gridsize=100, vmax=1000)

In [None]:

def alt_preprocess_grain_fractions(environments, low_bound=0.0, high_bound=1.0):
    """
    Convert grain_fract dicts to sorted two-entry lists.

    Parameters:
    -----------
    environments : list of dict
        List of environment dictionaries containing 'grain_fract' key

    Returns:
    --------
    list of lists
        Each inner list contains [largest_fract, second_largest_fract]
    """
    processed = []
    processed_idxs = []

    for idx, env in enumerate(environments):
        grain_fract = env['grain_fract']
        fract_hcpc = env['fract_hcp_c']

        if not (low_bound <= fract_hcpc <= high_bound):
            continue

        # Get all fraction values and sort in descending order
        fractions = sorted(grain_fract.values(), reverse=True)

        # Create two-entry list
        if len(fractions) == 0:
            print("SOMETHINGS WRONG")
            result = [0.0, 0.0]
        elif len(fractions) == 1:
            result = [fractions[0], 0.0]
        else:
            result = [fractions[0], fractions[1]]

        processed.append(result)
        processed_idxs.append(idx)

    return processed, processed_idxs

In [None]:
ub = 0.35
lb = 0.3

subset_processed, subset_idxs = alt_preprocess_grain_fractions(oatom_envs,low_bound=lb, high_bound=ub)
oatom_subset = [oatom_envs[i] for i in subset_idxs]

In [None]:
fig3, ax3 = plot_grain_scatter(oatom_subset, subset_processed, vmin=0.0, vmax=0.634)

In [None]:
Need the filtering scripts, i.e. finding roughly 50/50