In [3]:
import numpy as np
import torch

import torch_geometric as torch_g

from scipy.spatial.distance import cdist, pdist
import scipy

import os
import ase.io
import ase
import sys
import copy

from torch_scatter import scatter_add

import numpy as np
import sys
sys.path.append("/home/share/DATA/NeuralOpt/Interpolations/Geodesic_interp")
from get_geodesic_energy import ATOMIC_RADIUS, morse_scaler

# Load Batch data

In [4]:
torch.set_default_dtype(torch.float64)

In [5]:
from torch_data_test.proc import *
from omegaconf import OmegaConf

config = OmegaConf.load("torch_data_test/config.yaml")
datamodule = GrambowDataModule(config)

loader = datamodule.train_dataloader()
for batch in loader:
    print(batch)
    break

debug] (GrambowDataModule) 
	base_path: /home/ksh/MolDiff/NeuralOpt/neural_opt
	root_path: /home/ksh/MolDiff/NeuralOpt/neural_opt/data
DataBatch(x=[29], edge_index=[2, 196], pos_1=[29, 3], pos_2=[29, 3], batch=[29], ptr=[3])


In [6]:
## wrapper class to save atoms object

class Wrapper:
    def __init__(self, atoms_0, atoms_T, q_type="DM", alpha=1.7, beta=0.01, gamma=0.01, using_jacobian=True, svd_tol=1e-4):
        self.atoms_0 = atoms_0
        self.atoms_T = atoms_T
        # assert q_type in ["DM", "morse"]
        self.q_type = q_type
        self.svd_tol = svd_tol
        self.re = torch.Tensor(self.get_re(atoms_T))
        self.alpha, self.beta = alpha, beta
        self.gamma = gamma
        self.morse_scaler = morse_scaler(self.re, self.alpha, self.beta)
        self.scaler_factor = 1.0
        self.using_jacobian = using_jacobian
        return

    def get_re(self, atoms, threshold=np.inf):
        from scipy.spatial import KDTree

        rijset = set()
        tree = KDTree(atoms.positions)
        pairs = tree.query_pairs(threshold)
        rijset.update(pairs)
        rijlist = sorted(rijset)

        radius = np.array([ATOMIC_RADIUS.get(atom.capitalize(), 1.5) for atom in atoms.get_chemical_symbols()])
        re = np.array([radius[i] + radius[j] for i, j in rijlist])
        return re

    def calc_inverse_jacobian(self, pos, q_type):
        edge_index, edg_length = self.pos_to_dist(pos)
        distance = pdist(pos)
        distance_e = self.get_re(self.atoms_T)
        inverse_jacobain = []

        for ij, d, de in zip(edge_index.T, distance, distance_e):
            jacob = torch.zeros(size=pos.size())
            i, j = ij
            pos_i, pos_j = pos[i], pos[j]
            d_pos = pos_i - pos_j
            if q_type == "DM":
                dr_dd = d / d_pos
                jacob[i] = dr_dd
                jacob[j] = - dr_dd
                
            elif q_type == "morse":
                dr_dq = d ** 3 * de  / - (self.alpha * np.exp(- self.alpha * (d / de - 1)) + self.beta * de ** 2) / d_pos
                jacob[i] = dr_dq
                jacob[j] = - dr_dq
            inverse_jacobain.append(jacob.flatten())
        return torch.stack(inverse_jacobain, dim=0)
                
        
    def calc_jacobian(self, pos, q_type):
        debug = False
        # pos = Tensor, (N, 3)
        edge_index, edge_length = self.pos_to_dist(pos)
        distance = pdist(pos)
        distance_e = self.get_re(self.atoms_T)

        jacobian = []
        for i_idx in range(len(pos)):
            j_idx = list(range(len(pos)))
            j_idx.remove(i_idx)
            j_idx = torch.LongTensor(j_idx)

            j_mask = torch.any(edge_index == i_idx, axis=0)
            dd_dx = torch.zeros(size=(len(edge_length), 3))
            dq_dx = torch.zeros(size=(len(edge_length), 3))
            pos_i = pos[i_idx].reshape(1, -1)
            pos_j = pos[j_idx]
            if debug:
                print(f"debug] (old) \n\ti_idx = {i_idx}", end="")
                print(f"\n\t j_idx = \n\t\t{j_idx}")
                print(f"\n\t maked j_idx = \n\t\t{edge_index[:, j_mask]}")
            dist = distance[j_mask].reshape(-1, 1)
            dd_dx[j_mask] += (pos_i - pos_j) / dist

            if q_type == "DM":
                jacobian.append(dd_dx.T)

            elif q_type == "morse":
                dq_dd = - (self.alpha / distance_e[j_mask]) * np.exp(-self.alpha * (distance[j_mask] - distance_e[j_mask]) / distance_e[j_mask])
                dq_dd -= self.beta * distance_e[j_mask] / (distance[j_mask] ** 2)
                if debug:
                    print(f"debug] (old) dq_dd = {dq_dd}")
                dq_dx[j_mask] += dd_dx[j_mask] * dq_dd.reshape(-1, 1)
                jacobian.append(dq_dx.T)
                
            elif q_type == "morese+DM":
                raise NotImplementedError

        return torch.cat(jacobian, dim=0)
    
    def calc_distance_hessian(self, pos, edge_index, distance):
        N = len(pos)
        K = len(edge_index)
        hessian = torch.zeros(size=(K, 3 * N, 3 * N))
        for k, (ij, d_ij) in enumerate(zip(edge_index, distance)):
            i, j = ij
            pos_i, pos_j = pos[i], pos[j]

            # calculate hessian related to i, j atoms
            d_pos = pos_i - pos_j
            hess_ij = d_pos.reshape(1, -1) * d_pos.reshape(-1, 1)
            hess_ij /= d_ij ** 3
            hess_ij -= torch.eye(3) / d_ij
            
            # calculate hessian related to i, i atoms
            hess_ii = d_pos.reshape(1, -1) * d_pos.reshape(-1, 1)
            hess_ii /= - d_ij ** 3
            hess_ii += torch.eye(3) / d_ij
            
            # hess_ii = hess_jj
            hess_jj = hess_ii

            hessian[k, 3 * i:3 * (i + 1), 3 * i:3 * (i + 1)] += hess_ii
            hessian[k, 3 * j:3 * (j + 1), 3 * j:3 * (j + 1)] += hess_jj
            hessian[k, 3 * i:3 * (i + 1), 3 * j:3 * (j + 1)] += hess_ij
            hessian[k, 3 * j:3 * (j + 1), 3 * i:3 * (i + 1)] += hess_ij
            
        return hessian
    
    def calc_hessian(self, pos, q_type=None):
        if q_type is None:
            q_type = self.q_type
            
        edge_index, edge_length = self.pos_to_dist(pos)
        edge_index = edge_index.T
        distance = pdist(pos)
        distance_e = self.get_re(self.atoms_T)

        hessian = self.calc_distance_hessian(pos, edge_index, distance)
        
        if q_type == "DM":
            return hessian
        
        elif q_type == "morse":
            dq_dd = - self.alpha / distance_e * np.exp(-self.alpha * (distance - distance_e) / distance_e)
            dq_dd -= self.beta * distance_e / (distance ** 2)
            hessian_q = hessian * dq_dd.reshape(-1, 1, 1)
            
            for k, (ij, d_ij, de_ij) in enumerate(zip(edge_index, distance, distance_e)):
                i, j = ij
                pos_i, pos_j = pos[i], pos[j]
                # calculate hessian related to i, j atoms
                d_pos = pos_i - pos_j
                hess_ij = d_pos.reshape(1, -1) * d_pos.reshape(-1, 1)
                hess_ij /= - d_ij ** 2
                coeff = self.alpha ** 2 / de_ij ** 2 * np.exp(-self.alpha * (d_ij - de_ij) / de_ij)  + 2 * self.beta * de_ij / (d_ij ** 3)
                hess_ij *= coeff

                # calculate hessian related to i, i atoms
                hess_ii = - hess_ij
                
                hessian_q[k, 3 * i:3 * (i + 1), 3 * i:3 * (i + 1)] += hess_ii
                hessian_q[k, 3 * j:3 * (j + 1), 3 * j:3 * (j + 1)] += hess_ii
                hessian_q[k, 3 * i:3 * (i + 1), 3 * j:3 * (j + 1)] += hess_ij
                hessian_q[k, 3 * j:3 * (j + 1), 3 * i:3 * (i + 1)] += hess_ij
                
            return hessian_q
        
        elif q_type == "morese+DM":
            raise NotImplementedError
        return 
             
    def eq_transform(self, score_d, pos, edge_index, edge_length):
        if self.using_jacobian:
            jacobian = self.calc_jacobian(pos, q_type=self.q_type)
            
            score_pos = jacobian @ score_d.reshape(-1, 1)
            return score_pos.reshape(-1, 3)
            
        if self.q_type == "morse":
            edge_length = torch.Tensor(pdist(pos))
            
            N = pos.size(0)
            dd_dr = - (self.alpha / self.re) * torch.exp(-self.alpha * (edge_length - self.re) / self.re) / edge_length
            dd_dr -= self.beta * self.re / (edge_length ** 3)
            dd_dr = dd_dr.reshape(-1, 1)
            dd_dr = dd_dr * (pos[edge_index[0]] - pos[edge_index[1]])
            score_d = score_d.reshape(-1, 1)
            score_d *= self.scaler_factor
            score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N)
            score_pos += scatter_add(-dd_dr * score_d, edge_index[1], dim=0, dim_size=N)
        
        elif self.q_type == "DM":
            N = pos.size(0)
            dd_dr = (1.0 / edge_length).reshape(-1, 1) * (pos[edge_index[0]] - pos[edge_index[1]])
            score_d = score_d.reshape(-1, 1)
            score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N)
            score_pos += scatter_add(-dd_dr * score_d, edge_index[1], dim=0, dim_size=N)
        
        elif self.q_type == "morse+DM":
            edge_length = torch.Tensor(pdist(pos))
            N = pos.size(0)
            dd_dr = - (self.alpha / self.re) * torch.exp(-self.alpha * (edge_length - self.re) / self.re) / edge_length
            dd_dr -= self.beta * self.re / (edge_length ** 3)
            dd_dr += self.gamma / edge_length
            dd_dr = dd_dr.reshape(-1, 1)
            score_d *= self.scaler_factor
            dd_dr = dd_dr * (pos[edge_index[0]] - pos[edge_index[1]])
            score_d = score_d.reshape(-1, 1)
            score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N)
            score_pos += scatter_add(-dd_dr * score_d, edge_index[1], dim=0, dim_size=N)
        else:
            raise NotImplementedError
        return score_pos

    def pos_to_dist(self, pos, q_type=None):
        if q_type is None:
            q_type = self.q_type
        if q_type == "morse":
            rij = pdist(pos)
            wij = self.morse_scaler(rij)[0] * self.scaler_factor
            # print(wij, type(wij))
            # length = torch.Tensor(wij)
            length = wij
            index = torch.LongTensor(np.stack(np.triu_indices(len(pos), 1)))
        elif q_type == "DM":
            length = torch.Tensor(pdist(pos))
            index = torch.LongTensor(np.stack(np.triu_indices(len(pos), 1)))
        elif q_type == "morse+DM":
            rij = pdist(pos)
            wij = self.morse_scaler(rij)[0] * self.scaler_factor
            wij += self.gamma * rij
            length = torch.Tensor(wij)
            index = torch.LongTensor(np.stack(np.triu_indices(len(pos), 1)))
        else:
            raise NotImplementedError
        return index, length

    def reverse_diffusion_process(self, x_t, t, dt, params, x_0, x_T, coord="Cartesian", h_coeff=0.0, verbose=True,
                                  using_jacobian=True, sampling_test=1, inner_iteration=5):
        beta_t = params.beta(t)

        if coord == "Cartesian":
            diff, coeff, v1, v2, v3, v4 = self.reverse_score(x_t, t, params, x_0, x_T, verbose=verbose)
            reverse_score_ = diff * coeff
            dw = torch.sqrt(beta_t * dt) * torch.randn_like(diff)
            dx = - 1.0 * reverse_score_ * dt + dw
        else:
            diff, coeff, v1, v2, v3, v4 = self.reverse_score2(x_t, t, params, x_0, x_T, verbose=verbose)
            index, d_t = self.pos_to_dist(x_t)

            if sampling_test == 0:
                # every displacement is first calculated on the q-space, and then transformed to the Cartesian space
                reverse_score_ = diff * coeff
                dw = torch.sqrt(beta_t * dt) * torch.randn_like(diff)
                dd = - 1.0 * reverse_score_ * dt + dw

                dx = self.eq_transform(dd, x_t, index, d_t)
            
            elif sampling_test == 1:
                reverse_score_ = diff * coeff
                dw = torch.sqrt(beta_t * dt) * torch.randn_like(diff)
                dd = - 1.0 * reverse_score_ * dt + dw   
                x_tm1 = self.exponential_ode_solver(x_t, -dd, q_type=self.q_type, num_iter=inner_iteration, check_dot_every=3)
                dx = - x_tm1 + x_t
                
        x_tm1 = x_t - dx
        return x_tm1, v1, v2, v3, v4

    def reverse_ode_process(self, x_t, t, dt, params, x_0, x_T, coord="Cartesian", h_coeff=0.0, verbose=True,
                            using_jacobian=True, sampling_test=1, inner_iteration=5):
        beta_t = params.beta(t)

        if coord == "Cartesian":
            diff, coeff, v1, v2, v3, v4 = self.reverse_score(x_t, t, params, x_0, x_T, verbose=verbose)
            reverse_score_ = diff * coeff
            dx = - 0.5 * reverse_score_ * dt
            print(f"Debug ({t:0.3f}): \n\t1) diff norm and dx norm {diff.norm():0.4f}, {dx.norm():0.6f}")
        else:
            diff, coeff, v1, v2, v3, v4 = self.reverse_score2(x_t, t, params, x_0, x_T, verbose=verbose)
            index ,d_t = self.pos_to_dist(x_t)

            if sampling_test == 0:
                reverse_score_ = diff * coeff
                dd = - 0.5 * reverse_score_ * dt
                dx = self.eq_transform(dd, x_t, index, d_t)                
                # Want to check why eq-transform does not work well
                diff_d = diff
                diff_x = self.eq_transform(diff_d, x_t, index, d_t)
                print(f"Debug ({t:0.3f}): \n\t1) diff-d norm and diff-x norm {diff_d.norm():0.4f}, {diff_x.norm():0.4f} \n\t2) dd-norm and dx-norm {dd.norm():0.6f}, {dx.norm():0.6f}")
                print(f"\t3) dx-norm/dd-norm {dx.norm()/dd.norm():0.6f}")
                
            elif sampling_test == 1:
                reverse_score_ = diff * coeff
                dd = - 0.5 * reverse_score_ * dt
                print(f"debug ] time : {t:0.3f}")
                print(f"debug ] diff.norm() : {diff.norm()}")
                print(f"debug ] dd.norm() : {dd.norm()}")
                x_tm1 = self.exponential_ode_solver(x_t, -dd, q_type=self.q_type, num_iter=inner_iteration, check_dot_every=3)
                dx = - x_tm1 + x_t
                print(f"debug ] dx.norm() : {dx.norm()}")
                
        x_tm1 = x_t - dx
        return x_tm1, v1, v2, v3, v4

    def reverse_score(self, x_t, t, params, x_0, x_T, verbose=True):
        # calculate parameters
        beta_t = params.beta(t)
        sigma_t_square = params.sigma_square(t)
        sigma_T_square = params.sigma_1

        SNRTt = params.SNR(t)
        sigma_t_hat_square = sigma_t_square * (1 - SNRTt)

        # calc mu_hat
        mu_hat = x_T * SNRTt + x_0 * (1 - SNRTt)

        # calc difference
        diff = mu_hat - x_t

        # calc_score    
        coeff =  1 / (sigma_t_hat_square) * beta_t
        score = diff * coeff
        
        # for debug
        if self.q_type == "DM":
        # if self.q_type in ["DM", "morse"]: # debugging # calculate err corresponding the metric
            _, d_T = self.pos_to_dist(x_T)
            _, d_t = self.pos_to_dist(x_t)
            _, d_0 = self.pos_to_dist(x_0)
            _, d_mu_hat = self.pos_to_dist(mu_hat)
            v1 = (d_mu_hat - d_t).abs().mean()
            # v2 = (d_mu_hat - d_T).abs().mean()
            v2 = (d_0 - d_t).abs().mean()
            v3 = (mu_hat - x_t.numpy()).abs().mean()
            v4 = (mu_hat - x_T.numpy()).abs().mean()
        # elif self.q_type == "morse":
        elif self.q_type in ["morse", "morse+DM", "Cartesian"]:
            version = "DMAE"
            # version = "Morse-RMSD"
            if version == "DMAE":
                d_T = torch.Tensor(pdist(x_T))
                d_mu_hat = torch.Tensor(pdist(mu_hat))  # typo=2의 경우, 이렇게 하면 안될 듯.
                d_t = torch.Tensor(pdist(x_t))
                d_0 = torch.Tensor(pdist(x_0))
                v1 = (d_mu_hat - d_t).abs().mean()
                # v2 = (d_mu_hat - d_T).abs().mean()
                v2 = (d_0 - d_t).abs().mean()
                v3 = abs(mu_hat - x_t.numpy()).mean()
                v4 = abs(mu_hat - x_T.numpy()).mean()
            else:
                _, d_T = self.pos_to_dist(x_T)
                _, d_mu_hat = self.pos_to_dist(mu_hat)  # typo=2의 경우, 이렇게 하면 안될 듯.
                _, d_t = self.pos_to_dist(x_t)
                _, d_0 = self.pos_to_dist(x_0)
                v1 = (d_mu_hat - d_t).norm()
                # v2 = (d_mu_hat - d_T).abs().mean()
                v2 = (d_0 - d_t).norm()
                v3 = abs(mu_hat - x_t.numpy()).mean()
                v4 = abs(mu_hat - x_T.numpy()).mean()
        else:
            raise NotImplementedError
        if verbose:
            print(f"{t:0.3f}\t{v1:0.4f}\t\t{v2:0.4f}\t\t{v3:0.4f}\t\t{v4:0.4f}\t\t{torch.linalg.norm(score, dim=-1).max():0.4f}")
        return diff, coeff, v1, v2, v3, v4

    def reverse_score2(self, x_t, t, params, x_0, x_T, verbose=True):
        # calculate parameters
        beta_t = params.beta(t)
        sigma_t_square = params.sigma_square(t)
        sigma_T_square = params.sigma_1

        SNRTt = params.SNR(t)
        sigma_t_hat_square = sigma_t_square * (1 - SNRTt)

        # calc mu_hat
        typo = 2
        
        if typo == 1:
            mu_hat = x_T * SNRTt + x_0 * (1 - SNRTt)
            _, d_mu_hat = self.pos_to_dist(mu_hat)
        if typo == 2:
            _, d_0 = self.pos_to_dist(x_0)
            _, d_T = self.pos_to_dist(x_T)
            d_mu_hat = d_T * SNRTt + d_0 * (1 - SNRTt)
            mu_hat = x_T * SNRTt + x_0 * (1 - SNRTt)  # for debugging
        if typo == 3:
            mu_hat = interpolate_LST(x_0.numpy(), x_T.numpy(), SNRTt.item())
            _, d_mu_hat = self.pos_to_dist(mu_hat)
        
        # calc difference
        index, d_t = self.pos_to_dist(x_t)
        diff_d = d_mu_hat - d_t
        diff = diff_d
        coeff =  1 / (sigma_t_hat_square) * beta_t

        # for debugging
        d_T = torch.Tensor(pdist(x_T))
        d_mu_hat = torch.Tensor(pdist(mu_hat))  # typo=2의 경우, 이렇게 하면 안될 듯.
        d_t = torch.Tensor(pdist(x_t))
        d_0 = torch.Tensor(pdist(x_0))
        v_loss_mae = (d_mu_hat - d_t).abs().mean()  # DMAE
        v_acc_mae = (d_0 - d_t).abs().mean()  # DMAE

        original_q_type = copy.deepcopy(self.q_type)
        self.q_type = "morse"
        _, d_T = self.pos_to_dist(x_T)
        _, d_mu_hat = self.pos_to_dist(mu_hat)  # typo=2의 경우, 이렇게 하면 안될 듯.
        _, d_t = self.pos_to_dist(x_t)
        _, d_0 = self.pos_to_dist(x_0)
        v_loss_norm = (d_mu_hat - d_t).norm()  # q-norm
        v_acc_norm = (d_0 - d_t).norm()  # q-norm
        self.q_type = original_q_type
        

        if verbose:
            print(f"{t:0.3f}\t{v_loss_mae:0.4f}\t\t{v_acc_mae:0.4f}\t\t{v_loss_norm:0.4f}\t\t{v_acc_norm:0.4f}\t")
        return diff, coeff, v_loss_mae, v_acc_mae, v_loss_norm, v_acc_norm
    
    def exponential_ode_solver(self, x0, q_dot0, q_type="morse", num_iter=100, check_dot_every=10, thresh=1e-2, max_dt=1e-1, verbose=False):
        
        def one_step(x, x_dot, q_type=q_type, dt=1e-2, wrapper=self, refine_xdot=False, verbose=False):
            hess = wrapper.calc_hessian(x.reshape(-1, 3), q_type=q_type)
            jacob = wrapper.calc_jacobian(x.reshape(-1, 3), q_type=q_type).T
            
            # J, J_inv = wrapper.refine_jacobian(jacob)
            J = jacob
            J_inv = torch.linalg.pinv(J, rtol=1e-4, atol=1e-2)

            JG = J_inv.T
            if refine_xdot:
                x_dot = J_inv @ J @ x_dot
                if verbose:
                    print(f"\t\t\tdebug: x_dot norm = {x_dot.norm():0.6f}")
            christoffel = torch.einsum("mij, mk->kij", hess, JG)
            x_ddot = - torch.einsum("j,kij,i->k", x_dot, christoffel, x_dot)
            
            # x_ddot and x_dot should be perpendicular 
            q_ddot = J @ x_ddot
            q_dot = J @ x_dot
            
            new_x = x + x_dot * dt
            new_x_dot = x_dot + x_ddot * dt
            
            # dotproduct
            if verbose:
                print(f"\t\tdebug: <x_ddot, x_dot> = {(q_ddot * q_dot).sum()}")
                print(f"\t\tdebug: <x_ddot, x_dot> = {((jacob.T @ jacob) @ x_dot.reshape(-1, 1) * x_ddot).sum()}")
                print(f"\t\tdebug: x_dot size = {x_dot.norm():0.8f}, x_ddot size = {x_ddot.norm():0.8f}")
                print(f"\t\tdebug: dx norm = {(new_x - x).norm():0.8f}, dx_dot norm = {(new_x_dot - x_dot).norm():0.8f}")
            return new_x, new_x_dot

        
        jacob = self.calc_jacobian(x0, q_type=q_type).T
        # J, J_inv = self.refine_jacobian(jacob)
        J = jacob
        J_inv = torch.linalg.pinv(J, rtol=1e-4, atol=self.svd_tol)

        # debugging
        proj_q_dot = J @ J_inv @ q_dot0
        if verbose >= 1:
            print(f"\tdebug: proj_q_dot norm = {proj_q_dot.norm():0.4f}")
            print(f"\tdebug: proj_q_dot norm-ratio = {(proj_q_dot - q_dot0).norm()/ q_dot0.norm():0.4f}")
        
        # initialization
        x_dot0 = J_inv @ q_dot0
        x = x0.flatten()
        x_dot = x_dot0

        total_time = x_dot.norm()
        x_dot = x_dot / x_dot.norm()
        
        q = self.pos_to_dist(x.reshape(-1, 3))[1]
        # make time grid, 0 ~ total_time.
        # time spacing should be smaller than 1e-2
        # thresh = 5e-2
        # thresh = 1e-1
        # if total_time > num_iter * thresh:
        #     num_iter = int(total_time / thresh)
        
        # t = torch.linspace(0, total_time, num_iter + 1)[:-1]
        # # t = torch.linspace(0, 1, num_iter + 1)[:-1]
        # dt_ = t[1] - t[0]
        # dt = dt_
        # print(f"\tdebug: x_dot0.norm() = {x_dot0.norm()}")
        # print(f"\tdebug: x_dot0.norm() = {total_time.item():0.6f}")
        # solve the geodesic ODE iteratively
        # for i, t_i in enumerate(t):
        
        dt_ = min(total_time / num_iter, thresh)
        dt = dt_
        if verbose >= 1:
            print(f"initial dt = {dt_:0.6f}, total_expected_iter = {total_time / dt_:1.0f}")
            
        if verbose == 1:
            print("Progress-bar\n0%[--------------------]100%")
            print("0%[", end="")
        current_time = 0
        iter = 0
        cnt = 0
        total_dq = 0
        while total_time > current_time:
            # do_refine = i % check_dot_every == 0
            do_refine = False
            x_new, x_dot_new = one_step(x, x_dot, q_type=q_type, dt=dt, wrapper=self, refine_xdot=do_refine, verbose=verbose >= 3)
            current_time += dt
            
            # calculate dq
            q_new = self.pos_to_dist(x_new.reshape(-1, 3))[1]
            dq = (q_new - q).norm()
            total_dq += dq
            q = q_new
            if verbose >= 2:
                if iter % 25 == 0:
                    print(f"\tdebug: time = ({(current_time / total_time) * 100:0.4f}%), iter = {iter}, dt = {dt:0.6f}, dq = {dq:0.6f}")
            iter += 1
            
            x = x_new
            x_dot = x_dot_new
            # dt = x_dot.norm() * dt_
            dt = min(max(dt_, 1 / x_dot.norm() * dt_), max_dt)
            if current_time / total_time > cnt / 10 and verbose == 1:
                print(f"--", end="")
                cnt += 1
            if total_time - current_time < dt:
                dt = total_time - current_time
        # for i, t_i in enumerate(t):
        #     print(f"\tdebug: time = {t_i:0.4f} ({i/len(t) * 100:0.2f}%)")
        #     # do_refine = i % check_dot_every == 0
        #     do_refine = False
        #     x, x_dot = one_step(x, x_dot, q_type=q_type, dt=dt, wrapper=self, refine_xdot=do_refine)
        if verbose == 1:
            print("]100%")
        return x.reshape(-1, 3), iter, total_dq
    
    
    def svd(self, jacob, verbose=False):
        U, S, Vh = torch.linalg.svd(jacob)
        num_zeros = (S < self.svd_tol).sum()
        dim = len(S) - num_zeros
        S = S[:dim]
        U = U[:, :dim]
        Vh = Vh[:dim, :]
        if verbose:
            print(f"\t\t\tdebug: dim = {dim}, num_zeros = {num_zeros}, singular values = {S[-1].item():0.6f} ~ {S[0].item():0.6f}")
        return U, S, Vh

    def refine_jacobian(self, jacob):
        # find non-zero singular values
        U, S, Vh = self.svd(jacob)
        J = U @ torch.diag(S) @ Vh
        J_inv = Vh.T @ torch.diag(1 / S) @ U.T
        return J, J_inv
        

