In [1]:
import itertools
import numpy as np
from math import sqrt
from numpy import arange
from bisect import insort
from collections import defaultdict
from scipy.sparse import csgraph, csr_matrix, triu
from ase.neighborlist import NeighborList, natural_cutoffs
from ase.build.rotate import rotation_matrix_from_points
from ase.io import read, write
from collections import Counter

In [14]:
def rmsd(mol0, mol1):
    q0 = mol0.get_positions()
    q1 = mol1.get_positions()
    q0 = q0 - np.mean(q0, axis=0)
    q1 = q1 - np.mean(q1, axis=0)
    R = rotation_matrix_from_points(q1.T, q0.T)
    q1 = np.dot(q1, R.T)
    return sqrt(3*np.mean((q1 - q0)**2))

def find_neighbors(mol):
    cutoffs = natural_cutoffs(mol, mult=1.2)
    neighbor_list = NeighborList(cutoffs, skin=0, self_interaction=False, bothways=True)
    neighbor_list.update(mol)
    mol.neighbors = [set(neighbor_list.get_neighbors(atom.index)[0].tolist()) for atom in mol]
    mol.connectivity_matrix = neighbor_list.get_connectivity_matrix()

def fragcount(mol):
    return csgraph.connected_components(mol.connectivity_matrix)[0]

def adjd(mol0, mol1):
    return csr_matrix.count_nonzero(triu(mol0.connectivity_matrix) != triu(mol1.connectivity_matrix))

class PartSet(set):
    def add(self, item):
        super().add(item)
        self.atom_types[item] = self.trace

class Partition(dict):
    def __init__(self):
        super().__init__()
        self.atom_types = {}
        self.trace_index = defaultdict(int)
    def __le__(self, other):
        for a in self.values():
            for b in other.values():
                if a <= b:
                    break
            else:
                return False
        return True 
    def new_part(self, trace=()):
        self.trace_index[trace] += 1
        key = (*trace, self.trace_index[trace])
        self[key] = PartSet()
        self[key].trace = key
        self[key].atom_types = self.atom_types
        return self[key]
    def counter(self, neighbors):
        neighbor_types = [self.atom_types[i] for i in neighbors]
        return frozenset(Counter(neighbor_types).items())
    def print(self):
        print()
        for i, k in enumerate(self, start=1):
            print(f'{i:3}:  {k}   ->   { {i for i in self[k]} }')

def get_eltypes(mol):
    types = {}
    eltypes = Partition()
    for i, elsym in enumerate(mol.get_chemical_symbols()):
        if elsym not in types:
            types[elsym] = eltypes.new_part()
        types[elsym].add(i)
    return eltypes

def get_mnatypes(mol, types):
    while True:
        uptypes = levelup_mnatypes(mol, types)
        if all(k[-1] == 1 for k in uptypes):
            return types
        types = uptypes

def levelup_mnatypes(mol, types):
    uptypes = Partition()
    for k in types:
        subtypes = {}
        for i in types[k]:
            neighborhood = types.counter(mol.neighbors[i])
            if neighborhood not in subtypes:
                subtypes[neighborhood] = uptypes.new_part(k)
            subtypes[neighborhood].add(i)
    return uptypes

def get_splittypes(types, trace):
    splittypes = Partition()
    for k in types:
        if k == trace:
            for i in types[trace]:
                splittypes.new_part(trace).add(i)
        else:
            splittype = splittypes.new_part(k)
            for i in types[k]:
                splittype.add(i)
    return splittypes

def break_type_symmetry(symtypes):
    partitionlist = []
    for trace in symtypes:
        if len(symtypes[trace]) > 1:
            splittypes = get_splittypes(symtypes, trace)
            splittypes = get_mnatypes(mol0, splittypes)
            insort(partitionlist, splittypes, key=len)
    dispartitionlist = []
    for types in partitionlist:
        for distypes in dispartitionlist:
            if types <= distypes:
                break
        else:
            dispartitionlist.append(types)
    return dispartitionlist

In [15]:
mol0 = read('input_ab.xyz', index=0)
mol1 = read('input_ab.xyz', index=1)

