In [1]:
import itertools
import numpy as np
from math import sqrt
from numpy import arange
from bisect import insort
from collections import defaultdict
from scipy.sparse import csgraph, csr_matrix, triu
from ase.neighborlist import NeighborList, natural_cutoffs
from ase.build.rotate import rotation_matrix_from_points
from ase.io import read, write
from collections import Counter

In [2]:
def rmsd(mol0, mol1):
    q0 = mol0.get_positions()
    q1 = mol1.get_positions()
    q0 = q0 - np.mean(q0, axis=0)
    q1 = q1 - np.mean(q1, axis=0)
    R = rotation_matrix_from_points(q1.T, q0.T)
    q1 = np.dot(q1, R.T)
    return sqrt(3*np.mean((q1 - q0)**2))

def find_neighbors(mol):
    cutoffs = natural_cutoffs(mol, mult=1.2)
    neighbor_list = NeighborList(cutoffs, skin=0, self_interaction=False, bothways=True)
    neighbor_list.update(mol)
    mol.neighbors = [set(neighbor_list.get_neighbors(atom.index)[0].tolist()) for atom in mol]
    mol.connectivity_matrix = neighbor_list.get_connectivity_matrix()

def fragcount(mol):
    return csgraph.connected_components(mol.connectivity_matrix)[0]

def adjd(mol0, mol1):
    return csr_matrix.count_nonzero(triu(mol0.connectivity_matrix) != triu(mol1.connectivity_matrix))

class PartSet(set):
    def add(self, item):
        super().add(item)
        self.atom_types[item] = self.type_key

class Partition(dict):
    def __init__(self):
        super().__init__()
        self.atom_types = {}
        self.leaf_index = {}
        self.tree = defaultdict(list)
    def __le__(self, other):
        for a in self.values():
            for b in other.values():
                if a <= b:
                    break
            else:
                return False
        return True 
    def new_leaf(self, branch):
        if branch in self.leaf_index:
            if self.leaf_index[branch] == 0:
                self[(*branch, 0)] = self.pop(branch)
            self.leaf_index[branch] += 1
            key = (*branch, self.leaf_index[branch])
        else:
            key = branch
            self.tree[branch] = []
            self.leaf_index[branch] = 0
        self[key] = PartSet()
        self[key].type_key = key
        self[key].atom_types = self.atom_types
        self.tree[branch].append(self[key])
        return self[key]
    def print(self):
        print()
        for i, k in enumerate(self, start=1):
            print(f'{i:3}:  {k}   ->   { {i+1 for i in self[k]} }')

def get_eltypes(mol):
    types = {}
    eltypes = Partition()
    for i, elsym in enumerate(mol.get_chemical_symbols()):
        if elsym not in types:
            types[elsym] = eltypes.new_leaf(elsym)
        types[elsym].add(i)
    return eltypes

def get_mnatypes(mol, typepart):
    while True:
        subtypepart = Partition()
        for h in typepart:
            for i in typepart[h]:
                neighbor_types = Counter([typepart.atom_types[k] for k in mol.neighbors[i]])
                for subtype in subtypepart.tree[h]:
                    if neighbor_types == subtype.neighbor_types:
                        subtype.add(i)
                        break
                else:
                    newtype = subtypepart.new_leaf(h)
                    newtype.neighbor_types = neighbor_types
                    newtype.add(i)
        if typepart == subtypepart:
            return typepart
        typepart = subtypepart

def break_type_symmetry(typepart):
    typepartlist = []
    for type_i in typepart:
        if len(typepart[type_i]) >= 2:
            subtypepart = Partition()
            for type_j in typepart:
                if type_i == type_j:
                    for i in typepart[type_i]:
                        subtypepart.new_leaf(type_i).add(i)
                else:
                    subtype = subtypepart.new_leaf(type_j)
                    for i in typepart[type_j]:
                        subtype.add(i)
            subtypepart = get_mnatypes(mol0, subtypepart)
            insort(typepartlist, subtypepart, key=len)
    distypetartlist = []
    for typepart in typepartlist:
        for distypepart in distypetartlist:
            if typepart <= distypepart:
                break
        else:
            distypetartlist.append(typepart)
    return distypetartlist

In [3]:
mol0 = read('input_ab.xyz', index=0)
mol1 = read('input_ab.xyz', index=1)

find_neighbors(mol0)
find_neighbors(mol1)
print(f"There are {len(mol0)} atoms and {fragcount(mol0)} fragment(s) in system 0")
print(f"There are {len(mol1)} atoms and {fragcount(mol1)} fragment(s) in system 1")
print(f"There are {adjd(mol0, mol1)} adjacency differences between mol and mol1")
print(f"The root mean square distance between mol0 and mol1 is {rmsd(mol0, mol1):.4f} ")
#print(triu(matrix0) != triu(matrix1))

eltypes0 = get_eltypes(mol0)
mnatypes0 = get_mnatypes(mol0, eltypes0)
#neighbor_mnatypes0 = get_neighbor_types(mol0, mnatypes0)
eltypes0.print()
mnatypes0.print()