## Compute $d_e$, $q_{ij}$..

In [7]:
# 1, 6, 7, 8
atomic_radius = torch.zeros(100)
atomic_radius[1] = ATOMIC_RADIUS["H"]
atomic_radius[6] = ATOMIC_RADIUS["C"]
atomic_radius[7] = ATOMIC_RADIUS["N"]
atomic_radius[8] = ATOMIC_RADIUS["O"]
atomic_radius = torch.Tensor(atomic_radius)

def compute_de(edge_index, atom_type, atomic_radius):
    """
    Compute equilibrium distance between two atom.
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E)
        atom_type (torch.Tensor): atom type tensor (N, )
        atomic_radius (torch.Tensor): pre-defined atomic radius tensor (100, )
    Returns:
        d_e_ij (torch.Tensor): equilibrium distance tensor (E, )
    """
    d_e_ij = atomic_radius[atom_type[edge_index]].sum(0)
    return d_e_ij

def compute_d(edge_index, pos):
    """
    Compute distance between two atom.
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E)
        pos (torch.Tensor): position tensor (N, 3)
    Returns:
        d_ij (torch.Tensor): distance tensor (E, )
    """
    i, j = edge_index
    return (pos[i] - pos[j]).norm(dim=1)

def compute_q(edge_index, atom_type, pos, alpha=1.6, beta=2.3):
    """
    Compute 'morse' like coordinate.
    q_ij = exp(-alpha * (d_ij - d_e_ij) / d_e_ij) + beta * (d_e_ij / d_ij)
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E)
        atom_type (torch.Tensor): atom type tensor (N, )
        pos (torch.Tensor): position tensor (N, 3) 
    Returns:
        q_ij (torch.Tensor): q tensor (E, )
    """
    d_ij = compute_d(edge_index, pos)
    d_e_ij = compute_de(edge_index, atom_type, atomic_radius)
    q_ij = torch.exp(-alpha * (d_ij - d_e_ij) / d_e_ij) + beta * (d_e_ij / d_ij)
    return q_ij

