# Molecule Alignment

This provides visualization of the results in [./unique_canonical_representation.ipynb](./unique_canonical_representation.ipynb) but the visualization software has a very low level of compatibility for Jupyter.

In [1]:
from abc import ABCMeta
import ast

import torch
import numpy as np

import matplotlib.pyplot as plt

In [2]:
import sys
sys.path.append('../pyorbit/')
from CategoricalPointCloud import CatFrame as Frame

In [3]:
'''
Please let me know if you have better visualization libraries ,-,

'''


import py3Dmol
from rdkit import Chem
from rdkit.Chem import AllChem

import base64
from io import BytesIO
from PIL import Image
from IPython.display import display, HTML

def new_plot_mol(point_cloud, cat_data):
    mol = Chem.RWMol()
    for atom_number in cat_data:
        atom = Chem.Atom(int(atom_number))
        mol.AddAtom(atom)
        
    
    conf = Chem.Conformer(mol.GetNumAtoms())
    for i, pos in enumerate(point_cloud):
        conf.SetAtomPosition(i, (float(pos[0]), float(pos[1]), float(pos[2])))
    mol.AddConformer(conf)
    
    Chem.rdmolops.SanitizeMol(mol)
    mb = Chem.MolToMolBlock(mol)
    # pdb_string = Chem.MolToPDBBlock(mol)    
    # Visualization using py3Dmol
    viewer = py3Dmol.view(width=1600, height=900)
    viewer.addArrow({'start': {'x': 0, 'y': 0, 'z': 0}, 'end': {'x': 1, 'y': 0, 'z': 0}, 'radius': 0.1, 'color': 'red'})    # viewer.addModel(pdb_string, 'pdb')
    viewer.addArrow({'start': {'x': 0, 'y': 0, 'z': 0}, 'end': {'x': 0, 'y': 1, 'z': 0}, 'radius': 0.1, 'color': 'red'})    # viewer.addModel(pdb_string, 'pdb')
    viewer.addArrow({'start': {'x': 0, 'y': 0, 'z': 0}, 'end': {'x': 0, 'y': 0, 'z': 1}, 'radius': 0.1, 'color': 'red'})    # viewer.addModel(pdb_string, 'pdb')
    viewer.addModel(mb, 'mol')
    viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.2}})  # Adjust stick size and add spheres for better visibility    viewer.zoomTo()
    viewer.zoomTo()
    viewer.zoom(2.5)
    # viewer.rotate(10,'z')
    viewer.rotate(30,'x')
    viewer.rotate(30,'y')
    viewer.show()
    viewer.png()
    return viewer

In [4]:
from torch_geometric.datasets import QM9
from scipy.spatial.transform import Rotation as R
np.random.seed(42)

qm9 = QM9(root='../datasets/qm9-2.4.0/')
frame = Frame()
for i,data in enumerate(qm9):
    k=1
    if i>k:
        break
    elif i<k:
        continue
    else:
        for i in range(5):
            print(f'ROTATION {i}')
            random_rotation = R.random().as_matrix()
            random_translation = np.random.rand(3)
            point_cloud = data.pos - data.pos.mean(dim=0)
            point_cloud = (random_rotation@(point_cloud).numpy().T).T
            cat_data = data.z.numpy()
            new_plot_mol(point_cloud, cat_data)
            aligned_data, rot = frame.get_frame(point_cloud, cat_data)
            print(f'ALIGNMENT {i}')
            new_plot_mol(aligned_data, cat_data)

pass

ROTATION 0


ROTATION 0


ROTATION 1


ROTATION 1


ROTATION 2


ROTATION 2


ROTATION 3


ROTATION 3


ROTATION 4


ROTATION 4


In [11]:
from torch_geometric.datasets import QM9
from scipy.spatial.transform import Rotation as R
np.random.seed(42)

qm9 = QM9(root='../datasets/qm9-2.4.0/')  
frame = Frame()
for idx,data in enumerate(qm9):
    point_cloud = data.pos
    rank = torch.linalg.matrix_rank(point_cloud)
    if rank==1:
        for i in range(5):
            print(f'ROTATION {i}')
            random_rotation = R.random().as_matrix()
            random_translation = np.random.rand(3)
            point_cloud = data.pos - data.pos.mean(dim=0)
            point_cloud = (random_rotation@(point_cloud).numpy().T).T
            cat_data = data.z.numpy()
            new_plot_mol(point_cloud, cat_data)
            aligned_data, rot = frame.get_frame(point_cloud, cat_data)
            print(f'ALIGNMENT {i}')
            new_plot_mol(aligned_data, cat_data)
        break
    else:
        continue

ROTATION 0


ALIGNMENT 0


ROTATION 1


ALIGNMENT 1


ROTATION 2


ALIGNMENT 2


ROTATION 3


ALIGNMENT 3


ROTATION 4


ALIGNMENT 4


In [12]:
if __name__ == "__main__":
    from torch_geometric.datasets import QM9
    from scipy.spatial.transform import Rotation as R
    np.random.seed(42)
   
    qm9 = QM9(root='../datasets/qm9-2.4.0/')  
    frame = Frame()
    for idx,data in enumerate(qm9):
        point_cloud = data.pos
        rank = torch.linalg.matrix_rank(point_cloud, tol=1e-1)
        if rank==2:
            print(data.smiles)
            for i in range(5):
                print(f'ROTATION {i}')
                random_rotation = R.random().as_matrix()
                random_translation = np.random.rand(3)
                point_cloud = data.pos - data.pos.mean(dim=0)
                point_cloud = (random_rotation@(point_cloud).numpy().T).T
                cat_data = data.z.numpy()
                new_plot_mol(point_cloud, cat_data)
                aligned_data, rot = frame.get_frame(point_cloud, cat_data)
                print(f'ALIGNMENT {i}')
                new_plot_mol(aligned_data, cat_data)

            break
        else:
            continue

[H]O[H]
ROTATION 0


ALIGNMENT 0


ROTATION 1


ALIGNMENT 1


ROTATION 2


ALIGNMENT 2


ROTATION 3


ALIGNMENT 3


ROTATION 4


ALIGNMENT 4
