In [1]:
import itertools
import numpy as np
from math import sqrt
from numpy import arange
from random import choice, shuffle
from scipy.sparse import csgraph, csr_matrix, triu
from ase.neighborlist import NeighborList, natural_cutoffs
from ase.build.rotate import rotation_matrix_from_points
from ase.io import read, write
from collections import Counter

In [2]:
def rmsd(mol0, mol1):
    q0 = mol0.get_positions()
    q1 = mol1.get_positions()
    R = rotation_matrix_from_points(q1.T, q0.T)
    Rq1 = np.dot(q1, R.T)
    return sqrt(3*((Rq1 - q0)**2).mean())

def find_neighbors(mol):
    cutoffs = natural_cutoffs(mol, mult=1.2)
    neighbor_list = NeighborList(cutoffs, skin=0, self_interaction=False, bothways=True)
    neighbor_list.update(mol)
    mol.neighbors = [set(neighbor_list.get_neighbors(atom.index)[0].tolist()) for atom in mol]
    mol.connectivity_matrix = neighbor_list.get_connectivity_matrix()

def fragcount(mol):
    return csgraph.connected_components(mol.connectivity_matrix)[0]

def adjd(mol0, mol1):
    return csr_matrix.count_nonzero(triu(mol0.connectivity_matrix) != triu(mol1.connectivity_matrix))

class PartSet(set):
    def __init__(self, key, atom_types):
        super().__init__()
        self.key = key
        self.atom_types = atom_types
    def add(self, item):
        super().add(item)
        self.atom_types[item] = self.key

class Partition(dict):
    def __init__(self):
        super().__init__()
        self.atom_types = {}
        self.leaf_index = {}
    def new_type(self, branch):
        if branch in self.leaf_index:
            if self.leaf_index[branch] == 1:
                self[(*branch, 1)] = self.pop(branch)
            self.leaf_index[branch] += 1
            key = (*branch, self.leaf_index[branch])
        else:
            key = branch
            self.leaf_index[branch] = 1
        self[key] = PartSet(key, self.atom_types)
        return self[key]
    def print(self):
        print()
        for i, k in enumerate(self, start=1):
            print(f'{i:3}:  {k}   ->   { {i + 1 for i in self[k]} }')

def get_eltypes(mol):
    types = {}
    eltypes = Partition()
    for i, z in enumerate(mol.get_atomic_numbers().tolist()):
        if z not in types:
            types[z] = eltypes.new_type((z,))
        types[z].add(i) 
    eltypes.print()
    return eltypes

def get_mnatypes(mol, intypes):
    while True:
        types = Partition()
        for k in intypes:
            for i1 in intypes[k]:
                if i1 not in types.atom_types:
                    type = types.new_type(k)
                    type.add(i1)
                    for i2 in intypes[k]:
                        if i2 not in types.atom_types:
                            if Counter([intypes.atom_types[i] for i in mol.neighbors[i1]]) \
                            == Counter([intypes.atom_types[i] for i in mol.neighbors[i2]]):
                                type.add(i2)
        if types == intypes:
            types.print()
            return types
        intypes = types

def get_neighbor_types(mol, atom_types):
    neighbor_types = {}
    for i in [atom.index for atom in mol]:
        neighbor_types[i] = {}
        for k, atom_list in atom_types.items():
            if mol.neighbors[i] & atom_list:
                neighbor_types[i][k] = mol.neighbors[i] & atom_list
    return neighbor_types

In [3]:
mol0 = read('input_ab.xyz', index=0)
mol1 = read('input_ab.xyz', index=1)

find_neighbors(mol0)
find_neighbors(mol1)
print(f"There are {len(mol0)} atoms and {fragcount(mol0)} fragment(s) in system 0")
print(f"There are {len(mol1)} atoms and {fragcount(mol1)} fragment(s) in system 1")
print(f"There are {adjd(mol0, mol1)} adjacency differences between mol and mol1")
print(f"The root mean square distance between mol0 and mol1 is {rmsd(mol0, mol1):.4f} ")
#print(triu(matrix0) != triu(matrix1))