In [14]:
def my_repeat(x, n):
    return x.unsqueeze(-1).expand(-1, n).flatten()

def my_stack(x):
    return torch.stack([3 * x, 3 * x + 1, 3 * x + 2]).T.flatten()

def sparse_calc_jacobian_d(edge_index, pos):
    """
    Compute jacobian matrix of d_ij.
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E)
        atom_type (torch.Tensor): atom type tensor (N, )
        pos (torch.Tensor): position tensor (N, 3)
    Returns:
        jacobian (torch.Tensor, ): jacobian matrix tensor, expected size (E, 3N)
    """
    N = pos.size(0)
    E = edge_index.size(1)
    
    i, j = edge_index
    k = torch.arange(i.size(0))

    d_ij = compute_d(edge_index, pos)
    dd_dx = (pos[i] - pos[j]) / d_ij[:, None]
    # dd_ij/dx_i = (x_i - x_j) / d_ij
    dd_dx = dd_dx.flatten()
    
    k = k.unsqueeze(-1).expand(E, 3).flatten()
    i = torch.stack([3 * i, 3 * i + 1, 3 * i + 2]).T.reshape(-1)
    j = torch.stack([3 * j, 3 * j + 1, 3 * j + 2]).T.reshape(-1)
    
    jacobian = torch.sparse_coo_tensor(
        torch.stack([k, i]), dd_dx, (E, 3 * N)
    )
    jacobian += torch.sparse_coo_tensor(
        torch.stack([k, j]), -dd_dx, (E, 3 * N)
    )
    return jacobian

