In [None]:
import json
import numpy as np
from ase.io import read as ase_read
from ase.io import write as ase_write
import copy
from scipy.spatial import cKDTree

In [None]:
with open("oatom_envs_jp_dio-orig_min4.json", "r") as f:
    oatom_envs= json.load(f)

In [None]:
def filter_bulk_like_envs(all_envs):
    filtered_envs = []
    for env in all_envs:
        grain_fract = env["grain_fract"]
        fract_hcp = env["fract_hcp"]

        if len(grain_fract) == 1 and np.isclose(fract_hcp, 1.0, atol=1e-12):
            filtered_envs.append(env)
    return filtered_envs


In [None]:
bulk_like_envs = filter_bulk_like_envs(oatom_envs)

In [None]:
len(bulk_like_envs)

In [None]:
input_xyz = "./jp_dio-orig_min4.xyz"
min_orig_atoms = ase_read(input_xyz)

with open("./noOidx2orig.json", "r") as f:
    noOidx2orig = json.load(f)

# Reverse the mapping: orig index -> noO index
orig2noO = {int(v): int(k) for k, v in noOidx2orig.items()}

grain_ptm_data = np.load("./grains_ptm_111025_min4_fixed.npz")
noO_grains = grain_ptm_data["grains"]
noO_ptm_types = grain_ptm_data["ptm_types"]

xyz_ptm_types = []
for i, atm in enumerate(min_orig_atoms):
    if atm.symbol == "O":
        xyz_ptm_types.append(-1)
    else:
        xyz_ptm_types.append(int(noO_ptm_types[orig2noO[i]]))


In [None]:
non_hcp_hf_idxs = []
for i in range(len(min_orig_atoms)):
    ptm_type = xyz_ptm_types[i]
    if ptm_type == 2 or ptm_type == -1:
        continue
    else:
        non_hcp_hf_idxs.append(i)

In [None]:
len(non_hcp_hf_idxs)

In [None]:
all_positions = min_orig_atoms.get_positions(wrap=True)
non_hcp_hf_positions = all_positions[non_hcp_hf_idxs]

In [None]:
#len(non_hcp_hf_positions)
non_hcp_hf_positions[0]

In [None]:
cell = min_orig_atoms.get_cell()
boxsize = cell.lengths()
boxsize

In [None]:
tree = cKDTree(non_hcp_hf_positions,boxsize=boxsize)

In [None]:
results = []
for env in bulk_like_envs:
    idx = env["index"]
    pos = min_orig_atoms[idx].position
    closest_non_hcp = tree.query(pos)
    results.append(closest_non_hcp)
    #print(closest_non_hcp)

In [None]:
with open("sample_4k_bulk-like_env_idxs.json", "r") as f:
    sample_4k_envs = json.load(f)

In [None]:
# I'm matching the previous result
for i,d in enumerate(sample_4k_envs[:4000]):
    sample_dist = d['min_distance_to_non_hcp']
    ref_dist = results[i][0]
    if not np.isclose(ref_dist,sample_dist,atol=1e-8):
        print(d)
        print(f"{results[i]}\n")


In [None]:
for res in results[1:10]:
    print(res)

In [None]:
sorted_bulk_like_results = []
for env in bulk_like_envs:
    idx = env["index"]
    pos = min_orig_atoms[idx].position
    closest_non_hcp = tree.query(pos)
    sorted_bulk_like_results.append((idx,closest_non_hcp[0]))
sorted_bulk_like_results = sorted(sorted_bulk_like_results, key=lambda x: x[1],reverse=True)

In [None]:
sorted_bulk_like_results[:50]

In [None]:
with open("sorted_bulk_like_results.json", "w") as f:
    json.dump(sorted_bulk_like_results, f, indent=2)

In [None]:
for s in sorted_bulk_like_results[:100]:
    idx = s[0]
    grain_fract = oatom_envs_dict[idx]["grain_fract"]
    print(f"idx: {idx}\ngrain_fract: {grain_fract}\ndistance: {s[1]}\n")