# Unique Canonical Representation

A test to determine if alignment procedures construct and align to a unique canonical representation.

1. Take a molecular structure from QM9
2. Normalize the structure using <METHOD>
3. For N times:

    i. Transform initial structure
   
    ii. Normalize transformed structure

    iii. Compute Wasserstein distance between structures

    iv. Store loss by rank/point group

Times are reported by TQDM, average loss is reported by rank/point group.

In [1]:
import torch
import ot  # Python Optimal Transport (POT) library

from pointgroup import PointGroup
from torch_geometric.datasets import QM9
from scipy.spatial.transform import Rotation as R
import numpy as np
import torch

from torch_geometric.loader import DataLoader


from tqdm import tqdm

import sys

In [2]:
def compute_wasserstein_distance(pc1, pc2):
    """
    Compute the Wasserstein distance (EMD) between two point clouds.

    Args:
    pc1 (torch.Tensor): First point cloud (NxD).
    pc2 (torch.Tensor): Second point cloud (MxD).

    Returns:
    float: The Wasserstein distance.
    """
    # Convert PyTorch tensors to NumPy arrays if necessary
    if isinstance(pc1, torch.Tensor):
        pc1 = pc1.detach().cpu().numpy()
    if isinstance(pc2, torch.Tensor):
        pc2 = pc2.detach().cpu().numpy()

    # Create uniform distribution for each point cloud
    n1, n2 = pc1.shape[0], pc2.shape[0]
    a, b = ot.unif(n1), ot.unif(n2)

    # Compute cost matrix
    M = ot.dist(pc1, pc2, metric='euclidean')

    # Compute Wasserstein distance (EMD)
    emd_distance = ot.emd2(a, b, M)

    return emd_distance

## PCA

In [3]:
from typing import Any

from abc import ABCMeta

import torch
from torch import Tensor


class SVDAlignment(metaclass=ABCMeta):
  def __call__(self, pointcloud: Tensor, *args: Any, **kwds: Any) -> Any:
    self.pointcloud = pointcloud # (n,m)-dimensional tensor
    pass


  def align_center(self, pointcloud):
    return pointcloud - pointcloud.mean(dim=0)

  def get_eigs(self, pointcloud):
    C = torch.matmul(pointcloud.t(), pointcloud)
    e, v = torch.linalg.eig(C)  # v[:,j] is j-th eigenvector
    return torch.view_as_real(e), v.real


  def svd_rotate(self, pointcloud):
    e, v = self.get_eigs(pointcloud)
    indices = e[:, 0].argsort(descending=True)
    v = v.t()[indices].t()
    return torch.matmul(pointcloud, v)

In [4]:
from pointgroup import PointGroup
from torch_geometric.datasets import QM9
from scipy.spatial.transform import Rotation as R
np.random.seed(42)

qm9 = QM9(root='../datasets/qm9-2.4.0/')
SVDA = SVDAlignment()
pg_losses = {}

atomic_number_to_symbol = {
    1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'
    }


rank1_loss, rank1_count = 0,0
rank2_loss, rank2_count = 0,0
rank3_loss, rank3_count = 0,0

for idx,data in enumerate(tqdm(qm9[:10000])):
    point_cloud = data.pos
    rank = torch.linalg.matrix_rank(point_cloud)
    cat_data = data.z.numpy()
    SVDA(point_cloud)
    normalized_data = SVDA.pointcloud  
    
    positions = data.pos.numpy()
    atomic_numbers = data.z.numpy()
    symbols = [atomic_number_to_symbol[i] for i in atomic_numbers]
    try:
        pg = PointGroup(positions, symbols).get_point_group()
    except:
        pg = 'C1'
    if pg not in pg_losses:
        pg_losses[pg]={'loss':0, 'count':0}
        
    rank = torch.linalg.matrix_rank(data.pos)

    for i in range(10):
        random_rotation = R.random().as_matrix()
        random_translation = np.random.rand(3)
        
        point_cloud = data.pos
        point_cloud = (random_rotation@(point_cloud+random_translation).numpy().T).T
        cat_data = data.z.numpy()
        SVDA(torch.from_numpy(point_cloud))
        new_normalization = SVDA.pointcloud  
        loss = compute_wasserstein_distance(normalized_data, new_normalization)

        if rank==1:
            rank1_loss += loss
            rank1_count +=1
        elif rank==2:
            rank2_loss += loss
            rank2_count +=1
        else:
            rank3_loss += loss
            rank3_count +=1
            
        pg_losses[pg]['loss']+=loss
        pg_losses[pg]['count']+=1


print(f'Rank 1 Loss: {rank1_loss/(rank1_count+1e-16):.5f}, Rank 2 Loss: {rank2_loss/(rank2_count+1e-16):.5f}, Rank 3 Loss: {rank3_loss/(rank3_count+1e-16):.5f}')

for key, dct in pg_losses.items():
    val = dct['loss']/dct['count']
    print(f'\tPoint Group {key} : {val}')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:22<00:00, 437.12it/s]

Rank 1 Loss: 1.89805, Rank 2 Loss: 2.10510, Rank 3 Loss: 2.43520
	Point Group Td : 1.9635413812800062
	Point Group C3v : 2.4066188267868553
	Point Group C2v : 2.699401920794461
	Point Group Dinfh : 2.223506417244732
	Point Group Cinfv : 2.5627570800703396
	Point Group D3d : 2.5278595713494263
	Point Group C1 : 2.3920799210127237
	Point Group Cs : 2.802666677890567
	Point Group D3h : 2.1726454635623296
	Point Group C2 : 2.5476797477799344
	Point Group C2h : 2.5438649185049442
	Point Group D2d : 1.9552719477082061
	Point Group C1v : 2.4932992610916296
	Point Group C1h : 2.23501501641634
	Point Group D6h : 2.716527320883651
	Point Group D2h : 2.672102587412916
	Point Group C3 : 1.9051950106446491
	Point Group S2 : 5.055503645681741
	Point Group C3h : 2.1741920364115215
	Point Group D3 : 1.596167541153601
	Point Group Ci : 2.7027557023920594