def sparse_calc_jacobian_q(edge_index, atom_type, pos, alpha=1.6, beta=2.3):
    """
    Compute jacobian matrix of q_ij.
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E), we consider directed graph
        atom_type (torch.Tensor): atom type tensor (N, )
        pos (torch.Tensor): position tensor (N, 3)
    Returns:
        jacobian (torch.coo-Tensor, ): jacobian matrix sparse-coo tensor, expected size (E, 3N)
    """
    N = pos.size(0)
    E = edge_index.size(1)
    # dq/dx = dq/dd * dd/dx
    
    i, j = edge_index
    k = torch.arange(i.size(0))

    d_ij = compute_d(edge_index, pos)
    d_e_ij = compute_de(edge_index, atom_type, atomic_radius)
    dd_dx = (pos[i] - pos[j]) / d_ij[:, None]
    # dd_ij/dx_i = (x_i - x_j) / d_ij
    # dq_ij/dd_ij = - alpha / d_e_ij * exp(- alpha / d_e_ij * (d_ij - d_e_ij)) - beta * d_e_ij / d_ij ** 2
    
    dq_dd = - alpha / d_e_ij * torch.exp(- alpha / d_e_ij * (d_ij - d_e_ij)) - beta * d_e_ij / d_ij ** 2
    
    dq_dx = dq_dd.unsqueeze(-1) * dd_dx # (E, 3)
    
    k = k.unsqueeze(-1).expand(E, 3).flatten()
    i = torch.stack([3 * i, 3 * i + 1, 3 * i + 2]).T.reshape(-1)
    j = torch.stack([3 * j, 3 * j + 1, 3 * j + 2]).T.reshape(-1)

    jacobian = torch.sparse_coo_tensor(
        torch.stack([k, i]), dq_dx.flatten(), (E, 3 * N)
    )
    jacobian += torch.sparse_coo_tensor(
        torch.stack([k, j]), -dq_dx.flatten(), (E, 3 * N)
    )
    return jacobian