find_neighbors(mol0)
find_neighbors(mol1)
print(f"There are {len(mol0)} atoms and {fragcount(mol0)} fragment(s) in system 0")
print(f"There are {len(mol1)} atoms and {fragcount(mol1)} fragment(s) in system 1")
print(f"There are {adjd(mol0, mol1)} adjacency differences between mol and mol1")
print(f"The root mean square distance between mol0 and mol1 is {rmsd(mol0, mol1):.4f} ")
#print(triu(matrix0) != triu(matrix1))

eltypes0 = get_eltypes(mol0)
mnatypes0 = get_mnatypes(mol0, eltypes0)
#neighbor_mnatypes0 = get_neighborhhod(mol0, mnatypes0)
eltypes0.print()
mnatypes0.print()

There are 90 atoms and 1 fragment(s) in system 0
There are 90 atoms and 1 fragment(s) in system 1
There are 0 adjacency differences between mol and mol1
The root mean square distance between mol0 and mol1 is 2.8721 

  1:  (1,)   ->   {0, 1, 2, 3, 4, 5, 6, 7, 13, 15, 17, 18, 19, 20, 22, 27, 34, 35, 41, 43, 45, 46, 47, 48, 50, 55, 62, 63, 69, 71, 73, 74, 75, 76, 78, 83}
  2:  (2,)   ->   {8, 9, 10, 11, 12, 21, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 36, 37, 38, 39, 40, 49, 51, 52, 53, 54, 56, 57, 58, 59, 60, 61, 64, 65, 66, 67, 68, 77, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89}
  3:  (3,)   ->   {70, 72, 42, 44, 14, 16}

  1:  (1, 1, 1, 1)   ->   {0, 2, 4}
  2:  (1, 1, 2, 1)   ->   {1, 3, 5}
  3:  (1, 1, 3, 1)   ->   {19, 75, 47}
  4:  (1, 2, 1, 1)   ->   {34, 62, 6}
  5:  (1, 3, 1, 1)   ->   {35, 63, 7}
  6:  (1, 3, 2, 1)   ->   {50, 78, 22}
  7:  (1, 3, 3, 1)   ->   {27, 83, 55}
  8:  (1, 4, 1, 1)   ->   {41, 13, 69}
  9:  (1, 5, 1, 1)   ->   {43, 15, 71}
 10:  (1, 6, 1, 1)   ->   {73, 4

In [16]:
brokenmnatypelist0 = break_type_symmetry(mnatypes0)
for types in brokenmnatypelist0:
    types.print()


  1:  (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {0}
  2:  (1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1)   ->   {2}
  3:  (1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1)   ->   {4}
  4:  (1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {1}
  5:  (1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1)   ->   {3}
  6:  (1, 1, 2, 1, 1, 3, 1, 1, 1, 1, 1, 1)   ->   {5}
  7:  (1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {19}
  8:  (1, 1, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1)   ->   {75}
  9:  (1, 1, 3, 1, 1, 1, 1, 1, 1, 3, 1, 1)   ->   {47}
 10:  (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {34}
 11:  (1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1)   ->   {62}
 12:  (1, 2, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1)   ->   {6}
 13:  (1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {35}
 14:  (1, 3, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1)   ->   {63}
 15:  (1, 3, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1)   ->   {7}
 16:  (1, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {50}
 17:  (1, 3, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1)   ->   {78}
 18:  (1, 3, 2, 1, 1, 1, 1, 1, 1, 1, 3, 1)   ->   {22}
 19:  (1, 3, 3, 1

In [17]:
for types_i in brokenmnatypelist0:
    print('-' * 60)
    for types_j in break_type_symmetry(types_i):
        types_j.print()

------------------------------------------------------------

  1:  (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {0}
  2:  (1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {2}
  3:  (1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {4}
  4:  (1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {1}
  5:  (1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1)   ->   {3}
  6:  (1, 1, 2, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1)   ->   {5}
  7:  (1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {19}
  8:  (1, 1, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1)   ->   {75}
  9:  (1, 1, 3, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1)   ->   {47}
 10:  (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {34}
 11:  (1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1)   ->   {62}
 12:  (1, 2, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1)   ->   {6}
 13:  (1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {35}
 14:  (1, 3, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1)   ->   {63}
 15:  (1, 3, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1)   ->   {7}
 16:  (1, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)   ->   {50}
 17:  (1, 3, 2, 1,