In [1]:
from tips.io import load_ds
from ase import Atoms, neighborlist
from scipy import sparse
from tips.io import load_ds
import numpy as np
import nglview



In [2]:
def check_mols(mol_sets):
    """Check that the molecule scafolds are at least the same species"""
    symbs = [str(atoms[mol_set].symbols) for mol_set in mol_sets]
    cnt_mol = [symbs.count(k) for k in set(symbs)]
    return set(symbs) == set(["C2O2", "CNC2NC"])

def addADP(A, D, P, M, atoms):
    """append ADP assignment to atoms (for output)"""
    from ase import Atoms
    atoms = atoms.copy()
    numbers = np.array([10]*A.shape[0] + [11]*D.shape[0] + [12]*P.shape[0] + [13]*P.shape[0])
    positions = np.concatenate([A[:,1:], D[:,1:], P[:,0,1:], P[:,1,1:]])
    atoms += Atoms(numbers, positions=positions)
    return atoms

def atoms2adp(atoms, build_2nd=False, build_3rd=False, check=True, idx=None):
    natoms = len(atoms)
    (heavy,) = np.where(atoms.numbers != 1)
    (hydro,) = np.where(atoms.numbers == 1)
    (nitro,) = np.where(atoms.numbers == 7)
    (oxyge,) = np.where(atoms.numbers == 8)
    O_act = oxyge # all oxygen are active
    heavy2r = {k:i for i,k in enumerate(heavy)}
    
    h_rc = 6 if build_3rd else 4
    
    cutoff = {
      ("H", "C"): 2,
      ("H", "N"): h_rc,
      ("H", "O"): h_rc,
      ("C", "C"): 2,
      ("C", "N"): 2,
      ("C", "O"): 2,
    }
    
    nl_i, nl_j, nl_d = neighborlist.neighbor_list("ijd", atoms, cutoff, self_interaction=False)
    conMat = sparse.dok_matrix((natoms, natoms), dtype=np.int8)
    conMat[nl_i, nl_j] = 1  # we have several running indices here prefixed by (nl, mol, h)
    conMat[nl_j, nl_i] = 1  # v---- adapted from the ASE documentation
    
    # -- recognition of overall scafold (heavy atom networks)
    n_mol, mol_assign = sparse.csgraph.connected_components(conMat[heavy, :][:, heavy])
    
    cell = atoms.cell.diagonal()
    new_positions = atoms.positions.copy()
    for mol_i in range(n_mol):
        heavy_atoms = atoms[heavy]
        pos_mol = heavy_atoms.positions[mol_assign==mol_i]
        mass_mol = heavy_atoms.get_masses()[mol_assign==mol_i]
        pos_mol -= np.rint((pos_mol-pos_mol[:1,:])/cell[None,:])*cell[None,:]
        com_mol = mass_mol@pos_mol/(mass_mol.sum())
        pos_mol -= np.rint(com_mol/cell[None,:]-0.5)*cell[None,:]
        new_positions[heavy[mol_assign==mol_i]] = pos_mol
    atoms.set_positions(new_positions)
    mol_sets = [heavy[mol_assign == mol_i] for mol_i in range(n_mol)]
    
    
    CN_N = np.squeeze(np.asarray(conMat[nitro, :][:, heavy].sum(axis=1))) 
    N_act = nitro[CN_N==2]
    ALL_act = np.concatenate([O_act, N_act])
    mol_acts = [np.intersect1d(mol_set, ALL_act) for mol_set in mol_sets]
    if check: assert check_mols(mol_sets), str(mol_sets)
    
    # -- zeros pass, tag active protons: 
    sel0 = [np.where(nl_i == h_ia)[0] for h_ia in hydro]
    h_n0a = np.array([nl_j[_sel][np.argmin(nl_d[_sel])] for _sel in sel0])
    h_nn = np.stack([hydro, h_n0a]).T
    H_act = hydro[np.in1d(h_n0a, ALL_act)]
    if check: assert len(H_act)==32
    
    if not build_2nd:
        return H_act
    # -- first pass, tag D(oners)
    sel1 = [np.where(nl_i == h_ia)[0] for h_ia in H_act]
    h_n1a = np.array([nl_j[_sel][np.argmin(nl_d[_sel])] for _sel in sel1])
    h_n1e = atoms.numbers[h_n1a]
    h_n1d = np.array([nl_d[_sel][np.argmin(nl_d[_sel])] for _sel in sel1])
    D_mol = np.array([mol_assign[heavy2r[n1a]] for n1a in h_n1a])
    h_act1 = [mol_acts[_di] for _di in D_mol]
    D = np.concatenate([np.array(D_mol)[:,None], atoms.positions[h_n1a]], axis=1)
    A_mol = np.setdiff1d(np.arange(n_mol), D_mol)
    A = np.concatenate([np.array(A_mol)[:,None], [atoms.positions[mol_acts[a]].mean(axis=0)
                                                  for a in A_mol]], axis=1)
    if check and A_mol.shape[0]!=32: print(f'Double Proton Found @ {idx}')
    
    # -- second pass, make candidate neighbor for A
    sel2 = [_sel[~np.in1d(nl_j[_sel], _act)] for _sel, _act in zip(sel1, h_act1)]
    h_n2a = np.array([nl_j[_sel][np.argmin(nl_d[_sel])] for _sel in sel2])
    h_n2d = np.array([nl_d[_sel][np.argmin(nl_d[_sel])] for _sel in sel2])
    ANL = {a: [] for a in A_mol}
    for di, n2a, n2d in zip(D_mol, h_n2a, h_n2d):
        a = mol_assign[heavy2r[n2a]]
        if a in A_mol:
            ANL[a].append((di, n2d))
    
    P = [] 
    for a in A_mol:
        if ANL[a]:
            d = ANL[a][np.argmin([tmp[1] for tmp in ANL[a]])][0]
            P.append(np.concatenate([D[D_mol==d],A[A_mol==a]]))
            D = D[D_mol!=d]
            A = A[A_mol!=a]
            D_mol = D_mol[D_mol!=d]
            A_mol = A_mol[A_mol!=a]
    P = np.array(P)
    if not build_3rd:
        return A, D, P, mol_sets, h_nn

