In [None]:
import itertools
import numpy as np
from math import sqrt
from numpy import arange
from random import choice, shuffle
from scipy.sparse import csgraph, csr_matrix, triu
from ase.neighborlist import NeighborList, natural_cutoffs
from ase.build.rotate import rotation_matrix_from_points
from ase.io import read, write
from collections import Counter

In [None]:
class Bool:
    def __init__(self, value):
        self.value = value
    def __bool__(self):
        return self.value
    def true(self):
        self.value = True
    def false(self):
        self.value = False
        
class Integer:
    def __init__(self, value):
        self.value = value
    def __lt__(self, value):
        return self.value < value
    def __str__(self):
        return str(self.value)
    def copy(self):
        return Integer(self.value)
    def increase(self):
        self.value += 1
    def update(self, value):
        if isinstance(value, Integer):
            self.value = value.value
        else:
            self.value = value

class TypeList(dict):
    def __init__(self, typeof=[]):
        self.typeof = {}
        for i, j in enumerate(typeof):
            if j not in self:
                self[j] = set()
            self[j].add(i)
            self.typeof[i] = j
    def new(self):
        self.key = len(self)
        self[self.key] = set()
    def add(self, item):
        self[self.key].add(item)
        self.typeof[item] = self.key

def natural(indices):
    return [i+1 for i in indices]

def permutations(indexlist):
    return [tuple(zip(indexlist, x)) for x in itertools.permutations(indexlist)]

def shuffled(indices):
    indices = list(indices)
    shuffle(indices)
    return indices

def diffcount(mapping):
    nl1.update(mol1[mapping])
    matrix1 = nl1.get_connectivity_matrix()
    return csr_matrix.count_nonzero(triu(matrix0) != triu(matrix1))

def nodematch(node, mapping):
    nl1.update(mol1[mapping])
    matches = set(nl0.get_neighbors(node)[0]) & set(nl1.get_neighbors(node)[0])
    mismatches0 = set(nl0.get_neighbors(node)[0]) - matches
    mismatches1 = set(nl1.get_neighbors(node)[0]) - matches
    return matches, mismatches0, mismatches1

def moldiff(mapping):
    nl1.update(mol1[mapping])
    matrix1 = nl1.get_connectivity_matrix()
    return csr_matrix.count_nonzero(triu(matrix0) != triu(matrix1))

def moldist(mapping):
    p0 = mol0.get_positions()
    p1 = mol1[mapping].get_positions()
    R = rotation_matrix_from_points(p1.T, p0.T)
    p1 = np.dot(p1, R.T)
    return sqrt(3*((p1 - p0)**2).mean())

def nodedist(mapping, leaves):
    q0 = mol0[leaves].get_positions()
    q1 = mol1[mapping[leaves]].get_positions()
    return ((q1 - q0)**2).sum()

def get_atom_types(mol, nl):
    types = TypeList(mol.get_atomic_numbers())
    while True:
        newtypes = TypeList()
        for type in types.values():
            for i in type:
                if i not in newtypes.typeof:
                    newtypes.new()
                    newtypes.add(i)
                    for j in type:
                        if j not in newtypes.typeof:
                            if Counter([types.typeof[k] for k in nl.get_neighbors(i)[0]]) \
                            == Counter([types.typeof[k] for k in nl.get_neighbors(j)[0]]):
                                newtypes.add(j)
        if newtypes == types:
            return types
        types = newtypes

def get_neighbor_types(mol, nl, atom_types):
    neighbor_types = {}
    for i in [atom.index for atom in mol]:
        neighbor_types[i] = {}
        for k, atom_list in atom_types.items():
            if set(nl.get_neighbors(i)[0]) & atom_list:
                neighbor_types[i][k] = set(nl.get_neighbors(i)[0]) & atom_list
    return neighbor_types