eltypes0 = get_eltypes(mol0)
mnatypes0 = get_mnatypes(mol0, eltypes0)
#neighbor_mnatypes0 = get_neighbor_types(mol0, mnatypes0)

There are 90 atoms and 1 fragment(s) in system 0
There are 90 atoms and 1 fragment(s) in system 1
There are 0 adjacency differences between mol and mol1
The root mean square distance between mol0 and mol1 is 2.8732 

  1:  (6,)   ->   {1, 2, 3, 4, 5, 6, 7, 8, 14, 16, 18, 19, 20, 21, 23, 28, 35, 36, 42, 44, 46, 47, 48, 49, 51, 56, 63, 64, 70, 72, 74, 75, 76, 77, 79, 84}
  2:  (1,)   ->   {9, 10, 11, 12, 13, 22, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 65, 66, 67, 68, 69, 78, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90}
  3:  (7,)   ->   {71, 73, 43, 45, 15, 17}

  1:  (6, 1, 1)   ->   {1, 3, 5}
  2:  (6, 1, 2)   ->   {2, 4, 6}
  3:  (6, 1, 3)   ->   {48, 20, 76}
  4:  (6, 2)   ->   {35, 7, 63}
  5:  (6, 3, 1)   ->   {64, 8, 36}
  6:  (6, 3, 2)   ->   {51, 23, 79}
  7:  (6, 3, 3)   ->   {56, 28, 84}
  8:  (6, 4)   ->   {70, 42, 14}
  9:  (6, 5)   ->   {16, 72, 44}
 10:  (6, 6)   ->   {74, 18, 46}
 11:  (6, 7, 1)   ->   {19, 75, 47

In [4]:
the_key = min(mnatypes0, key=lambda k: len(mnatypes0[k]))
for i in mnatypes0[the_key]:
    eltypes0[the_key[0:1]].remove(i)
for i in mnatypes0[the_key]:
    eltypes0.new_type(the_key[0:1]).add(i)
eltypes0.print()
mnatypes0 = get_mnatypes(mol0, eltypes0)


  1:  (1,)   ->   {9, 10, 11, 12, 13, 22, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 65, 66, 67, 68, 69, 78, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90}
  2:  (7,)   ->   {71, 73, 43, 45, 15, 17}
  3:  (6, 1)   ->   {2, 4, 6, 7, 8, 14, 16, 18, 19, 20, 21, 23, 28, 35, 36, 42, 44, 46, 47, 48, 49, 51, 56, 63, 64, 70, 72, 74, 75, 76, 77, 79, 84}
  4:  (6, 2)   ->   {1}
  5:  (6, 3)   ->   {3}
  6:  (6, 4)   ->   {5}

  1:  (1, 1, 1, 1, 1)   ->   {65, 66, 67}
  2:  (1, 1, 1, 1, 2)   ->   {37, 38, 39}
  3:  (1, 1, 1, 1, 3)   ->   {9, 10, 11}
  4:  (1, 1, 1, 2, 1)   ->   {80, 81, 82}
  5:  (1, 1, 1, 2, 2)   ->   {52, 53, 54}
  6:  (1, 1, 1, 2, 3)   ->   {24, 25, 26}
  7:  (1, 1, 1, 3, 1)   ->   {85, 86, 87}
  8:  (1, 1, 1, 3, 2)   ->   {57, 58, 59}
  9:  (1, 1, 1, 3, 3)   ->   {29, 30, 31}
 10:  (1, 1, 2, 1)   ->   {68, 69}
 11:  (1, 1, 2, 2)   ->   {40, 41}
 12:  (1, 1, 2, 3)   ->   {12, 13}
 13:  (1, 1, 3, 1, 1)   ->   {50}
 14:  (1,