def trackAD(A_set, D_set, P_set,
            A_prev, D_prev, P_prev, 
            A_idx_prev, D_idx_prev, P_idx_prev,
            M, h_nn):
    if A_prev is None:
        # make inital assignment
        A_idx = np.arange(len(A_set), dtype=int)
        D_idx = np.arange(len(D_set), dtype=int)
        P_idx = np.array([np.arange(len(A_set),len(A_set)+len(P_set), dtype=int),
                          np.arange(len(D_set),len(D_set)+len(P_set), dtype=int)]).T
        return A_idx, D_idx, P_idx
    
    A_prev, D_prev, P_prev = A_prev[:,0], D_prev[:,0], [set(p) for p in P_prev[:,:,0]]
    # convert A,D,P to sets of mols
    
    A_idx = np.zeros(len(A_set), dtype=int)
    D_idx = np.zeros(len(D_set), dtype=int)
    P_idx = np.zeros([len(P_set), 2], dtype=int)
    for i, p in enumerate(P_set):
        if p in P_prev:
            continue
            P_idx[i,:] = P_idx_prev[i,:]
        else:
            intersection = [bool(p.intersection(_p)) for _p in P_prev]
            if sum(intersection)==1:
                # XP -> PY
                p_prev = P_prev[np.where(intersection)[0][0]]
                X = p - p_prev
                Y = p_prev - p
                print(f'{X},{p_prev} -> {p}, {Y}', end='\t\t')
                to_draw =  np.concatenate([*[M[int(pp)] for pp in X],
                                           *[M[int(pp)] for pp in p_prev]])
                to_draw = np.concatenate([to_draw, 
                                          [h_id for h_id, nn in h_nn
                                            if nn in to_draw]])
                print('@'+','.join(map(str, to_draw)))
                
            elif sum(intersection)==0:
                # P -> AD
                print(f'+{p}')
                pass
            else:
                print(A_prev, D_prev, P_prev)
                print(A_set, D_set, P_set)
                raise NotImplementedError(f'Cannot resolve the HB pair {p}') 
    for i, p in enumerate(P_prev):
        intersection = [bool(p.intersection(_p)) for _p in P_set]
        if sum(intersection)==0:
            # P -> AD
            print(f'-{p}')
    return A_idx, D_idx, P_idx

