In [2]:
import json
import torch
import numpy as np
import mdtraj as md
import nglview as nv

import openmm as mm
import openmm.unit as unit
from openmm import app

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

PHI_ANGLE = [4, 6, 8, 14]
PSI_ANGLE = [6, 8, 14, 16]
ALANINE_HEAVY_ATOM_IDX = [1, 4, 5, 6, 8, 10, 14, 15, 16, 18]
ALANINE_HEAVY_ATOM_IDX_TBG = [0, 4, 5, 6, 8, 10, 14, 15, 16, 18]

n_particles = 22
n_dimensions = 3
scaling = 10
dim = n_particles * n_dimensions

def compute_dihedral(positions):
    """http://stackoverflow.com/q/20305272/1128289"""
    def dihedral(p):
        if not isinstance(p, np.ndarray):
            p = p.numpy()
        b = p[:-1] - p[1:]
        b[0] *= -1
        v = np.array([v - (v.dot(b[1]) / b[1].dot(b[1])) * b[1] for v in [b[0], b[2]]])
        
        # Normalize vectors
        v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1, 1)
        b1 = b[1] / np.linalg.norm(b[1])
        x = np.dot(v[0], v[1])
        m = np.cross(v[0], b1)
        y = np.dot(m, v[1])
        
        return np.arctan2(y, x)
    
    return np.array(list(map(dihedral, positions)))



In [11]:
def kabsch(
	reference_position: torch.Tensor,
	position: torch.Tensor,
) -> torch.Tensor:
    '''
        Kabsch algorithm for aligning two sets of points
        Args:
            reference_position (torch.Tensor): Reference positions (N, 3)
            position (torch.Tensor): Positions to align (N, 3)
        Returns:
            torch.Tensor: Aligned positions (N, 3)
    '''
    # Compute centroids
    centroid_ref = torch.mean(reference_position, dim=0, keepdim=True)
    centroid_pos = torch.mean(position, dim=0, keepdim=True)
    ref_centered = reference_position - centroid_ref  
    pos_centered = position - centroid_pos

    # Compute rotation, translation matrix
    covariance = torch.matmul(ref_centered.T, pos_centered)
    U, S, Vt = torch.linalg.svd(covariance)
    d = torch.linalg.det(torch.matmul(Vt.T, U.T))
    if d < 0:
        Vt = Vt.clone()
        Vt[-1] = Vt[-1] * -1
        #  Vt = torch.cat([Vt[:-1], -Vt[-1:].clone()], dim=0)
    rotation = torch.matmul(Vt.T, U.T)

    # Align position to reference_position
    aligned_position = torch.matmul(pos_centered, rotation) + centroid_ref
    return aligned_position

# Aligned dataset

In [3]:
current_xyz_tbg_loaded = torch.load("../dataset/alanine/300.0/tbg-10n-lag30/current-xyz.pt")
print(current_xyz_tbg_loaded.shape)

torch.Size([10000, 22, 3])


In [10]:
c5_state = torch.load("../data/alanine/c5-tbg.pt")
print(c5_state['xyz'].shape)

torch.Size([1, 22, 3])


In [14]:
aligned_data_list = []
for data in tqdm(current_xyz_tbg_loaded):
    aligned_data_list.append(kabsch(c5_state['xyz'].squeeze(), data))

aligned_data_list = torch.stack(aligned_data_list)
print(aligned_data_list.shape)

  0%|          | 0/10000 [00:00<?, ?it/s]

torch.Size([10000, 22, 3])


In [15]:
torch.save(aligned_data_list, "../dataset/alanine/300.0/tbg-10n-lag30/aligned-current-xyz.pt")

In [16]:
timelag_xyz_tbg_loaded = torch.load("../dataset/alanine/300.0/tbg-10n-lag30/timelag-xyz.pt")
print(timelag_xyz_tbg_loaded.shape)

aligned_timelag_data_list = []
for data in tqdm(timelag_xyz_tbg_loaded):
    aligned_timelag_data_list.append(kabsch(c5_state['xyz'].squeeze(), data))

aligned_timelag_data_list = torch.stack(aligned_timelag_data_list)
print(aligned_timelag_data_list.shape)
torch.save(aligned_timelag_data_list, "../dataset/alanine/300.0/tbg-10n-lag30/aligned-timelag-xyz.pt")

torch.Size([10000, 22, 3])


  0%|          | 0/10000 [00:00<?, ?it/s]

torch.Size([10000, 22, 3])