def calc_jacobian_d(edge_index, pos):
    """
    Compute jacobian matrix of d_ij.
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E)
        pos (torch.Tensor): position tensor (N, 3)
    Returns:
        jacobian (torch.Tensor, ): jacobian matrix tensor, expected size (E, 3N)
    """
    jacobian = sparse_calc_jacobian_d(edge_index, pos).to_dense()
    return jacobian

def calc_jacobian_q(edge_index, atom_type, pos, alpha=1.6, beta=2.3):
    """
    Compute jacobian matrix of q_ij.
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E)
        atom_type (torch.Tensor): atom type tensor (N, )
        pos (torch.Tensor): position tensor (N, 3)
    Returns:
        jacobian (torch.Tensor, ): jacobian matrix tensor, expected size (E, 3N)
    """
    jacobian = sparse_calc_jacobian_q(edge_index, atom_type, pos, alpha=1.6, beta=2.3).to_dense()
    return jacobian
    
def calc_hessian_d(edge_index, pos):
    """
    Compute hessian matrix of d_ij.
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E)
        pos (torch.Tensor): position tensor (N, 3)
    Returns:
        hessian (torch.Tensor, ): hessian matrix tensor, expected size (E, 3N, 3N)
    """
    N = pos.size(0)
    E = edge_index.size(1)
    d_ij = compute_d(edge_index, pos)
    i, j = edge_index
    k = torch.arange(i.size(0))
    
    # first calculate hessian of d_ij, which is shape of (E, 3, 3) 
    d_pos = pos[i] - pos[j] # (E, 3)
    hess_d_ij = d_pos.reshape(-1, 1, 3) * d_pos.reshape(-1, 3, 1) / (d_ij.reshape(-1, 1, 1) ** 3)
    eye = torch.eye(3).reshape(1, 3, 3).to(d_ij.device)
    hess_d_ij -= eye / d_ij.reshape(-1, 1, 1)
    
    hess_d_ii = - hess_d_ij
    hess_d_jj = - hess_d_ij
    hess_d_ji = hess_d_ij
    
    # hessian of d is shape of (E, 3N, 3N)
    # Firstly, make it sparse tensor
    k = my_repeat(k, 9)
    col_i = my_stack(my_repeat(i, 3))
    row_i = my_repeat(my_stack(i), 3)
    col_j = my_stack(my_repeat(j, 3))
    row_j = my_repeat(my_stack(j), 3)
    hess = torch.sparse_coo_tensor(
        torch.stack([k, row_i, col_i]), hess_d_ii.flatten(), (E, 3 * N, 3 * N)
    )
    hess += torch.sparse_coo_tensor(
        torch.stack([k, row_j, col_j]), hess_d_jj.flatten(), (E, 3 * N, 3 * N)
    )
    hess += torch.sparse_coo_tensor(
        torch.stack([k, row_i, col_j]), hess_d_ij.flatten(), (E, 3 * N, 3 * N)
    )
    hess += torch.sparse_coo_tensor(
        torch.stack([k, row_j, col_i]), hess_d_ji.flatten(), (E, 3 * N, 3 * N)
    )
    
    hessian = hess.to_dense()
    return hessian