In [None]:
# Función recursiva
def minadjdiff(node, mapping, tracked, depth=0):
    # Encontrar los nodos vecinos apareados y mal apareados y calcular los descriptores
    matches, mismatches0, mismatches1 = nodematch(node, mapping)
    #debug: Imprimir información del paso
    #print(' '*4*depth, f'{n+1}/{mapping[n]+1}:', natural(matches), natural(mismatches0), natural(mismatches1), moldiff(mapping), f'{moldist(mapping):.4f}')
    # Agregar el nodo n al conjunto de nodos recorridos
    tracked.add(node)
    # Aplicar la función recursivamente para cada nodo vecino apareado
    for i in shuffled(matches):
        if i not in tracked:
            minadjdiff(i, mapping, tracked, depth + 1)
    # Run over every mismatches neighbor in reference molecule
    for i in shuffled(mismatches0):
        # which is not yet tracked
        if i not in tracked:
            # Run over every mismatches neighbor in working molecule
            for j in shuffled(mismatches1):
                # Si los nodos corresponden al mismo elemento
                if mol0.symbols[i] == mol1[mapping].symbols[j]:
                    # Hacer una copia local del mapeo y del recorrido
                    tracked_branch = tracked.copy()
                    mapping_branch = mapping.copy()
                    # Hacer el intercambio
                    mapping_branch[i], mapping_branch[j] = mapping_branch[j], mapping_branch[i]
                    # Aplicar la función recursivamente a la rama que nace en el nodo intercambiado
                    minadjdiff(i, mapping_branch, tracked_branch, depth + 1)
                    # Si el intercambio no fue rechazado
                    if diffcount(mapping_branch) < diffcount(mapping):
                        # Agregar todos los nodos de la rama a la lista de recorridos
                        tracked.update(tracked_branch)
                        # Actualizar el mapeo con los intercambios hechos en la rama
                        mapping[:] = mapping_branch[:]
                        # Remove matched neighbors from mismatches
                        mismatches0.remove(i)
                        mismatches1.remove(j)
                        # Salir del ciclo y continuar con otro nodo
                        break
    # Run over non matches neighbors
    for i in mismatches0:
        if i not in tracked:
            minadjdiff(i, mapping, tracked, depth + 1)

In [None]:
def recursive_remap(node_a, node_b, mapping_ref, mapping, held, depth=0):
    #debug: Imprimir información del paso
    #print(' '*4*depth, [node_a, node_b], moldiff(mapping))
    held.update(atom_types_0[atom_types_0.typeof[node_a]])
    for group in neighbor_types_0[node_a]:
        for i, j in zip(neighbor_types_0[node_a][group] - held, neighbor_types_0[node_b][group] - held):
            if i != j:
                mapping[i] = mapping_ref[j]
                recursive_remap(i, j, mapping_ref, mapping, held.copy(), depth + 1)

In [None]:
# Función recursiva
def deep_eqvatom_perm(node, mapping, tracked, held, depth=0):
    # debug: El contador de permutaciones es global
    global total_permut
    # debug: Imprimir información del paso
    #print(' '*4*depth, f'{n+1}/{mapping[n]+1}:', moldiff(mapping), f'{moldist(mapping):.4f}')
    # Agregar el nodo n al conjunto de nodos recorridos
    tracked.add(node)
    # Agregar los nodos equivalentes al nodo n al conjunto de nodos equivalentes recorridos
    held.update(atom_types_0[atom_types_0.typeof[node]])
    # Iterar sobre todos los grupos de equivalencia
    for equiv_atoms in neighbor_types_0[node].values():
        equiv_neighbors = equiv_atoms - held - tracked
        if equiv_neighbors:
            # Inicalizar el diccionario de copias locales de las ramas
            tracked_permut = {}
            mapping_permut = {}
            moldist_permut = {}
            # Iterar sobre todas las permutaciones de los vecinos equivalentes
            for p in permutations(equiv_atoms - held - tracked):
                # debug: Contar el total de permutaciones
                total_permut += 1
                # Hacer una copia local del mapeo y del recorrido para cada permutación
                tracked_permut[p] = tracked.copy()
                mapping_permut[p] = mapping.copy()
                # Registrar el recorrido de las ramas permutadas
                for i, j in p:
                    if i != j:
                        # Aplicar la permutación al mapeo
                        mapping_permut[p][i] = mapping[j]
                        # Restaurar la conectividad de las ramas permutadas
                        recursive_remap(i, j, mapping, mapping_permut[p], held.copy(), depth + 1)
                # Continuar el recorrido recursivamente
                for i in equiv_neighbors:
                    deep_eqvatom_perm(i, mapping_permut[p], tracked_permut[p], held.copy(), depth + 1)
                # Calcular la distancia de la permutación
                moldist_permut[p] = moldist(mapping_permut[p])
            # Elegir la permutación que minimiza la distanca
            p = min(moldist_permut, key=moldist_permut.get)
            # Agregar todos los nodos recorridos en la permutación a la lista de recorridos
            tracked.update(tracked_permut[p])
            # Actualizar el mapeo con todos los intercambios hechos en la permutación
            mapping[:] = mapping_permut[p][:]