In [3]:
def smart_wrap(atoms):
      #!/usr/bin/env python
    import numpy as np
    from ase import Atoms, neighborlist
    from ase.io import write
    from scipy import sparse
    natoms =  len(atoms)
    
    assert (atoms.cell.angles() == 90.).all(), "Only orthogonal cells allowed."
    cell = atoms.cell.diagonal()
    atoms.wrap()
  
    cutoff = {
        ("H", "C"): 1.74,
        ("H", "N"): 1.9,
        ("H", "O"): 1.9,
        ("C", "C"): 2,
        ("C", "N"): 2,
        ("C", "O"): 2,
    }

    nl_i, nl_j, nl_d = neighborlist.neighbor_list("ijd", atoms, cutoff, self_interaction=False)
    conMat = sparse.dok_matrix((natoms, natoms), dtype=np.int8)
    conMat[nl_i, nl_j] = 1  # we have several running indices here prefixed by (nl, mol, h)
    conMat[nl_j, nl_i] = 1  # v---- shamelessly taken from the ase documentation
    n_mol, mol_assign = sparse.csgraph.connected_components(conMat)
    for mol_i in range(n_mol):
        pos_mol = atoms.positions[mol_assign==mol_i]
        mass_mol = atoms.get_masses()[mol_assign==mol_i]
        pos_mol -= np.rint((pos_mol-pos_mol[:1,:])/cell[None,:])*cell[None,:]
        com_mol = mass_mol@pos_mol/(mass_mol.sum())
        pos_mol -= np.rint(com_mol/cell[None,:]-0.5)*cell[None,:]
        atoms.positions[mol_assign==mol_i] = pos_mol
        
def unwrap(atoms_prev, atoms_next):
    cell = atoms_next.cell.diagonal()
    atoms_next.positions -= np.rint((atoms_next.positions-atoms_prev.positions)
                                    /cell[None,:])*cell[None,:]    

In [4]:
from ase.io import read, write
traj = read('../exp/scan/prod/gen8/nvt-340k-5ns-0/a32b32i0-r1.08/asemd.traj', index='2500:4000:10')
for atoms in traj:
    atoms.positions-=np.array([15.30250755, -0.32583541, 12.2960064 ])
    atoms.positions+=np.array([9, 9, 9])
# A_prev, D_prev, P_prev, M_prev, A_idx_prev, D_idx_prev, P_idx_prev = [None]*7
# for i, atoms in enumerate(traj):
#     assert (M_prev is None) or (M==M_prev), f"{M}!={M_prev}"
#     A, D, P, M, h_nn = atoms2adp(atoms, check=False, build_2nd=True)
#     
#     A_set, D_set, P_set = A[:,0], D[:,0], [set(p) for p in P[:,:,0]]
#     A_idx, D_idx, P_idx = trackAD(A_set, D_set, P_set, 
#                                   A_prev, D_prev, P_prev,
#                                   A_idx_prev, D_idx_prev, P_idx_prev,
#                                   M, h_nn)
#     
#     A_prev, D_prev, P_prev, M_prev = A, D, P, M
#     A_idx_prev, D_idx_prev, P_idx_prev = A_idx, D_idx, P_idx
#     print(f'{i}-- A{A_idx_prev.shape[0]}, D{D_idx_prev.shape[0]}')

In [7]:
[smart_wrap(atoms) for atoms in traj]
v = nglview.show_asetraj(traj)
v.add_unitcell()
v._remote_call("setSize", target="Widget", 
               args=['800px', '800px'])
v.clear_representations()
sel = "@32,33,34,35,192,193,194,195,616,617,618,619,620,621,36,37,38,39,196,197,198,199,622,623,624,625,626,627"
v.add_spacefill(
    sel,
    radius_type='vdw',radius_scale=0.5,
    roughness=1,metalness=0,
)
v.add_spacefill(
    radius_type='vdw',radius_scale=0.5,
    roughness=1,metalness=0,opacity=0.1
)
v.parameters=dict(clipDist=0,sampleLevel=2)
v

NGLWidget(max_frame=149)

In [12]:
v.download_image()

In [621]:
ds = load_ds('../trajs/al-adam1-sin-run1-gen27/nvt-340k-1ns/m0i32-r1.16/asemd.traj',fmt='asetraj', index='::1000')

In [20]:
#traj = ds.convert(fmt='ase')[:100
from ase.io import read
#traj = read('defect.traj', '::1000')  
#traj = ds[::200].convert(fmt='ase')
# traj = read('../a_defect.traj', index='10000:20000:1000')
traj = read('../exp/scan/prod/gen8/nvt-340k-5ns-0/a32b32i0-r1.08/asemd.traj', index='1000::1000')
traj_labelled = [addADP(*atoms2adp(atoms, check=False, build_2nd=True), atoms) for atoms in traj]

In [22]:
from ase.io import write
write('/cephyr/users/yunqi/Alvis/labelled.xyz', traj_labelled)