def calc_hessian_q(edge_index, atom_type, pos, alpha=1.6, beta=2.3):
    """
    Compute hessian matrix of q_ij.
    Args:
        edge_index (torch.Tensor): edge index tensor (2, E)
        atom_type (torch.Tensor): atom type tensor (N, )
        pos (torch.Tensor): position tensor (N, 3)
    Returns:
        hessian (torch.Tensor, ): hessian matrix tensor, expected size (E, 3N, 3N)
    """

    hessian_d = calc_hessian_d(edge_index, pos)
    jacobian_d = calc_jacobian_d(edge_index, pos)
    # d^2q/dadb = d^2d/dadb * K1(d) + dd/da * dd/db * K2(d)
    # K2(d) = d^2q/dd^2 
    # = (alpha / d_e_ij) ** 2 * exp(-alpha * (d_ij - d_e_ij) / d_e_ij) + 2 * beta * d_e_ij / d_ij ** 3
    # K1(d) = dq/dd 
    # = -alpha * exp(-alpha * (d_ij - d_e_ij) / d_e_ij) / d_e_ij - beta * d_e_ij / d_ij ** 2
    
    d_ij = compute_d(edge_index, pos)
    d_e_ij = compute_de(edge_index, atom_type, atomic_radius)
    K1 = - (alpha / d_e_ij) * torch.exp(-alpha * (d_ij - d_e_ij) / d_e_ij) - beta * d_e_ij / d_ij ** 2
    K2 = (alpha / d_e_ij) ** 2 * torch.exp(-alpha * (d_ij - d_e_ij) / d_e_ij) + 2 * beta * d_e_ij / d_ij ** 3
    
    hessian_q = K1.reshape(-1, 1, 1) * hessian_d + K2.reshape(-1, 1, 1) * jacobian_d.unsqueeze(1) * jacobian_d.unsqueeze(2)
    return hessian_q