In [None]:
# Función recursiva
def eqvatom_perm(node, mapping, tracked, held, depth=0):
    # debug: El contador de permutaciones es global
    global total_permut
    # debug: Imprimir información del paso
    #print(' '*4*depth, f'{n+1}/{mapping[n]+1}:', moldiff(mapping), f'{moldist(mapping):.4f}')
    # Agregar el nodo n al conjunto de nodos recorridos
    tracked.add(node)
    # Agregar los nodos equivalentes al nodo n al conjunto de nodos equivalentes recorridos
    held.update(atom_types_0[atom_types_0.typeof[node]])
    # Iterar sobre todos los grupos de equivalencia
    for equiv_atoms in neighbor_types_0[node].values():
        equiv_neighbors = equiv_atoms - held - tracked
        if equiv_neighbors:
            # Inicalizar el diccionario de copias locales de las ramas
            mapping_permut = {}
            moldist_permut = {}
            # Iterar sobre todas las permutaciones de los vecinos equivalentes
            for p in permutations(equiv_neighbors):
                # debug: Contar el total de permutaciones
                total_permut += 1
                # Hacer una copia local del mapeo y del recorrido para cada permutación
                mapping_permut[p] = mapping.copy()
                # Aplicar la permutación al mapeo
                for i, j in p:
                    mapping_permut[p][i] = mapping[j]
                # Calcular la distancia de la permutación
                moldist_permut[p] = nodedist(mapping_permut[p], list(equiv_neighbors))
            # Elegir la permutación que minimiza la distanca
            p = min(moldist_permut, key=moldist_permut.get)
            # Restaurar la conectividad de las ramas permutadas
            for i, j in p:
                if i != j:
                    recursive_remap(i, j, mapping, mapping_permut[p], held.copy())
            # Restaurar la conectividad de las ramas permutadas restantes?
#            minadjdiff(node, mapping, tracked.copy(), depth)
            # Actualizar el mapeo con todos los intercambios de la permutación
            mapping[:] = mapping_permut[p][:]
            # Continuar el recorrido recursivamente
            for i in equiv_neighbors:
                eqvatom_perm(i, mapping, tracked, held.copy(), depth + 1)

In [None]:
path = './test24a.xyz'
mol0 = read(path, index=0)
mol1 = read(path, index=1)
cutoffs0 = natural_cutoffs(mol0, mult=1.2)
cutoffs1 = natural_cutoffs(mol0, mult=1.2)
nl0 = NeighborList(cutoffs0, skin=0, self_interaction=False, bothways=True)
nl1 = NeighborList(cutoffs1, skin=0, self_interaction=False, bothways=True)
nl0.update(mol0)
nl1.update(mol1)
matrix0 = nl0.get_connectivity_matrix()
matrix1 = nl1.get_connectivity_matrix()
natom = len(mol0)
print(f"There are {len(mol0)} atoms and {csgraph.connected_components(matrix0)[0]} molecule(s) in system 0")
print(f"There are {len(mol1)} atoms and {csgraph.connected_components(matrix1)[0]} molecule(s) in system 1")
print()
print(f"There are {csr_matrix.count_nonzero(triu(matrix0) != triu(matrix1))} initial differences with a distance of {moldist(list(range(natom))):.4f}")
#print(triu(matrix0) != triu(matrix1))

mapping = arange(natom)
minadjdiff(0, mapping, set())
print(f"There are {moldiff(mapping)} differences left after backtracking with a distance of {moldist(mapping):.4f}")
write('output1.xyz', [mol0, mol1[mapping]])
#if tracked != set(range(natom)):
#    print(natural(set(tracked) - set(range(natom))), natural(set(range(natom)) - set(tracked)))

#mapping1 = arange(natom)
mapping1 = mapping.copy()
atom_types_0 = get_atom_types(mol0, nl0)
neighbor_types_0 = get_neighbor_types(mol0, nl0, atom_types_0)
total_permut = 0
eqvatom_perm(0, mapping1, set(), set())
print(f"The smallest distance found among {total_permut} equivalent {moldiff(mapping1)}-difference permutations is {moldist(mapping1):.4f}")
write('output1.xyz', [mol0, mol1[mapping1]])