There are 90 atoms and 1 fragment(s) in system 0
There are 90 atoms and 1 fragment(s) in system 1
There are 0 adjacency differences between mol and mol1
The root mean square distance between mol0 and mol1 is 2.8721 

  1:  C   ->   {1, 2, 3, 4, 5, 6, 7, 8, 14, 16, 18, 19, 20, 21, 23, 28, 35, 36, 42, 44, 46, 47, 48, 49, 51, 56, 63, 64, 70, 72, 74, 75, 76, 77, 79, 84}
  2:  H   ->   {9, 10, 11, 12, 13, 22, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 65, 66, 67, 68, 69, 78, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90}
  3:  N   ->   {71, 73, 43, 45, 15, 17}

  1:  ('C', 0, 0)   ->   {1, 3, 5}
  2:  ('C', 0, 1)   ->   {2, 4, 6}
  3:  ('C', 0, 2)   ->   {48, 76, 20}
  4:  ('C', 1)   ->   {35, 7, 63}
  5:  ('C', 2, 0)   ->   {8, 64, 36}
  6:  ('C', 2, 1)   ->   {51, 23, 79}
  7:  ('C', 2, 2)   ->   {56, 84, 28}
  8:  ('C', 3)   ->   {42, 70, 14}
  9:  ('C', 4)   ->   {72, 16, 44}
 10:  ('C', 5)   ->   {74, 18, 46}
 11:  ('C', 6, 0)   -> 

In [4]:
brokenmnatypelist0 = break_type_symmetry(mnatypes0)
for typepart in brokenmnatypelist0:
    typepart.print()


  1:  ('C', 0, 0, 0)   ->   {1}
  2:  ('C', 0, 0, 1)   ->   {3}
  3:  ('C', 0, 0, 2)   ->   {5}
  4:  ('C', 0, 1, 0)   ->   {2}
  5:  ('C', 0, 1, 1)   ->   {4}
  6:  ('C', 0, 1, 2)   ->   {6}
  7:  ('C', 0, 2, 0)   ->   {20}
  8:  ('C', 0, 2, 1)   ->   {76}
  9:  ('C', 0, 2, 2)   ->   {48}
 10:  ('C', 1, 0)   ->   {35}
 11:  ('C', 1, 1)   ->   {63}
 12:  ('C', 1, 2)   ->   {7}
 13:  ('C', 2, 0, 0)   ->   {36}
 14:  ('C', 2, 0, 1)   ->   {64}
 15:  ('C', 2, 0, 2)   ->   {8}
 16:  ('C', 2, 1, 0)   ->   {51}
 17:  ('C', 2, 1, 1)   ->   {79}
 18:  ('C', 2, 1, 2)   ->   {23}
 19:  ('C', 2, 2, 0)   ->   {84}
 20:  ('C', 2, 2, 1)   ->   {28}
 21:  ('C', 2, 2, 2)   ->   {56}
 22:  ('C', 3, 0)   ->   {42}
 23:  ('C', 3, 1)   ->   {14}
 24:  ('C', 3, 2)   ->   {70}
 25:  ('C', 4, 0)   ->   {44}
 26:  ('C', 4, 1)   ->   {16}
 27:  ('C', 4, 2)   ->   {72}
 28:  ('C', 5, 0)   ->   {74}
 29:  ('C', 5, 1)   ->   {46}
 30:  ('C', 5, 2)   ->   {18}
 31:  ('C', 6, 0, 0)   ->   {75}
 32:  ('C', 6, 0, 1)

In [5]:
for typepart_i in brokenmnatypelist0:
    print('+' * 80)
    for typepart_j in break_type_symmetry(typepart_i):
        typepart_i.print()

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

  1:  ('C', 0, 0, 0)   ->   {1}
  2:  ('C', 0, 0, 1)   ->   {3}
  3:  ('C', 0, 0, 2)   ->   {5}
  4:  ('C', 0, 1, 0)   ->   {2}
  5:  ('C', 0, 1, 1)   ->   {4}
  6:  ('C', 0, 1, 2)   ->   {6}
  7:  ('C', 0, 2, 0)   ->   {20}
  8:  ('C', 0, 2, 1)   ->   {76}
  9:  ('C', 0, 2, 2)   ->   {48}
 10:  ('C', 1, 0)   ->   {35}
 11:  ('C', 1, 1)   ->   {63}
 12:  ('C', 1, 2)   ->   {7}
 13:  ('C', 2, 0, 0)   ->   {36}
 14:  ('C', 2, 0, 1)   ->   {64}
 15:  ('C', 2, 0, 2)   ->   {8}
 16:  ('C', 2, 1, 0)   ->   {51}
 17:  ('C', 2, 1, 1)   ->   {79}
 18:  ('C', 2, 1, 2)   ->   {23}
 19:  ('C', 2, 2, 0)   ->   {84}
 20:  ('C', 2, 2, 1)   ->   {28}
 21:  ('C', 2, 2, 2)   ->   {56}
 22:  ('C', 3, 0)   ->   {42}
 23:  ('C', 3, 1)   ->   {14}
 24:  ('C', 3, 2)   ->   {70}
 25:  ('C', 4, 0)   ->   {44}
 26:  ('C', 4, 1)   ->   {16}
 27:  ('C', 4, 2)   ->   {72}
 28:  ('C', 5, 0)   ->   {74}
 29:  ('C', 5, 1)   ->   {46}
 3