In [1]:
import itertools
import numpy as np
from math import sqrt
from numpy import arange
from itertools import permutations
from collections import defaultdict
from scipy.sparse import csgraph, csr_matrix, triu
from ase.neighborlist import NeighborList, natural_cutoffs
from ase.build.rotate import rotation_matrix_from_points
from ase.io import read, write
from collections import Counter

In [2]:
def rmsd(mol0, mol1):
    q0 = mol0.get_positions()
    q1 = mol1.get_positions()
    q0 = q0 - np.mean(q0, axis=0)
    q1 = q1 - np.mean(q1, axis=0)
    R = rotation_matrix_from_points(q1.T, q0.T)
    q1 = np.dot(q1, R.T)
    return sqrt(3*np.mean((q1 - q0)**2))

def find_neighbors(mol):
    cutoffs = natural_cutoffs(mol, mult=1.2)
    neighbor_list = NeighborList(cutoffs, skin=0, self_interaction=False, bothways=True)
    neighbor_list.update(mol)
    mol.neighbors = [set(neighbor_list.get_neighbors(atom.index)[0].tolist()) for atom in mol]
    mol.connectivity_matrix = neighbor_list.get_connectivity_matrix()

def fragcount(mol):
    return csgraph.connected_components(mol.connectivity_matrix)[0]

def adjd(mol0, mol1):
    return csr_matrix.count_nonzero(triu(mol0.connectivity_matrix) != triu(mol1.connectivity_matrix))

class Part(set):
    def __init__(self, trace, atom_types):
        super().__init__()
        self.trace = trace
        self.atom_types = atom_types
    def add(self, item):
        super().add(item)
        self.atom_types[item] = self.trace

class Partition(dict):
    def __init__(self):
        super().__init__()
        self.atom_types = {}
        self.trace_index = defaultdict(int)
    def new_part(self, trace=()):
        self.trace_index[trace] += 1
        new_trace = (*trace, self.trace_index[trace])
        self[new_trace] = Part(new_trace, self.atom_types)
        return self[new_trace]
    def counter(self, neighbors):
        neighbor_types = [self.atom_types[i] for i in neighbors]
        return frozenset(Counter(neighbor_types).items())
    def get_frozenpartition(self):
        return frozenset(frozenset(i) for i in self.values())
    def print(self):
        print()
        for i, k in enumerate(self, start=1):
            print(f'{i:3}:  {k}   ->   { {i+1 for i in self[k]} }')

def get_eltypes(mol):
    typedict = {}
    eltypes = Partition()
    for i, elnum in enumerate(mol.get_atomic_numbers()):
        if elnum not in typedict:
            typedict[elnum] = eltypes.new_part()
        typedict[elnum].add(i)
    return eltypes

def get_mnatypes(mol, types):
    while True:
        uptypes = levelup_mnatypes(mol, types)
        if all(k[-1] == 1 for k in uptypes):
            return types
        types = uptypes

def levelup_mnatypes(mol, types):
    uptypes = Partition()
    for k in types:
        subtypes = {}
        for i in types[k]:
            neighborhood = types.counter(mol.neighbors[i])
            if neighborhood not in subtypes:
                subtypes[neighborhood] = uptypes.new_part(k)
            subtypes[neighborhood].add(i)
    return uptypes

def get_subtypes(types, trace):
    subtypes = Partition()
    for k in types:
        if k == trace:
            for i in types[trace]:
                subtypes.new_part(trace).add(i)
        else:
            splittype = subtypes.new_part(k)
            for i in types[k]:
                splittype.add(i)
    return subtypes

def break_type_symmetry(mol, types):
    partitiondict = {}
    for k in types:
        if len(types[k]) > 1:
            subtypes = get_subtypes(types, k)
            partitiondict[k] = get_mnatypes(mol, subtypes)
    uniquepartitions = {}
    for k in partitiondict:
        frozenpartition = partitiondict[k].get_frozenpartition()
        if frozenpartition not in uniquepartitions:
            uniquepartitions[frozenpartition] = []
        uniquepartitions[frozenpartition].append(types[k])
    uniquepartitionkeys = sorted(uniquepartitions.keys(), key=len)
    superpartitionkeys = []
#    return uniquepartitions
    # Exclude subpartitions
    for k in uniquepartitionkeys:
        for l in superpartitionkeys:
            for i in k:
                for j in l:
                    if i <= j: break
                else: break
            else: break
        else:
            superpartitionkeys.append(k)
    return {k:uniquepartitions[k] for k in superpartitionkeys}

In [3]:
mol0 = read('test24.xyz', index=0)
mol1 = read('test24.xyz', index=1)