## Equivariant AE (GIAE)

In [None]:



sys.path.append("../training/models/")
from giae_model import Model

np.random.seed(42)



qm9 = QM9(root='../datasets/qm9-2.4.0/')
model = Model(hidden_dim=256, emb_dim=32, num_layers=5).to("cpu")
model.load_state_dict(torch.load('../training/models/giae_model.pth'))  # Load the state dict
pg_losses = {}

atomic_number_to_symbol = {
    1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'
    }


rank1_loss, rank1_count = 0,0
rank2_loss, rank2_count = 0,0
rank3_loss, rank3_count = 0,0

for idx,data in enumerate(tqdm(qm9[:10000])):
    loader = DataLoader([data], batch_size=1, shuffle=False)
    batch = next(iter(loader))
    batch.pos = batch.pos - batch.pos.mean(dim=0)
    pos_out, perm, vout, rot = model(batch, hard=False)
    pos_out = pos_out - pos_out.mean(dim=0)
    normalized_data = (pos_out@rot.squeeze()).detach().numpy()
        
    positions = data.pos.numpy()
    atomic_numbers = data.z.numpy()
    symbols = [atomic_number_to_symbol[i] for i in atomic_numbers]
    try:
        pg = PointGroup(positions, symbols).get_point_group()
    except:
        pg = 'C1'
    if pg not in pg_losses:
        pg_losses[pg]={'loss':0, 'count':0}
        
    rank = torch.linalg.matrix_rank(batch.pos)

    for i in range(10):
        random_rotation = R.random().as_matrix()
        random_translation = np.random.rand(3)
        
        batch.pos = torch.from_numpy((random_rotation@(batch.pos.detach().numpy()+random_translation).T).T).to(torch.float)
        pos_out, perm, vout, rot = model(batch, hard=False)
        pos_out = pos_out - pos_out.mean(dim=0)
        new_normalization = (pos_out@rot.squeeze()).detach().numpy()
        loss = compute_wasserstein_distance(normalized_data, new_normalization)

        if rank==1:
            rank1_loss += loss
            rank1_count +=1
        elif rank==2:
            rank2_loss += loss
            rank2_count +=1
        else:
            rank3_loss += loss
            rank3_count +=1
            
        pg_losses[pg]['loss']+=loss
        pg_losses[pg]['count']+=1


print(f'Rank 1 Loss: {rank1_loss/(rank1_count+1e-16):.5f}, Rank 2 Loss: {rank2_loss/(rank2_count+1e-16):.5f}, Rank 3 Loss: {rank3_loss/(rank3_count+1e-16):.5f}')

for key, dct in pg_losses.items():
    val = dct['loss']/dct['count']
    print(f'\tPoint Group {key} : {val}')

 66%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                           | 6565/10000 [20:34<12:36,  4.54it/s]

## ASUN

In [None]:
sys.path.append('../pyorbit/')
from CategoricalPointCloud import CatFrame as Frame
np.random.seed(42)

qm9 = QM9(root='../datasets/qm9-2.4.0/')
frame = Frame()
pg_losses = {}

atomic_number_to_symbol = {
    1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'
    }


rank1_loss, rank1_count = 0,0
rank2_loss, rank2_count = 0,0
rank3_loss, rank3_count = 0,0

for idx,data in enumerate(tqdm(qm9[:10000])):
    point_cloud = data.pos
    rank = torch.linalg.matrix_rank(point_cloud)
    cat_data = data.z.numpy()
    normalized_data, rot = frame.get_frame(point_cloud, cat_data)
        
    positions = data.pos.numpy()
    atomic_numbers = data.z.numpy()
    symbols = [atomic_number_to_symbol[i] for i in atomic_numbers]
    try:
        pg = PointGroup(positions, symbols).get_point_group()
    except:
        pg = 'C1'
    if pg not in pg_losses:
        pg_losses[pg]={'loss':0, 'count':0}
        
    rank = torch.linalg.matrix_rank(data.pos)

    for i in range(10):
        random_rotation = R.random().as_matrix()
        random_translation = np.random.rand(3)
        
        point_cloud = data.pos
        point_cloud = (random_rotation@(point_cloud+random_translation).numpy().T).T
        cat_data = data.z.numpy()
        new_normalization, rot = frame.get_frame(point_cloud, cat_data)
        loss = compute_wasserstein_distance(normalized_data, new_normalization)

        if rank==1:
            rank1_loss += loss
            rank1_count +=1
        elif rank==2:
            rank2_loss += loss
            rank2_count +=1
        else:
            rank3_loss += loss
            rank3_count +=1
            
        pg_losses[pg]['loss']+=loss
        pg_losses[pg]['count']+=1


print(f'Rank 1 Loss: {rank1_loss/(rank1_count+1e-16):.5f}, Rank 2 Loss: {rank2_loss/(rank2_count+1e-16):.5f}, Rank 3 Loss: {rank3_loss/(rank3_count+1e-16):.5f}')

for key, dct in pg_losses.items():
    val = dct['loss']/dct['count']
    print(f'\tPoint Group {key} : {val}')