In [9]:
idx = 0
a_data = datamodule.train_dataset[idx]
atom_type = a_data.x
# edge_index = a_data.edge_index
# get upper triangular index of N x N matrix
index = np.triu_indices(len(atom_type), 1)
edge_index = torch.Tensor(index).long()

pos_1 = a_data.pos_1
pos_2 = a_data.pos_2
atoms_0 = ase.Atoms(symbols=atom_type, positions=pos_1)
atoms_T = ase.Atoms(symbols=atom_type, positions=pos_2)
wrapper = Wrapper(atoms_0, atoms_T, q_type="morse", alpha=1.6, beta=2.3, gamma=0.0, using_jacobian=True, svd_tol=1e-4)

pos = pos_1

  edge_index = torch.Tensor(index).long()


Check every differentitation function is correct

In [16]:
torch.set_printoptions(precision=8, sci_mode=False)

d_e_old = torch.Tensor(wrapper.get_re(atoms_0))
d_e_new = compute_de(edge_index, atom_type, atomic_radius)
print(f"d_e is same? : {(abs(d_e_old - d_e_new) < 1e-6).all()}")

d_old = torch.Tensor(pdist(pos_1))
d_new = compute_d(edge_index, pos_1)
print(f"d is same? : {(abs(d_old - d_new) < 1e-6).all()}")