find_neighbors(mol0)
find_neighbors(mol1)
print(f"There are {len(mol0)} atoms and {fragcount(mol0)} fragment(s) in system 0")
print(f"There are {len(mol1)} atoms and {fragcount(mol1)} fragment(s) in system 1")
print(f"There are {adjd(mol0, mol1)} adjacency differences between mol and mol1")
print(f"The root mean square distance between mol0 and mol1 is {rmsd(mol0, mol1):.4f} ")
#print(triu(matrix0) != triu(matrix1))

mol = mol1
eltypes = get_eltypes(mol)
mnatypes = get_mnatypes(mol, eltypes)
eltypes.print()
mnatypes.print()

There are 86 atoms and 1 fragment(s) in system 0
There are 86 atoms and 1 fragment(s) in system 1
There are 171 adjacency differences between mol and mol1
The root mean square distance between mol0 and mol1 is 5.7485 

  1:  (1,)   ->   {1, 2, 4, 5, 8, 11, 15, 19, 21, 22, 23, 24, 30, 31, 33, 38, 40, 41, 50, 52, 55, 57, 60, 61, 69, 73, 74, 75, 78, 79, 82, 86}
  2:  (2,)   ->   {3, 6, 7, 9, 10, 12, 13, 14, 16, 17, 18, 20, 25, 26, 27, 28, 32, 34, 35, 36, 37, 39, 42, 43, 44, 45, 46, 47, 48, 49, 51, 53, 54, 58, 59, 63, 64, 65, 67, 68, 70, 72, 77, 80, 81, 83, 85}
  3:  (3,)   ->   {66, 29}
  4:  (4,)   ->   {56, 76}
  5:  (5,)   ->   {62, 71}
  6:  (6,)   ->   {84}

  1:  (1, 1, 1, 1, 1, 1)   ->   {1, 52}
  2:  (1, 1, 1, 2, 1, 1)   ->   {4, 23}
  3:  (1, 2, 1, 1, 1, 1)   ->   {2}
  4:  (1, 3, 1, 1, 1, 1)   ->   {5, 8, 11, 50, 21, 86}
  5:  (1, 3, 1, 1, 2, 1)   ->   {41, 78, 15, 79, 55, 61}
  6:  (1, 4, 1, 1, 1, 1)   ->   {33}
  7:  (1, 4, 2, 1, 1, 1)   ->   {19}
  8:  (1, 4, 3, 1, 1, 1)   ->

In [4]:
uniquepartitions = break_type_symmetry(mol, mnatypes)
for partition in uniquepartitions:
    print()
    print([{i+1 for i in s} for s in uniquepartitions[partition]])
    for i, s in enumerate(partition, start=1):
        print(f'{i:3}:  { {i+1 for i in s} }')


[{65, 45}]
  1:  {33}
  2:  {5, 8, 11, 50, 21, 86}
  3:  {57}
  4:  {58}
  5:  {29}
  6:  {42, 6}
  7:  {2}
  8:  {19}
  9:  {60, 30}
 10:  {41, 78, 15, 79, 55, 61}
 11:  {74}
 12:  {24}
 13:  {73}
 14:  {40}
 15:  {25}
 16:  {32, 64, 35, 68, 70, 72, 12, 14, 47, 16, 17, 18, 81, 83, 85, 48, 53, 28}
 17:  {84}
 18:  {82, 69}
 19:  {75, 31}
 20:  {45}
 21:  {38}
 22:  {76}
 23:  {26, 13}
 24:  {56}
 25:  {34, 67, 36, 37, 7, 9, 10, 43, 44, 77, 46, 59, 80, 49, 51, 54, 27, 63}
 26:  {39}
 27:  {3}
 28:  {4, 23}
 29:  {65}
 30:  {66}
 31:  {62, 71}
 32:  {20}
 33:  {22}
 34:  {1, 52}

[{82, 69}, {62, 71}]
  1:  {33}
  2:  {5, 8, 11, 50, 21, 86}
  3:  {57}
  4:  {58}
  5:  {29}
  6:  {42, 6}
  7:  {2}
  8:  {19}
  9:  {82}
 10:  {60, 30}
 11:  {41, 78, 15, 79, 55, 61}
 12:  {74}
 13:  {24}
 14:  {73}
 15:  {40}
 16:  {65, 45}
 17:  {25}
 18:  {32, 64, 35, 68, 70, 72, 12, 14, 47, 16, 17, 18, 81, 83, 85, 48, 53, 28}
 19:  {62}
 20:  {84}
 21:  {75, 31}
 22:  {69}
 23:  {38}
 24:  {76}
 25:  {26