q_old = wrapper.pos_to_dist(pos_1)[1]
q_new = compute_q(edge_index, atom_type, pos_1)
print(f"q is same? : {(abs(q_old - q_new) < 1e-6).all()}")
print("------------------------------------")
jacob_old = wrapper.calc_jacobian(pos_1, q_type="DM").T
jacob_new = calc_jacobian_d(edge_index, pos_1)
print(f"d-jacobian is same? : {(abs(jacob_old - jacob_new) < 1e-6).all()}")
print(f"d-jacobian norm diff = {(jacob_old - jacob_new).norm()}")

jacob_old = wrapper.calc_jacobian(pos_1, q_type="morse").T
jacob_new = calc_jacobian_q(edge_index, atom_type, pos_1, alpha=1.6, beta=2.3)
print(f"q-jacobian is same? : {(abs(jacob_old - jacob_new) < 1e-6).all()}")
print(f"q-jacobian norm diff = {(jacob_old - jacob_new).norm()}")
print("------------------------------------")
hess_old = wrapper.calc_hessian(pos_1, q_type="DM")
hess_new = calc_hessian_d(edge_index, pos_1)
print(f"d-hessian is same? : {(abs(hess_old - hess_new) < 1e-6).all()}")
print(f"d-hessian norm diff = {(hess_old - hess_new).norm()}")

hess_old = wrapper.calc_hessian(pos_1, q_type="morse")
hess_new = calc_hessian_q(edge_index, atom_type, pos_1, alpha=1.6, beta=2.3)
print(f"q-hessian is same? : {(abs(hess_old - hess_new) < 1e-5).all()}")
print(f"q-hessian norm diff = {(hess_old - hess_new).norm()}")


d_e is same? : True
d is same? : True
q is same? : True
------------------------------------
d-jacobian is same? : True
d-jacobian norm diff = 4.082286002664526e-07
q-jacobian is same? : True
q-jacobian norm diff = 1.625652109717903e-06
------------------------------------
d-hessian is same? : True
d-hessian norm diff = 7.666445624956989e-07
q-hessian is same? : True
q-hessian norm diff = 7.569197015296669e-06
