<a href="https://colab.research.google.com/github/pranavkantgaur/gamd_sr/blob/main/deliverible.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Objective
1. Demonstrate the existence and merits of algorithm alignment in GNNs for physical simulations, if any.
2. Keep abstractions limited to experimental data and stable python libraries.

## Tasks
1. Loads a dataset from LJ, tip3w, tip4w, DFT
2. Optionally, generates MD dataset by passing a custom potential energy expressions as input
3. Loads relevant GNN model based on dataset complexity
4. Trains GNN and records edge messages of the converged model
5. Test for linearity of fit between pair potential and pair-force vs edge messages
6. Depending on the underlying custom potential, defines SR inputs and outputs
7. Trains SR
8. Plots pred vs gt curve for best SR equation
9. Discusses algorithm alignment of GNNs with the underlying physical phenomenon algorithm.
10. Perhaps demomstrates the benefits of algo. alignment:
   1. Whether algo. alignment is natural in GNNs or it happens becuase of the inductive biases?

### Load dataset

#### Preprocessed dataset sources
1. LJ test-cases: https://drive.google.com/file/d/1jJdTAnhps1EIHDaBfb893fruaLPJzYKI/view?usp=sharing
2. Tip-3P test-cases: https://drive.google.com/file/d/18uvKVtN8Xm_5w7AJuezFiR1d2AqlHFKn/view?usp=sharing
3. Tip-4p test-cases: https://drive.google.com/file/d/1jBk78hN4ZPC9x-YXnznUzxFnXnpeKFRi/view?usp=sharing
4. DFT test-cases: https://drive.google.com/file/d/1b9P7EvIliGupN9ZIJpMGZzkm4ttG9Ul6/view?usp=sharing

#### Dataset classes and dataloaders

In [None]:
import torch
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np


class WaterDataNew(Dataset):
    def __init__(self,
                 dataset_path,
                 sample_num,   # per seed
                 case_prefix='data_',
                 seed_num=10,
                 m_num=258,    # tip3p 258, tip4p 251
                 split=(0.9, 0.1),
                 mode='train',
                 data_type='tip3p',
                 ):
        self.dataset_path = dataset_path
        self.sample_num = sample_num
        self.case_prefix = case_prefix
        self.seed_num = seed_num

        self.data_type = data_type
        particle_type = []
        for i in range(m_num * 3):
            particle_type.append(1 if i % 3 == 0 else 0)
        self.particle_type = np.array(particle_type).astype(np.int64).reshape(-1, 1)
        # transform into one hot encoding
        self.particle_type_one_hot = np.zeros((self.particle_type.size, 1), dtype=np.float32)
        self.particle_type_one_hot[self.particle_type.reshape(-1) == 1] = 1
        self.num_atom_type = self.particle_type.max() + 1
        print(f'Including atom type: {self.num_atom_type}')

        self.mode = mode
        assert mode in ['train', 'test']
        idxs = np.arange(seed_num*sample_num)
        np.random.seed(0)   # fix same random seed
        np.random.shuffle(idxs)
        ratio = split[0]
        if mode == 'train':
            self.idx = idxs[:int(len(idxs)*ratio)]
        else:
            self.idx = idxs[int(len(idxs)*ratio):]

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, idx, get_path_name=False):
        idx = self.idx[idx]
        sample_to_read = idx % self.sample_num
        seed = idx // self.sample_num
        fname = f'data_{seed}_{sample_to_read}'#f'seed_{seed_to_read}_data_{sample_to_read}'
        data_path = os.path.join(self.dataset_path, fname)

        data = {}
        with np.load(data_path + '.npz', 'rb') as raw_data:
            pos = raw_data['pos'].astype(np.float32)
            if self.data_type == 'tip4p':
                pos = pos[np.mod(np.arange(pos.shape[0]), 4) < 3]
            data['pos'] = pos
            data['feat'] = self.particle_type_one_hot
            forces = raw_data['forces'].astype(np.float32)
            if self.data_type == 'tip4p':
                forces = forces[np.mod(np.arange(forces.shape[0]), 4) < 3]
            data['forces'] = forces
        if get_path_name:
            return data, data_path
        return data


class LJDataNew(Dataset):
    def __init__(self,
                 dataset_path,
                 sample_num,   # per seed
                 case_prefix='ljdata_',
                 seed_num=10,
                 split=(0.9, 0.1),
                 mode='train',
                 ):
        self.dataset_path = dataset_path
        self.sample_num = sample_num
        self.case_prefix = case_prefix
        self.seed_num = seed_num

        self.mode = mode
        assert mode in ['train', 'test']
        idxs = np.arange(seed_num*sample_num)
        np.random.seed(0)   # fix same random seed
        np.random.shuffle(idxs)
        ratio = split[0]
        if mode == 'train':
            self.idx = idxs[:int(len(idxs)*ratio)]
        else:
            self.idx = idxs[int(len(idxs)*ratio):]

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, idx, get_path_name=False):
        idx = self.idx[idx]
        #sample_to_read = 999 + idx % self.sample_num
        sample_to_read = idx % self.sample_num
        seed = idx // self.sample_num
        '''
        if sample_to_read < 999:
          print("Error")
          exit(0)
        '''
        fname = f'ljdata_{seed}_{sample_to_read}'
        data_path = os.path.join(self.dataset_path, fname)

        data = {}
        with np.load(data_path + '.npz', 'rb') as raw_data:
            pos = raw_data['pos'].astype(np.float32)
            data['pos'] = pos
            forces = raw_data['forces'].astype(np.float32)
            data['forces'] = forces
        if get_path_name:
            return data, data_path
        return data


class WaterDataRealLarge(Dataset):
    def __init__(self,
                 dataset_path,
                 mode='train',
                 use_part=False
                 ):
        self.dataset_path = dataset_path
        self.use_part = use_part
        with np.load(self.dataset_path, allow_pickle=True) as npz_data:
            train_idx = npz_data['train_idx']
            test_idx = npz_data['test_idx']
            self.pos = npz_data['pos']
            self.forces = npz_data['force']
            self.box_size = npz_data['box']
            self.atom_type = npz_data['atom_type']
        if use_part:
            print(f'Using 1500 training samples')
        else:
            print(f'Using {len(train_idx)} training samples')
        print(f'Using {len(test_idx)} testing samples')

        if mode == 'train':
            if not use_part:
                self.idx = train_idx
            else:
                self.idx = train_idx[:1500]
        else:
            self.idx = test_idx

    def __len__(self):
        return len(self.idx)

    def generate_atom_emb(self, particle_type):
        particle_type = np.array(particle_type).astype(np.int64).reshape(-1, 1)
        # transform into one hot encoding
        particle_type_one_hot = np.zeros((particle_type.size, 1), dtype=np.float32)
        particle_type_one_hot[particle_type.reshape(-1) == 1] = 1
        return particle_type_one_hot

    def __getitem__(self, idx):
        data = {}
        data['pos'] = self.pos[self.idx[idx]].copy().astype(np.float32)
        data['feat'] = self.generate_atom_emb(self.atom_type[self.idx[idx]])
        data['forces'] = self.forces[self.idx[idx]].copy().astype(np.float32)
        data['box_size'] = self.box_size[self.idx[idx]].copy().astype(np.float32)

        return data

### For LJ system

In [None]:
dataset = LJDataNew(dataset_path=os.path.join(self.data_dir, ''),
                        sample_num=1000,
                        case_prefix='data_',
                        seed_num=10,
                        mode='train')

dataloader = DataLoader(dataset, num_workers=2, batch_size=self.batch_size, shuffle=True,
                  collate_fn=
                  lambda batches: {
                      'pos': [batch['pos'] for batch in batches],
                      'forces': [batch['forces'] for batch in batches],
                  })

### For tip3p system

In [None]:
dataset = WaterDataNew(dataset_path=os.path.join(self.data_dir, 'water_data_tip3p'),
                        sample_num=1000,
                        case_prefix='data_',
                        seed_num=10,
                        m_num=NUM_OF_ATOMS//3,
                        mode='train',
                        data_type='tip3p')

dataloader = DataLoader(dataset, num_workers=2, batch_size=self.batch_size, shuffle=True,
                  collate_fn=
                  lambda batches: {
                      'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                      'pos': [batch['pos'] for batch in batches],
                      'forces': [batch['forces'] for batch in batches],
                  })


### For Tip4p system

In [None]:
dataset = WaterDataNew(dataset_path=os.path.join(self.data_dir, 'water_data_tip4p'),
                        sample_num=1000,
                        case_prefix='data_',
                        seed_num=10,
                        m_num=NUM_OF_ATOMS//3,
                        mode='train',
                        data_type='tip4p')

dataloader = DataLoader(dataset, num_workers=2, batch_size=self.batch_size, shuffle=True,
                  collate_fn=
                  lambda batches: {
                      'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                      'pos': [batch['pos'] for batch in batches],
                      'forces': [batch['forces'] for batch in batches],
                  })


### For DFT system

In [None]:
dataset = WaterDataRealLarge(dataset_path=os.path.join(self.data_dir, 'RPBE-data-processed.npz'), use_part=self.use_part)
return DataLoader(dataset, num_workers=2, batch_size=self.batch_size, shuffle=True,
                  collate_fn=
                  lambda batches: {
                      'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                      'pos': [batch['pos'] for batch in batches],
                      'forces': [batch['forces'] for batch in batches],
                      'box_size': [batch['box_size'] for batch in batches],
                  })

### Optionally, generate MD dataset

#### Generate LJ test-cases

In [None]:
import logging

import numpy as np


logger = logging.getLogger(__name__)

import time
import numpy as np
import sys
import os
from matplotlib import pyplot as plt

import openmm as mm
from openmm import *
from openmm.app import *
from openmm.unit import *
from openmm import unit
from openmm import app

from openmmtools import integrators
from openmmtools import testsystems


def subrandom_particle_positions(nparticles, box_vectors, method='sobol'):
    """Generate a deterministic list of subrandom particle positions.

    Parameters
    ----------
    nparticles : int
        The number of particles.
    box_vectors : openmm.unit.Quantity of (3,3) with units compatible with nanometer
        Periodic box vectors in which particles should lie.
    method : str, optional, default='sobol'
        Method for creating subrandom sequence (one of 'halton' or 'sobol')

    Returns
    -------
    positions : openmm.unit.Quantity of (natoms,3) with units compatible with nanometer
        The particle positions.

    Examples
    --------
    >>> nparticles = 216
    >>> box_vectors = openmm.System().getDefaultPeriodicBoxVectors()
    >>> positions = subrandom_particle_positions(nparticles, box_vectors)

    Use halton sequence:

    >>> nparticles = 216
    >>> box_vectors = openmm.System().getDefaultPeriodicBoxVectors()
    >>> positions = subrandom_particle_positions(nparticles, box_vectors, method='halton')

    """
    # Create positions array.
    positions = unit.Quantity(np.zeros([nparticles, 3], np.float32), unit.nanometers)

    if method == 'halton':
        # Fill in each dimension.
        primes = [2, 3, 5]  # prime bases for Halton sequence
        for dim in range(3):
            x = halton_sequence(primes[dim], nparticles)
            l = box_vectors[dim][dim]
            positions[:, dim] = unit.Quantity(x * l / l.unit, l.unit)

    elif method == 'sobol':
        # Generate Sobol' sequence.
        from openmmtools import sobol
        ivec = sobol.i4_sobol_generate(3, nparticles, 1)
        x = np.array(ivec, np.float32)
        for dim in range(3):
            l = box_vectors[dim][dim]
            positions[:, dim] = unit.Quantity(x[dim, :] * l / l.unit, l.unit)

    else:
        raise Exception("method '%s' must be 'halton' or 'sobol'" % method)

    return positions



def get_rotation_matrix():
    """ Randomly rotate the point clouds to augument the dataset
        rotation is per shape based along up direction
        Input:
          Nx3 array, original point clouds
        Return:
          Nx3 array, rotated point clouds
    """
    angles = np.random.uniform(-1.0, 1.0, size=(3,)) * np.pi
    print(f'Using angle: {angles}')
    Rx = np.array([[1., 0, 0],
                       [0, np.cos(angles[0]), -np.sin(angles[0])],
                       [0, np.sin(angles[0]), np.cos(angles[0])]], dtype=np.float32)
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                       [0, 1, 0],
                       [-np.sin(angles[1]), 0, np.cos(angles[1])]], dtype=np.float32)
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                   [np.sin(angles[2]), np.cos(angles[2]), 0],
                   [0, 0, 1]], dtype=np.float32)
    rotation_matrix = np.matmul(Rz, np.matmul(Ry, Rx))

    return rotation_matrix

def center_positions(pos):
    offset = np.mean(pos, axis=0)
    return pos - offset, offset


BOX_SCALE = 2
DT = 2
for seed in range(10):
    print(f'Running seed: {seed}')
    nparticles = 50

    reduced_density=0.05
    mass=39.9 * unit.amu  # argon
    sigma=3.4 * unit.angstrom  # argon,
    epsilon=0.238 * unit.kilocalories_per_mole  # argon,
    cutoff = 3.0 * sigma
    mass=39.9 * unit.amu
    charge = 0.0 * unit.elementary_charge
    # Create an empty system object.
    system = openmm.System()

    # Determine volume and periodic box vectors.
    number_density = reduced_density / sigma**3
    volume = nparticles * (number_density ** -1)
    box_edge = volume ** (1. / 3.)
    a = unit.Quantity((box_edge,        0 * unit.angstrom, 0 * unit.angstrom))
    b = unit.Quantity((0 * unit.angstrom, box_edge,        0 * unit.angstrom))
    c = unit.Quantity((0 * unit.angstrom, 0 * unit.angstrom, box_edge))
    system.setDefaultPeriodicBoxVectors(a, b, c)

    # Define Lennard-Jones potential with periodic boundary conditions (global constants)
    lj_potential = '4*epsilon*((sigma/r)^12 - (sigma/r)^6)'  # Using epsilon and sigma as constants in the expression
    custom_force = CustomNonbondedForce(lj_potential)
    # Add the constants for epsilon and sigma as global parameters
    custom_force.addGlobalParameter("epsilon", epsilon)
    custom_force.addGlobalParameter("sigma", sigma)

    # Set periodic cutoff for nonbonded interactions
    custom_force.setNonbondedMethod(CustomNonbondedForce.CutoffPeriodic)
    custom_force.setCutoffDistance(cutoff)


    for particle_index in range(nparticles):
      system.addParticle(mass)
      custom_force.addParticle([])

    system.addForce(custom_force)


    # Define initial positions
    positions = subrandom_particle_positions(nparticles, system.getDefaultPeriodicBoxVectors())

    # Create topology.
    topology = app.Topology()
    element = app.Element.getBySymbol('Ar')
    chain = topology.addChain()
    for particle in range(system.getNumParticles()):
        residue = topology.addResidue('Ar', chain)
        topology.addAtom('Ar', element, residue)
    topology = topology


    R = get_rotation_matrix()
    positions = positions.value_in_unit(unit.angstrom)
    positions, off = center_positions(positions)
    positions = np.matmul(positions, R)
    positions += off
    positions += np.random.randn(positions.shape[0], positions.shape[1]) * 0.005
    positions *= unit.angstrom

    timestep = DT * unit.femtoseconds
    temperature = 100 * unit.kelvin
    chain_length = 10
    friction = 25. / unit.picosecond
    num_mts = 5
    num_yoshidasuzuki = 5

    integrator1 = integrators.NoseHooverChainVelocityVerletIntegrator(system,
                                                                      temperature,
                                                                      friction,
                                                                      timestep, chain_length, num_mts, num_yoshidasuzuki)


    platform = Platform.getPlatformByName("CUDA")
    platformProperties = {'Precision': 'mixed', 'DeviceIndex': '0, 1, 2'}


    simulation = Simulation(topology, system, integrator1, platform, platformProperties)

    simulation.context.setPositions(positions)
    simulation.context.setVelocitiesToTemperature(temperature)

    simulation.minimizeEnergy(tolerance=1*unit.kilojoule/(unit.mole*unit.nanometer))
    simulation.step(1)

    os.makedirs(f'./lj_data_ours/run_7', exist_ok=True)
    stepsPerIter = 5000
    totalIter = 2000
    totalSteps = stepsPerIter * totalIter
    dataReporter_gt = StateDataReporter(f'./lj_data_ours/run_7/log_nvt_lj_{seed}.txt', stepsPerIter, totalSteps=totalSteps,
        step=True, time=True, speed=True, progress=True, elapsedTime=True, remainingTime=True,
        potentialEnergy=True, kineticEnergy=True, totalEnergy=True, temperature=True,
                                     separator='\t')
    simulation.reporters.append(dataReporter_gt)

    for t in range(totalIter):
        #if (t+1)%100 == 0:
        #    print(f'Finished {(t+1)*stepsPerIter} steps')
        state = simulation.context.getState(getPositions=True,
                                             getVelocities=True,
                                             getForces=True,
                                             enforcePeriodicBox=True)
        pos = state.getPositions(asNumpy=True).value_in_unit(unit.angstrom)
        vel = state.getVelocities(asNumpy=True).value_in_unit(unit.meter / unit.second)
        force = state.getForces(asNumpy=True).value_in_unit(unit.kilojoules_per_mole/unit.nanometer)
        np.savez(f'./lj_data_ours/run_7/data_{seed}_{t}.npz',
                 pos=pos,
                 vel=vel,
                 forces=force)
        simulation.step(stepsPerIter)



#### Generate tip-3p test-cases

In [None]:
from openmmtools import testsystems
from simtk.openmm.app import *
import simtk.unit as unit

import logging

import numpy as np

from openmmtools.constants import kB
from openmmtools import respa, utils

logger = logging.getLogger(__name__)

# Energy unit used by OpenMM unit system
from openmmtools import states, integrators
import time
import numpy as np
import sys
import os


def get_rotation_matrix():
    """ Randomly rotate the point clouds to augument the dataset
        rotation is per shape based along up direction
        Input:
          Nx3 array, original point clouds
        Return:
          Nx3 array, rotated point clouds
    """
    angles = np.random.uniform(-1.0, 1.0, size=(3,)) * np.pi
    print(f'Using angle: {angles}')
    Rx = np.array([[1., 0, 0],
                       [0, np.cos(angles[0]), -np.sin(angles[0])],
                       [0, np.sin(angles[0]), np.cos(angles[0])]], dtype=np.float32)
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                       [0, 1, 0],
                       [-np.sin(angles[1]), 0, np.cos(angles[1])]], dtype=np.float32)
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                   [np.sin(angles[2]), np.cos(angles[2]), 0],
                   [0, 0, 1]], dtype=np.float32)
    rotation_matrix = np.matmul(Rz, np.matmul(Ry, Rx))

    return rotation_matrix

def center_positions(pos):
    offset = np.mean(pos, axis=0)
    return pos - offset, offset


BOX_SCALE = 2
DT = 2
for seed in range(10):
    print(f'Running seed: {seed}')

    waterbox = testsystems.WaterBox(
        box_edge=2 * unit.nanometers,
        model='tip3p')
    [topology, system, positions] = [waterbox.topology, waterbox.system, waterbox.positions]

    R = get_rotation_matrix()
    positions = positions.value_in_unit(unit.angstrom)
    positions, off = center_positions(positions)
    positions = np.matmul(positions, R)
    positions += off
    positions += np.random.randn(positions.shape[0], positions.shape[1]) * 0.005
    positions *= unit.angstrom

    p_num = positions.shape[0] // 3
    timestep = DT * unit.femtoseconds
    temperature = 300 * unit.kelvin
    chain_length = 10
    friction = 1. / unit.picosecond
    num_mts = 5
    num_yoshidasuzuki = 5

    integrator = integrators.NoseHooverChainVelocityVerletIntegrator(system,
                                                                      temperature,
                                                                      friction,
                                                                      timestep, chain_length, num_mts, num_yoshidasuzuki)

    simulation = Simulation(topology, system, integrator)
    simulation.context.setPositions(positions)
    simulation.context.setVelocitiesToTemperature(temperature)

    simulation.minimizeEnergy(tolerance=1*unit.kilojoule/unit.mole)
    simulation.step(1)

    os.makedirs(f'./water_data_tip3p/', exist_ok=True)
    dataReporter_gt = StateDataReporter(f'./log_nvt_tip3p_{seed}.txt', 50, totalSteps=50000,
        step=True, time=True, speed=True, progress=True, elapsedTime=True, remainingTime=True,
        potentialEnergy=True, kineticEnergy=True, totalEnergy=True, temperature=True,
                                     separator='\t')
    simulation.reporters.append(dataReporter_gt)
    for t in range(1000):
        if (t+1)%100 == 0:
            print(f'Finished {(t+1)*50} steps')
        state = simulation.context.getState(getPositions=True,
                                             getVelocities=True,
                                             getForces=True,
                                             enforcePeriodicBox=True)
        pos = state.getPositions(asNumpy=True).value_in_unit(unit.angstrom)
        vel = state.getVelocities(asNumpy=True).value_in_unit(unit.meter / unit.second)
        force = state.getForces(asNumpy=True).value_in_unit(unit.kilojoules_per_mole/unit.nanometer)

        np.savez(f'./water_data_tip3p/data_{seed}_{t}.npz',
                 pos=pos,
                 vel=vel,
                 forces=force)
        simulation.step(50)




#### Generate tip-4p test-cases

In [None]:
from openmmtools import testsystems
from simtk.openmm.app import *
import simtk.unit as unit

import logging

import numpy as np

from openmmtools.constants import kB
from openmmtools import respa, utils

logger = logging.getLogger(__name__)

# Energy unit used by OpenMM unit system
from openmmtools import states, integrators
import time
import numpy as np
import sys
import os


def get_rotation_matrix():
    """ Randomly rotate the point clouds to augument the dataset
        rotation is per shape based along up direction
        Input:
          Nx3 array, original point clouds
        Return:
          Nx3 array, rotated point clouds
    """
    angles = np.random.uniform(-1.0, 1.0, size=(3,)) * np.pi
    print(f'Using angle: {angles}')
    Rx = np.array([[1., 0, 0],
                       [0, np.cos(angles[0]), -np.sin(angles[0])],
                       [0, np.sin(angles[0]), np.cos(angles[0])]], dtype=np.float32)
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                       [0, 1, 0],
                       [-np.sin(angles[1]), 0, np.cos(angles[1])]], dtype=np.float32)
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                   [np.sin(angles[2]), np.cos(angles[2]), 0],
                   [0, 0, 1]], dtype=np.float32)
    rotation_matrix = np.matmul(Rz, np.matmul(Ry, Rx))

    return rotation_matrix

def center_positions(pos):
    offset = np.mean(pos, axis=0)
    return pos - offset, offset


BOX_SCALE = 2
DT = 2
for seed in range(10):
    print(f'Running seed: {seed}')

    waterbox = testsystems.WaterBox(
        box_edge=2 * unit.nanometers,
        model='tip4pew')
    [topology, system, positions] = [waterbox.topology, waterbox.system, waterbox.positions]

    R = get_rotation_matrix()
    positions = positions.value_in_unit(unit.angstrom)
    positions, off = center_positions(positions)
    positions = np.matmul(positions, R)
    positions += off
    positions += np.random.randn(positions.shape[0], positions.shape[1]) * 0.005
    positions *= unit.angstrom

    p_num = positions.shape[0] // 3
    timestep = DT * unit.femtoseconds
    temperature = 300 * unit.kelvin
    chain_length = 10
    friction = 1. / unit.picosecond
    num_mts = 5
    num_yoshidasuzuki = 5

    integrator = integrators.NoseHooverChainVelocityVerletIntegrator(system,
                                                                      temperature,
                                                                      friction,
                                                                      timestep, chain_length, num_mts, num_yoshidasuzuki)

    simulation = Simulation(topology, system, integrator)
    simulation.context.setPositions(positions)
    simulation.context.setVelocitiesToTemperature(temperature)

    simulation.minimizeEnergy(tolerance=1*unit.kilojoule/unit.mole)
    simulation.step(1)

    os.makedirs(f'./water_data_tip4p/', exist_ok=True)
    dataReporter_gt = StateDataReporter(f'./log_nvt_tip4p_{seed}.txt', 50, totalSteps=50000,
        step=True, time=True, speed=True, progress=True, elapsedTime=True, remainingTime=True,
        potentialEnergy=True, kineticEnergy=True, totalEnergy=True, temperature=True,
                                     separator='\t')
    simulation.reporters.append(dataReporter_gt)
    for t in range(1000):
        if (t+1)%100 == 0:
            print(f'Finished {(t+1)*50} steps')
        state = simulation.context.getState(getPositions=True,
                                             getVelocities=True,
                                             getForces=True,
                                             enforcePeriodicBox=True)
        pos = state.getPositions(asNumpy=True).value_in_unit(unit.angstrom)
        vel = state.getVelocities(asNumpy=True).value_in_unit(unit.meter / unit.second)
        force = state.getForces(asNumpy=True).value_in_unit(unit.kilojoules_per_mole/unit.nanometer)

        np.savez(f'./water_data_tip4p/data_{seed}_{t}.npz',
                 pos=pos,
                 vel=vel,
                 forces=force)
        simulation.step(50)




#### Visualize the MD simulation dataset

##### .npz to .pdb file conversion

In [None]:
import numpy as np
import os

# Set the input directory containing your .npz files
input_directory = '/home/pranav/gamd_sr/openmm_data_generation/lj_data_ours/run_5/'
# Set the output directory where you want to save .pdb files
output_directory = 'top_pymol_ours'

# Create output directory if it doesn't exist
os.makedirs(output_directory, exist_ok=True)

# Loop through all .npz files in the input directory
for filename in os.listdir(input_directory):
    if filename.endswith('.npz'):
        # Extract simulation_id and frame_id from the filename
        parts = filename.split('_')
        if len(parts) != 3:
            continue  # Skip files that do not match the expected naming convention

        simulation_id = parts[1]  # Get the simulation ID
        frame_id = parts[2].replace('.npz', '')  # Get the frame ID without extension

        # Load the .npz file
        data = np.load(os.path.join(input_directory, filename))
        positions = data['pos']  # Assuming 'pos' is the key for positions

        # Create a subdirectory for each simulation if it doesn't exist
        sim_output_dir = os.path.join(output_directory, f'simulation_{simulation_id}')
        os.makedirs(sim_output_dir, exist_ok=True)

        # Save positions to a .pdb file in the corresponding subdirectory
        pdb_filename = f'frame_{frame_id}.pdb'
        with open(os.path.join(sim_output_dir, pdb_filename), 'w') as f:
            for i, pos in enumerate(positions):
                f.write(f"ATOM  {i+1:5d}  Ar   RES A   1    {pos[0]:8.3f}{pos[1]:8.3f}{pos[2]:8.3f}\n")
            f.write("END\n")

print("Conversion complete.")


##### .pdb to .mpg movie file conversion

In [None]:
import os
from pymol import cmd

# Set the input directory containing your .pdb files
input_directory = '../top_pymol/simulation_0'
# Set the output directory where you want to save the movie
output_directory = '../movie_dir'
# Define the output movie filename
output_movie_filename = os.path.join(output_directory, 'simulation_movie.mp4')

# Create output directory if it doesn't exist
os.makedirs(output_directory, exist_ok=True)

# Clear existing objects in PyMOL
cmd.reinitialize()

# Load all .pdb files from the input directory
pdb_files = sorted([f for f in os.listdir(input_directory) if f.endswith('.pdb')])[:50]


# Create scenes for each frame
for i, pdb_file in enumerate(pdb_files):

    # Clear any previously loaded frames
    cmd.delete('all')  # Remove all previous objects from the scene

    # Load the current frame
    cmd.load(os.path.join(input_directory, pdb_file), f'frame_{i}')


    # Optionally set the view or other properties here
    cmd.show('spheres', f'frame_{i}')  # Show as spheres (or other representation)

    # You can customize colors or representations if needed
    #cmd.color('blue', f'frame_{i}')  # Color all frames blue

    # Create a new scene for this frame
    cmd.scene(f'scene_{i}', 'store')  # 0 means current state

    # Save the current scene as an image (PNG format)
    image_filename = os.path.join(output_directory, f'frame_{i}.png')
    cmd.png(image_filename, width=800, height=600, dpi=300)  # Adjust dimensions and DPI as needed

# Invoke FFmpeg to create a video from the saved frames
ffmpeg_command = f"ffmpeg -framerate 1 -i {output_directory}/frame_%d.png -c:v libx264 -pix_fmt yuv420p {output_movie_filename}"
os.system(ffmpeg_command)

print(f"Movie saved to {output_movie_filename}")

### Setup appropriate GNN model

#### Neighborhood calculation functions

In [None]:
from functools import partial
import jax
from jax_md import space, partition
from jax_md.space import pairwise_displacement
from jax import numpy as jnp

import warnings
warnings.filterwarnings("ignore")


class NeighborSearcher(object):
    def __init__(self, box_size, cutoff):
        # define a displacement function under periodic condition
        self.box_size = jnp.array(box_size)

        self.displacement_fn, _ = space.periodic(self.box_size)
        self.disp = jax.vmap(self.displacement_fn)
        self.dist = jax.vmap(space.metric(self.displacement_fn))
        self.cutoff = cutoff
        self.has_been_init = False
        self.neighbor_list_fn = partition.neighbor_list(self.displacement_fn,
                                                       self.box_size,
                                                       cutoff,
                                                       dr_threshold= cutoff / 6.,
                                                       mask_self=False)
        self.neighbor_list_fn_jit = jax.jit(self.neighbor_list_fn)
        self.neighbor_dist_jit = self.displacement_fn

    def init_new_neighbor_lst(self, pos):
        # Create a new neighbor list.
        pos = jnp.mod(pos, self.box_size)
        nbr = self.neighbor_list_fn(pos)
        self.has_been_init = True
        return nbr

    def update_neighbor_lst(self, pos, nbr):
        pos = jnp.mod(pos, self.box_size)
        # update_idx = np.any(self.dist(pos, nbr.reference_position) > (self.cutoff / 10.))

        nbr = self.neighbor_list_fn_jit(pos, nbr)
        if nbr.did_buffer_overflow:
            nbr = self.neighbor_list_fn(pos)

        return nbr


def graph_network_nbr_fn(displacement_fn,
                         cutoff,
                         N):

    def nbrlst_to_edge_mask(pos: jnp.ndarray, neigh_idx: jnp.ndarray):
        # notice here, pos must be jax numpy array, otherwise fancy indexing will fail
        d = partial(displacement_fn)
        d = space.map_neighbor(d)
        pos_neigh = pos[neigh_idx]
        dR = d(pos, pos_neigh)

        dr_2 = space.square_distance(dR)
        mask = jnp.logical_and(neigh_idx != N, dr_2 < cutoff ** 2)

        return mask

    return nbrlst_to_edge_mask


# def graph_network_nbr_fn_with_type_mask(displacement_fn,
#                                          cutoff,
#                                          N):
#
#     def nbrlst_to_edge_mask(pos: jnp.ndarray, neigh_idx: jnp.ndarray, type_mask: jnp.ndarray):
#         # notice here, pos must be jax numpy array, otherwise fancy indexing will fail
#         d = jax.partial(displacement_fn)
#         d = space.map_neighbor(d)
#         pos_neigh = pos[neigh_idx]
#         dR = d(pos, pos_neigh)
#
#         dr_2 = space.square_distance(dR)
#         mask = jnp.logical_and(neigh_idx != N, dr_2 < cutoff ** 2)
#         mask = jnp.logical_and(type_mask, mask)
#         return mask
#
#     return nbrlst_to_edge_mask


#### GNN model definitions

In [None]:
import numpy as np
import torch
import torch.nn as nn
import dgl.nn
import dgl.function as fn
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
import time
from md_module import get_neighbor
from sklearn.preprocessing import StandardScaler

from typing import List, Set, Dict, Tuple, Optional


def cubic_kernel(r, re):
    eps = 1e-3
    r = torch.threshold(r, eps, re)
    return nn.ReLU()((1. - (r/re)**2)**3)


class MLP(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 hidden_dim=128,
                 hidden_layer=3,
                 activation_first=False,
                 activation='relu',
                 init_param=False):
        super(MLP, self).__init__()
        if activation == 'relu':
            act_fn = nn.ReLU()
        elif activation == 'leaky_relu':
            act_fn = nn.LeakyReLU(0.2)
        elif activation == 'sigmoid':
            act_fn = nn.Sigmoid()
        elif activation == 'tanh':
            act_fn = nn.Tanh()
        elif activation == 'elu':
            act_fn = nn.ELU()
        elif activation == 'gelu':
            act_fn = nn.GELU()
        elif activation == 'silu':
            act_fn = nn.SiLU()
        else:
            raise Exception('Only support: relu, leaky_relu, sigmoid, tanh, elu, as non-linear activation')

        mlp_layer = []
        for l in range(hidden_layer):
            if l != hidden_layer-1 and l != 0:
                mlp_layer += [nn.Linear(hidden_dim, hidden_dim), act_fn]
            elif l == 0:
                if hidden_layer == 1:
                    if activation_first:
                        mlp_layer += [act_fn, nn.Linear(in_feats, out_feats)]
                    else:
                        print('Using MLP with no hidden layer and activations! Fall back to nn.Linear()')
                        mlp_layer += [nn.Linear(in_feats, out_feats)]
                elif not activation_first:
                    mlp_layer += [nn.Linear(in_feats, hidden_dim), act_fn]
                else:
                    mlp_layer += [act_fn, nn.Linear(in_feats, hidden_dim), act_fn]
            else:   # l == hidden_layer-1
                mlp_layer += [nn.Linear(hidden_dim, out_feats)]
        self.mlp_layer = nn.Sequential(*mlp_layer)
        if init_param:
            self._init_parameters()

    def _init_parameters(self):
        for layer in self.mlp_layer:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)

    def forward(self, feat):
        return self.mlp_layer(feat)


class SmoothConvLayerNew(nn.Module):
    def __init__(self,
                 in_node_feats,
                 in_edge_feats,
                 out_node_feats,
                 hidden_dim=128,
                 activation='relu',
                 drop_edge=True,
                 update_edge_emb=False):

        super(SmoothConvLayerNew, self).__init__()
        self.drop_edge = drop_edge
        self.update_edge_emb = update_edge_emb
        if self.update_edge_emb:
            self.edge_layer_norm = nn.LayerNorm(in_edge_feats)

        # self.theta_src = nn.Linear(in_node_feats, hidden_dim)
        self.edge_affine = MLP(in_edge_feats, hidden_dim, activation=activation, hidden_layer=2)
        self.src_affine = nn.Linear(in_node_feats, hidden_dim)
        self.dst_affine = nn.Linear(in_node_feats, hidden_dim)
        self.theta_edge = MLP(hidden_dim, in_node_feats,
                              hidden_dim=hidden_dim, activation=activation, activation_first=True,
                              hidden_layer=2)
        # self.theta = MLP(hidden_dim, hidden_dim, activation_first=True, hidden_layer=2)

        self.phi_dst = nn.Linear(in_node_feats, hidden_dim)
        self.phi_edge = nn.Linear(in_node_feats, hidden_dim)
        self.phi = MLP(hidden_dim, out_node_feats,
                       activation_first=True, hidden_layer=1, hidden_dim=hidden_dim, activation=activation)

    def forward(self, g: dgl.DGLGraph, node_feat: torch.Tensor) -> torch.Tensor:
        h = node_feat.clone()
        with g.local_scope():
            if self.drop_edge and self.training:
                src_idx, dst_idx = g.edges()
                e_feat = g.edata['e'].clone()
                dropout_ratio = 0.2
                idx = np.arange(dst_idx.shape[0])
                np.random.shuffle(idx)
                keep_idx = idx[:-int(idx.shape[0] * dropout_ratio)]
                src_idx = src_idx[keep_idx]
                dst_idx = dst_idx[keep_idx]
                e_feat = e_feat[keep_idx]
                g = dgl.graph((src_idx, dst_idx))
                g.edata['e'] = e_feat
            # for multi batch training
            if g.is_block:
                h_src = h
                h_dst = h[:g.number_of_dst_nodes()]
            else:
                h_src = h_dst = h

            g.srcdata['h'] = h_src
            g.dstdata['h'] = h_dst
            edge_idx = g.edges()
            src_idx = edge_idx[0]
            dst_idx = edge_idx[1]
            edge_code = self.edge_affine(g.edata['e'])
            src_code = self.src_affine(h_src[src_idx])
            dst_code = self.dst_affine(h_dst[dst_idx])
            g.edata['e_emb'] = self.theta_edge(edge_code+src_code+dst_code)
            self.edge_message_neigh_center = src_code * g.edata['e_emb'] # Recording messages for messgage regularisation
            self.input_node_embeddings = h # Recording messages for node embedding regularisation

            if self.update_edge_emb:
                normalized_e_emb = self.edge_layer_norm(g.edata['e_emb'])
            g.update_all(fn.u_mul_e('h', 'e_emb', 'm'), fn.sum('m', 'h'))
            edge_emb = g.ndata['h']

        if self.update_edge_emb:
            g.edata['e'] = normalized_e_emb
        node_feat = self.phi(self.phi_dst(h) + self.phi_edge(edge_emb))
        return node_feat


class SmoothConvBlockNew(nn.Module):
    def __init__(self,
                 in_node_feats,
                 out_node_feats,
                 hidden_dim=128,
                 conv_layer=3,
                 edge_emb_dim=64,
                 use_layer_norm=False,
                 use_batch_norm=True,
                 drop_edge=False,
                 activation='relu',
                 update_egde_emb=False,
                 ):
        super(SmoothConvBlockNew, self).__init__()
        self.conv = nn.ModuleList()
        self.edge_emb_dim = edge_emb_dim
        self.use_layer_norm = use_layer_norm
        self.use_batch_norm = use_batch_norm

        self.drop_edge = drop_edge
        if use_batch_norm == use_layer_norm and use_batch_norm:
            raise Exception('Only one type of normalization at a time')
        if use_layer_norm or use_batch_norm:
            self.norm_layers = nn.ModuleList()

        for layer in range(conv_layer):
            if layer == 0:
                self.conv.append(SmoothConvLayerNew(in_node_feats=in_node_feats,
                                                 in_edge_feats=self.edge_emb_dim,
                                                 out_node_feats=out_node_feats,
                                                 hidden_dim=hidden_dim,
                                                 activation=activation,
                                                 drop_edge=drop_edge,
                                                 update_edge_emb=update_egde_emb))
            else:
                self.conv.append(SmoothConvLayerNew(in_node_feats=out_node_feats,
                                                 in_edge_feats=self.edge_emb_dim,
                                                 out_node_feats=out_node_feats,
                                                 hidden_dim=hidden_dim,
                                                 activation=activation,
                                                 drop_edge=drop_edge,
                                                 update_edge_emb=update_egde_emb))
            if use_layer_norm:
                self.norm_layers.append(nn.LayerNorm(out_node_feats))
            elif use_batch_norm:
                self.norm_layers.append(nn.BatchNorm1d(out_node_feats))

    def forward(self, h: torch.Tensor, graph: dgl.DGLGraph) -> torch.Tensor:

        for l, conv_layer in enumerate(self.conv):
            if self.use_layer_norm or self.use_batch_norm:
                h = conv_layer.forward(graph, self.norm_layers[l](h)) + h
            else:
                h = conv_layer.forward(graph, h) + h

        return h


# code from DGL documents
class RBFExpansion(nn.Module):
    r"""Expand distances between nodes by radial basis functions.

    .. math::
        \exp(- \gamma * ||d - \mu||^2)

    where :math:`d` is the distance between two nodes and :math:`\mu` helps centralizes
    the distances. We use multiple centers evenly distributed in the range of
    :math:`[\text{low}, \text{high}]` with the difference between two adjacent centers
    being :math:`gap`.

    The number of centers is decided by :math:`(\text{high} - \text{low}) / \text{gap}`.
    Choosing fewer centers corresponds to reducing the resolution of the filter.

    Parameters
    ----------
    low : float
        Smallest center. Default to 0.
    high : float
        Largest center. Default to 30.
    gap : float
        Difference between two adjacent centers. :math:`\gamma` will be computed as the
        reciprocal of gap. Default to 0.1.
    """
    def __init__(self, low=0., high=30., gap=0.1):
        super(RBFExpansion, self).__init__()

        num_centers = int(np.ceil((high - low) / gap))
        self.centers = np.linspace(low, high, num_centers)
        self.centers = nn.Parameter(torch.tensor(self.centers).float(), requires_grad=False)
        self.gamma = 1 / gap

    def reset_parameters(self):
        """Reinitialize model parameters."""
        device = self.centers.device
        self.centers = nn.Parameter(
            self.centers.clone().detach().float(), requires_grad=False).to(device)

    def forward(self, edge_dists):
        """Expand distances.

        Parameters
        ----------
        edge_dists : float32 tensor of shape (E, 1)
            Distances between end nodes of edges, E for the number of edges.

        Returns
        -------
        float32 tensor of shape (E, len(self.centers))
            Expanded distances.
        """
        radial = edge_dists - self.centers
        coef = - self.gamma
        return torch.exp(coef * (radial ** 2))


class WaterMDDynamicBoxNet(nn.Module):
    def __init__(self,
                 in_feats,
                 encoding_size,
                 out_feats,
                 bond=None,       #
                 hidden_dim=128,
                 conv_layer=4,
                 edge_embedding_dim=128,
                 dropout=0.1,
                 drop_edge=True,
                 use_layer_norm=False,
                 update_edge=False,
                 expand_edge=True):
        super(WaterMDDynamicBoxNet, self).__init__()
        self.graph_conv = SmoothConvBlockNew(in_node_feats=encoding_size,
                                              out_node_feats=encoding_size,
                                              hidden_dim=hidden_dim,
                                              conv_layer=conv_layer,
                                              edge_emb_dim=edge_embedding_dim,
                                              use_layer_norm=use_layer_norm,
                                              use_batch_norm=not use_layer_norm,
                                              drop_edge=drop_edge,
                                              activation='silu',
                                              update_egde_emb=update_edge)

        self.edge_emb_dim = edge_embedding_dim
        self.expand_edge = expand_edge
        if self.expand_edge:
            self.edge_expand = RBFExpansion(high=1, gap=0.025)
        self.edge_drop_out = nn.Dropout(dropout)
        self.use_bond = not bond is None

        self.length_mean = nn.Parameter(torch.tensor([0.]), requires_grad=False)
        self.length_std = nn.Parameter(torch.tensor([1.]), requires_grad=False)
        self.length_scaler = StandardScaler()

        self.node_encoder = nn.Linear(in_feats, encoding_size)
        if bond is not None:
            if self.expand_edge:
                self.edge_encoder = MLP(4 + 1 + len(self.edge_expand.centers), self.edge_emb_dim, hidden_dim=hidden_dim,
                                        activation='gelu')
            else:
                self.edge_encoder = MLP(4 + 1, self.edge_emb_dim, hidden_dim=hidden_dim,
                                        activation='gelu')
            self.bond_graph = self.build_bond_graph(bond)
        else:
            if self.expand_edge:
                self.edge_encoder = MLP(3 + 1 + len(self.edge_expand.centers), self.edge_emb_dim, hidden_dim=hidden_dim,
                                        activation='gelu')
            else:
                self.edge_encoder = MLP(3 + 1, self.edge_emb_dim, hidden_dim=hidden_dim,
                                        activation='gelu')
        self.edge_layer_norm = nn.LayerNorm(self.edge_emb_dim)
        self.graph_decoder = MLP(encoding_size, out_feats, hidden_layer=2, hidden_dim=hidden_dim, activation='gelu')

    def calc_edge_feat(self, rel_pos_periodic, rel_pos_norm):

        if self.training:
            self.fit_length(rel_pos_norm)
            self._update_length_stat(self.length_scaler.mean_, np.sqrt(self.length_scaler.var_))
        rel_pos_periodic = -rel_pos_periodic / (rel_pos_norm + 1e-8)
        rel_pos_norm = (rel_pos_norm - self.length_mean) / self.length_std
        if self.expand_edge:
            edge_feat = torch.cat((rel_pos_periodic,
                                   rel_pos_norm,
                                   self.edge_expand(rel_pos_norm)), dim=1)
        else:
            edge_feat = torch.cat((rel_pos_periodic,
                                   rel_pos_norm), dim=1)
        return edge_feat

    def build_graph(self, fluid_pos, cutoff, box_size, self_loop=True):
        if isinstance(box_size, torch.Tensor):
            box_size = box_size.to(fluid_pos.device)
        elif isinstance(box_size, np.ndarray):
            box_size = torch.from_numpy(box_size).to(fluid_pos.device)

        with torch.no_grad():
            edge_idx, distance, distance_norm, _ = get_neighbor(fluid_pos,
                                                                cutoff, box_size)
        center_idx = edge_idx[0, :]  # [edge_num, 1]
        neigh_idx = edge_idx[1, :]
        fluid_graph = dgl.graph((neigh_idx, center_idx))
        fluid_edge_feat = self.calc_edge_feat(distance, distance_norm.view(-1, 1))
        if not self.use_bond:
            fluid_edge_emb = self.edge_layer_norm(self.edge_encoder(fluid_edge_feat)) # [edge_num, 64]
            fluid_edge_emb = self.edge_drop_out(fluid_edge_emb)
            fluid_graph.edata['e'] = fluid_edge_emb
        else:
            edge_type = self.bond_graph.has_edges_between(center_idx, neigh_idx)
            fluid_edge_feat = torch.cat((fluid_edge_feat, edge_type.view(-1, 1)), dim=1)
            fluid_edge_emb = self.edge_layer_norm(self.edge_encoder(fluid_edge_feat))  # [edge_num, 64]
            fluid_edge_emb = self.edge_drop_out(fluid_edge_emb)
            fluid_graph.edata['e'] = fluid_edge_emb

        # add self loop for fluid particles
        if self_loop:
            fluid_graph.add_self_loop()
        return fluid_graph

    def build_graph_batches(self, pos_lst, box_size_lst, cutoff):
        graph_lst = []
        for pos, box_size in zip(pos_lst, box_size_lst):
            graph = self.build_graph(pos, cutoff, box_size)
            graph_lst += [graph]
        batched_graph = dgl.batch(graph_lst)
        return batched_graph

    def build_bond_graph(self, bond):
        if isinstance(bond, np.ndarray):
            bond = torch.from_numpy(bond).cuda()
        bond_graph = dgl.graph((bond[:, 0], bond[:, 1]))
        bond_graph = dgl.add_reverse_edges(bond_graph)  # undirectional and symmetry
        return bond_graph

    def _update_length_stat(self, new_mean, new_std):
        self.length_mean[0] = new_mean[0]
        self.length_std[0] = new_std[0]

    def fit_length(self, length):
        if not isinstance(length, np.ndarray):
            length = length.detach().cpu().numpy().reshape(-1,1)
        self.length_scaler.partial_fit(length)

    def forward(self,
                fluid_pos_lst,  #   list of [N, 3]
                x,  # node feature    # [b*N, 3]
                box_size_lst,   #   list of scalar
                cutoff          # a scalar
                ):
        # fluid_graph = self.build_graph(fluid_pos, cutoff, box_size)
        if len(fluid_pos_lst) > 1:
            fluid_graph = self.build_graph_batches(fluid_pos_lst, box_size_lst, cutoff)
        else:
            fluid_graph = self.build_graph(fluid_pos_lst[0], cutoff, box_size_lst[0])

        x = self.node_encoder(x)
        x = self.graph_conv(x, fluid_graph)

        x = self.graph_decoder(x)
        return x


class WaterMDNetNew(nn.Module):
    def __init__(self,
                 in_feats,
                 encoding_size,
                 out_feats,
                 box_size,   # can also be array
                 bond=None,       #
                 hidden_dim=128,
                 conv_layer=4,
                 edge_embedding_dim=128,
                 dropout=0.1,
                 drop_edge=True,
                 use_layer_norm=False):
        super(WaterMDNetNew, self).__init__()
        self.graph_conv = SmoothConvBlockNew(in_node_feats=encoding_size,
                                             out_node_feats=encoding_size,
                                             hidden_dim=hidden_dim,
                                             conv_layer=conv_layer,
                                             edge_emb_dim=edge_embedding_dim,
                                             use_layer_norm=use_layer_norm,
                                             use_batch_norm=not use_layer_norm,
                                             drop_edge=drop_edge,
                                             activation='silu')

        self.edge_emb_dim = edge_embedding_dim
        self.edge_expand = RBFExpansion(high=1, gap=0.025)
        self.edge_drop_out = nn.Dropout(dropout)
        self.use_bond = not bond is None

        self.length_mean = nn.Parameter(torch.tensor([0.]), requires_grad=False)
        self.length_std = nn.Parameter(torch.tensor([1.]), requires_grad=False)
        self.length_scaler = StandardScaler()

        if isinstance(box_size, np.ndarray):
            self.box_size = torch.from_numpy(box_size).float()
        else:
            self.box_size = box_size
        self.box_size = self.box_size

        self.node_encoder = nn.Linear(in_feats, encoding_size)
        if bond is not None:
            self.edge_encoder = MLP(4 + 1 + len(self.edge_expand.centers), self.edge_emb_dim, hidden_dim=hidden_dim,
                                    activation='gelu')
            self.use_bond = True
            self.bond_graph = self.build_bond_graph(bond)
        else:
            self.edge_encoder = MLP(3 + 1 + len(self.edge_expand.centers), self.edge_emb_dim, hidden_dim=hidden_dim,
                                    activation='gelu')
            self.use_bond = False
        self.edge_layer_norm = nn.LayerNorm(self.edge_emb_dim)
        self.graph_decoder = MLP(encoding_size, out_feats, hidden_layer=2, hidden_dim=hidden_dim, activation='gelu')

    def calc_edge_feat(self,
                       src_idx: torch.Tensor,
                       dst_idx: torch.Tensor,
                       pos_src: torch.Tensor,
                       pos_dst=None) -> torch.Tensor:
        # this is the raw input feature

        # to enhance computation performance, dont track their calculation on graph
        if pos_dst is None:
            pos_dst = pos_src

        with torch.no_grad():
            rel_pos = pos_dst[dst_idx.long()] - pos_src[src_idx.long()]
            if isinstance(self.box_size, torch.Tensor):
                rel_pos_periodic = torch.remainder(rel_pos + 0.5 * self.box_size.to(rel_pos.device),
                                                   self.box_size.to(rel_pos.device)) - 0.5 * self.box_size.to(rel_pos.device)
            else:
                rel_pos_periodic = torch.remainder(rel_pos + 0.5 * self.box_size,
                                                   self.box_size) - 0.5 * self.box_size

            rel_pos_norm = rel_pos_periodic.norm(dim=1).view(-1, 1)  # [edge_num, 1]
            rel_pos_periodic /= rel_pos_norm + 1e-8   # normalized

        if self.training:
            self.fit_length(rel_pos_norm)
            self._update_length_stat(self.length_scaler.mean_, np.sqrt(self.length_scaler.var_))

        rel_pos_norm = (rel_pos_norm - self.length_mean) / self.length_std
        edge_feat = torch.cat((rel_pos_periodic,
                               rel_pos_norm,
                               self.edge_expand(rel_pos_norm)), dim=1)
        return edge_feat

    def build_graph(self,
                    fluid_edge_idx: torch.Tensor,
                    fluid_pos: torch.Tensor,
                    self_loop=True) -> dgl.DGLGraph:

        center_idx = fluid_edge_idx[0, :]  # [edge_num, 1]
        neigh_idx = fluid_edge_idx[1, :]
        fluid_graph = dgl.graph((neigh_idx, center_idx))
        fluid_edge_feat = self.calc_edge_feat(center_idx, neigh_idx, fluid_pos)

        if not self.use_bond:
            fluid_edge_emb = self.edge_layer_norm(self.edge_encoder(fluid_edge_feat))  # [edge_num, 64]
            fluid_edge_emb = self.edge_drop_out(fluid_edge_emb)
            fluid_graph.edata['e'] = fluid_edge_emb
        else:
            edge_type = self.bond_graph.has_edges_between(center_idx, neigh_idx)
            fluid_edge_feat = torch.cat((fluid_edge_feat, edge_type.view(-1, 1)), dim=1)
            fluid_edge_emb = self.edge_layer_norm(self.edge_encoder(fluid_edge_feat))  # [edge_num, 64]
            fluid_edge_emb = self.edge_drop_out(fluid_edge_emb)
            fluid_graph.edata['e'] = fluid_edge_emb

        # add self loop for fluid particles
        if self_loop:
            fluid_graph.add_self_loop()
        return fluid_graph

    def build_graph_batches(self, pos_lst, edge_idx_lst):
        graph_lst = []
        for pos, edge_idx in zip(pos_lst, edge_idx_lst):
            graph = self.build_graph(edge_idx, pos)
            graph_lst += [graph]
        batched_graph = dgl.batch(graph_lst)
        return batched_graph

    def build_bond_graph(self, bond) -> dgl.DGLGraph:
        if isinstance(bond, np.ndarray):
            bond = torch.from_numpy(bond).cuda()
        bond_graph = dgl.graph((bond[:, 0], bond[:, 1]))
        bond_graph = dgl.add_reverse_edges(bond_graph)  # undirectional and symmetry
        return bond_graph

    def _update_length_stat(self, new_mean, new_std):
        self.length_mean[0] = new_mean[0]
        self.length_std[0] = new_std[0]

    def fit_length(self, length):
        if not isinstance(length, np.ndarray):
            length = length.detach().cpu().numpy().reshape(-1, 1)
        self.length_scaler.partial_fit(length)

    def forward(self,
                fluid_pos_lst: List[torch.Tensor],  # list of [N, 3]
                x: torch.Tensor,  # node feature    # [b*N, 3]
                fluid_edge_lst: List[torch.Tensor]
                ) -> torch.Tensor:
        if len(fluid_pos_lst) > 1:
            fluid_graph = self.build_graph_batches(fluid_pos_lst, fluid_edge_lst)
        else:
            fluid_graph = self.build_graph(fluid_edge_lst[0], fluid_pos_lst[0])
        x = self.node_encoder(x)
        x = self.graph_conv(x, fluid_graph)

        x = self.graph_decoder(x)
        return x


class SimpleMDNetNew(nn.Module):  # no bond, no learnable node encoder
    def __init__(self,
                 encoding_size,
                 out_feats,
                 box_size,   # can also be array
                 hidden_dim=128,
                 conv_layer=4,
                 edge_embedding_dim=128,
                 dropout=0.1,
                 drop_edge=True,
                 use_layer_norm=False):
        super(SimpleMDNetNew, self).__init__()
        self.graph_conv = SmoothConvBlockNew(in_node_feats=encoding_size,
                                             out_node_feats=encoding_size,
                                             hidden_dim=hidden_dim,
                                             conv_layer=conv_layer,
                                             edge_emb_dim=edge_embedding_dim,
                                             use_layer_norm=use_layer_norm,
                                             use_batch_norm=not use_layer_norm,
                                             drop_edge=drop_edge,
                                             activation='silu')

        self.edge_emb_dim = edge_embedding_dim
        self.edge_expand = RBFExpansion(high=1, gap=0.025)
        self.edge_drop_out = nn.Dropout(dropout)

        self.length_mean = nn.Parameter(torch.tensor([0.]), requires_grad=False)
        self.length_std = nn.Parameter(torch.tensor([1.]), requires_grad=False)
        self.length_scaler = StandardScaler()

        if isinstance(box_size, np.ndarray):
            self.box_size = torch.from_numpy(box_size).float()
        else:
            self.box_size = box_size
        self.box_size = self.box_size

        self.node_emb = nn.Parameter(torch.randn((1, encoding_size)), requires_grad=True)

        self.edge_encoder = MLP(3 + 1 + len(self.edge_expand.centers), self.edge_emb_dim, hidden_dim=hidden_dim,
                                activation='gelu')
        '''
        self.edge_encoder = MLP(3 + 2 + 1 + len(self.edge_expand.centers), self.edge_emb_dim, hidden_dim=hidden_dim,
                                activation='gelu')
        '''
        self.edge_layer_norm = nn.LayerNorm(self.edge_emb_dim)
        self.graph_decoder = MLP(encoding_size, out_feats, hidden_layer=2, hidden_dim=hidden_dim, activation='gelu')

    def calc_edge_feat(self,
                       src_idx: torch.Tensor,
                       dst_idx: torch.Tensor,
                       pos_src: torch.Tensor,
                       pos_dst=None) -> torch.Tensor:
        # this is the raw input feature

        # to enhance computation performance, dont track their calculation on graph
        if pos_dst is None:
            pos_dst = pos_src

        with torch.no_grad():
            rel_pos = pos_dst[dst_idx.long()] - pos_src[src_idx.long()]
            if isinstance(self.box_size, torch.Tensor):
                rel_pos_periodic = torch.remainder(rel_pos + 0.5 * self.box_size.to(rel_pos.device),
                                                   self.box_size.to(rel_pos.device)) - 0.5 * self.box_size.to(rel_pos.device)
            else:
                rel_pos_periodic = torch.remainder(rel_pos + 0.5 * self.box_size,
                                                   self.box_size) - 0.5 * self.box_size

            rel_pos_norm = rel_pos_periodic.norm(dim=1).view(-1, 1)  # [edge_num, 1]
            rel_pos_periodic /= rel_pos_norm + 1e-8   # normalized

        if self.training:
            self.fit_length(rel_pos_norm)
            self._update_length_stat(self.length_scaler.mean_, np.sqrt(self.length_scaler.var_))

        rel_pos_norm = (rel_pos_norm - self.length_mean) / self.length_std

        '''
        # Add inductive bias like r^-6, r^-12 to signal intramolecular potentials

        r_neg_6_prior = rel_pos_norm.pow(-6)
        # Replace inf values with 0
        r_neg_6_prior[torch.isinf(r_neg_6_prior)] = 0.0
        r_neg_12_prior = rel_pos_norm.pow(-12)
        r_neg_12_prior[torch.isinf(r_neg_12_prior)] = 0.0


        edge_feat = torch.cat((rel_pos_periodic,
                               r_neg_6_prior,
                               r_neg_12_prior,
                               rel_pos_norm,
                               self.edge_expand(rel_pos_norm)), dim=1)


        # Min-Max Scaling to avoid Nans downstream
        min_val = torch.min(edge_feat)
        max_val = torch.max(edge_feat)
        edge_feat = (edge_feat - min_val) / (max_val - min_val)

        '''
        edge_feat = torch.cat((rel_pos_periodic,
                               rel_pos_norm,
                               self.edge_expand(rel_pos_norm)), dim=1)

        return edge_feat

    def build_graph(self,
                    fluid_edge_idx: torch.Tensor,
                    fluid_pos: torch.Tensor,
                    self_loop=True) -> dgl.DGLGraph:

        center_idx = fluid_edge_idx[0, :]  # [edge_num, 1]
        neigh_idx = fluid_edge_idx[1, :]
        fluid_graph = dgl.graph((neigh_idx, center_idx))
        fluid_edge_feat = self.calc_edge_feat(center_idx, neigh_idx, fluid_pos)

        fluid_edge_emb = self.edge_layer_norm(self.edge_encoder(fluid_edge_feat))  # [edge_num, 64]
        '''
        if torch.isnan(fluid_edge_emb).any():
          print("NANANANN---1")
          exit(0)
        '''


        fluid_edge_emb = self.edge_drop_out(fluid_edge_emb)
        fluid_graph.edata['e'] = fluid_edge_emb

        # add self loop for fluid particles
        if self_loop:
            fluid_graph.add_self_loop()
        return fluid_graph

    def build_graph_batches(self, pos_lst, edge_idx_lst):
        graph_lst = []
        for pos, edge_idx in zip(pos_lst, edge_idx_lst):
            graph = self.build_graph(edge_idx, pos)
            graph_lst += [graph]
        batched_graph = dgl.batch(graph_lst)
        return batched_graph

    def _update_length_stat(self, new_mean, new_std):
        self.length_mean[0] = new_mean[0]
        self.length_std[0] = new_std[0]

    def fit_length(self, length):
        if not isinstance(length, np.ndarray):
            length = length.detach().cpu().numpy().reshape(-1, 1)
        self.length_scaler.partial_fit(length)

    def forward(self,
                fluid_pos_lst: List[torch.Tensor],  # list of [N, 3]
                fluid_edge_lst: List[torch.Tensor]
                ) -> torch.Tensor:
        if len(fluid_pos_lst) > 1:
            fluid_graph = self.build_graph_batches(fluid_pos_lst, fluid_edge_lst)
        else:
            fluid_graph = self.build_graph(fluid_edge_lst[0], fluid_pos_lst[0])
        num = np.sum([pos.shape[0] for pos in fluid_pos_lst])
        x = self.node_emb.repeat((num, 1))
        x = self.graph_conv(x, fluid_graph)

        x = self.graph_decoder(x)
        return x




### Train GNN and record edge messages

### Train for LJ system test-cases

In [None]:
import argparse
import os, sys
import joblib
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import jax
import jax.numpy as jnp
import cupy

sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
from nn_module import SimpleMDNetNew
from train_utils import LJDataNew
from graph_utils import NeighborSearcher, graph_network_nbr_fn
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "2" # just to test if it works w/o gpu
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# for water box
#CUTOFF_RADIUS = 7.5
#CUTOFF_RADIUS = 10.2
CUTOFF_RADIUS = 15.2
BOX_SIZE = 27.27

NUM_OF_ATOMS = 258

# NUM_OF_ATOMS = 251 * 3  # tip4p
# CUTOFF_RADIUS = 3.4

LAMBDA1 = 100.
LAMBDA2 = 1e-3


def get_rotation_matrix():
    """ Randomly rotate the point clouds to augument the dataset
        rotation is per shape based along up direction
        Input:
          Nx3 array, original point clouds
        Return:
          Nx3 array, rotated point clouds
    """
    if np.random.uniform() < 0.3:
        angles = np.random.randint(-2, 2, size=(3,)) * np.pi
    else:
        angles = [0., 0., 0.]
    Rx = np.array([[1., 0, 0],
                       [0, np.cos(angles[0]), -np.sin(angles[0])],
                       [0, np.sin(angles[0]), np.cos(angles[0])]], dtype=np.float32)
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                       [0, 1, 0],
                       [-np.sin(angles[1]), 0, np.cos(angles[1])]], dtype=np.float32)
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                   [np.sin(angles[2]), np.cos(angles[2]), 0],
                   [0, 0, 1]], dtype=np.float32)
    rotation_matrix = np.matmul(Rz, np.matmul(Ry, Rx))

    return rotation_matrix


def center_positions(pos):
    offset = np.mean(pos, axis=0)
    return pos - offset, offset

def build_model(args, ckpt=None):

    param_dict = {
                  'encoding_size': args.encoding_size,
                  'out_feats': 3,
                  'hidden_dim': args.hidden_dim,
                  'edge_embedding_dim': args.edge_embedding_dim,
                  'conv_layer': 4,
                  'drop_edge': args.drop_edge,
                  'use_layer_norm': args.use_layer_norm,
                  'box_size': BOX_SIZE,
                  }

    print("Using following set of hyper-parameters")
    print(param_dict)
    model = SimpleMDNetNew(**param_dict)

    if ckpt is not None:
        print('Loading model weights from: ', ckpt)
        model.load_state_dict((torch.load(ckpt)))
    return model


class ParticleNetLightning(pl.LightningModule):
    def __init__(self, args, num_device=1, epoch_num=100, batch_size=1, learning_rate=3e-4, log_freq=1000,
                 model_weights_ckpt=None, scaler_ckpt=None):
        super(ParticleNetLightning, self).__init__()
        self.pnet_model = build_model(args, model_weights_ckpt)
        self.epoch_num = epoch_num
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_device = num_device
        self.log_freq = log_freq
        self.train_data_scaler = StandardScaler()
        self.training_mean = np.array([0.])
        self.training_var = np.array([1.])

        if scaler_ckpt is not None:
            self.load_training_stats(scaler_ckpt)

        self.cutoff = CUTOFF_RADIUS
        self.nbr_searcher = NeighborSearcher(BOX_SIZE, self.cutoff)
        self.nbrlst_to_edge_mask = jax.jit(graph_network_nbr_fn(self.nbr_searcher.displacement_fn,
                                                                    self.cutoff,
                                                                    NUM_OF_ATOMS))
        self.nbr_cache = {}
        self.rotate_aug = args.rotate_aug
        self.data_dir = args.data_dir
        self.loss_fn = args.loss
        assert self.loss_fn in ['mae', 'mse', 'l1_message', 'kl_message', 'l1_message_node_embed', 'constrain_msg_stds']

    def load_training_stats(self, scaler_ckpt):
        if scaler_ckpt is not None:
            scaler_info = np.load(scaler_ckpt)
            self.training_mean = scaler_info['mean']
            self.training_var = scaler_info['var']

    def forward(self, pos, feat, edge_idx_tsr):
        return self.denormalize(self.pnet_model(pos, feat, edge_idx_tsr.long()), self.training_var, self.training_mean)

    def denormalize(self, normalized_force, var, mean):
        return normalized_force * \
                np.sqrt(var) +\
                mean

    def predict_forces(self, pos: np.ndarray, verbose=False):
        nbr_start = time.time()
        edge_idx_tsr = self.search_for_neighbor(pos,
                                                self.nbr_searcher,
                                                self.nbrlst_to_edge_mask,
                                                'all')
        nbr_end = time.time()
        # enforce periodic boundary
        pos = np.mod(pos, np.array(BOX_SIZE))
        pos = torch.from_numpy(pos).float().cuda()
        force_start = time.time()
        pred = self.pnet_model([pos],
                               [edge_idx_tsr],
                               )
        force_end = time.time()
        if verbose:
            print('=============================================')
            print(f'Nbr search used time: {nbr_end - nbr_start}')
            print(f'Force eval used time: {force_end - force_start}')

        pred = pred.detach().cpu().numpy()

        pred = self.denormalize(pred, self.training_var, self.training_mean)

        return pred

    def scale_force(self, force, scaler):
        b_pnum, dims = force.shape
        force_flat = force.reshape((-1, 1))
        scaler.partial_fit(force_flat)
        force = torch.from_numpy(scaler.transform(force_flat)).float().view(b_pnum, dims)
        return force

    def get_edge_idx(self, nbrs, pos_jax, mask):
        dummy_center_idx = nbrs.idx.copy()
        #dummy_center_idx = jax.ops.index_update(dummy_center_idx, None,
        #                                        jnp.arange(pos_jax.shape[0]).reshape(-1, 1))
        dummy_center_idx = dummy_center_idx.at[:].set(jnp.arange(pos_jax.shape[0]).reshape(-1, 1))
        center_idx = dummy_center_idx.reshape(-1)
        center_idx_ = cupy.asarray(center_idx)
        center_idx_tsr = torch.as_tensor(center_idx_, device='cuda')

        neigh_idx = nbrs.idx.reshape(-1)

        # cast jax device array to cupy array so that it can be transferred to torch
        neigh_idx = cupy.asarray(neigh_idx)
        mask = cupy.asarray(mask)
        mask = torch.as_tensor(mask, device='cuda')
        flat_mask = mask.view(-1)
        neigh_idx_tsr = torch.as_tensor(neigh_idx, device='cuda')

        edge_idx_tsr = torch.cat((center_idx_tsr[flat_mask].view(1, -1), neigh_idx_tsr[flat_mask].view(1, -1)),
                                 dim=0)
        return edge_idx_tsr

    def search_for_neighbor(self, pos, nbr_searcher, masking_fn, type_name):
        pos_jax = jax.device_put(pos, jax.devices("gpu")[0])

        if not nbr_searcher.has_been_init:
            nbrs = nbr_searcher.init_new_neighbor_lst(pos_jax)
            self.nbr_cache[type_name] = nbrs
        else:
            nbrs = nbr_searcher.update_neighbor_lst(pos_jax, self.nbr_cache[type_name])
            self.nbr_cache[type_name] = nbrs

        edge_mask_all = masking_fn(pos_jax, nbrs.idx)
        edge_idx_tsr = self.get_edge_idx(nbrs, pos_jax, edge_mask_all)
        return edge_idx_tsr.long()

    def training_step(self, batch, batch_nb):
        pos_lst = batch['pos']
        gt_lst = batch['forces']
        edge_idx_lst = []
        for b in range(len(gt_lst)):
            pos, gt = pos_lst[b], gt_lst[b]

            if self.rotate_aug:
                pos = np.mod(pos, BOX_SIZE)
                pos, off = center_positions(pos)
                R = get_rotation_matrix()
                pos = np.matmul(pos, R)
                pos += off
                gt = np.matmul(gt, R)

            pos = np.mod(pos, BOX_SIZE)

            gt = self.scale_force(gt, self.train_data_scaler).cuda()
            pos_lst[b] = torch.from_numpy(pos).float().cuda()
            gt_lst[b] = gt

            edge_idx_tsr = self.search_for_neighbor(pos,
                                                    self.nbr_searcher,
                                                    self.nbrlst_to_edge_mask,
                                                    'all')
            edge_idx_lst += [edge_idx_tsr]
        gt = torch.cat(gt_lst, dim=0)
        pos_lst = [pos + torch.randn_like(pos) * 0.005 for pos in pos_lst]

        pred = self.pnet_model(pos_lst,
                               edge_idx_lst,
                               )

        if self.loss_fn == 'mae':
            loss = nn.L1Loss()(pred, gt)
        elif self.loss_fn == 'l1_message':
            regularization = 1e-2
            m12 = self.pnet_model.graph_conv.conv[-1].edge_message_neigh_center
            normalized_l05 = torch.mean(torch.abs(m12))
            mae = nn.L1Loss()(pred, gt)
            message_regularization_term = regularization * normalized_l05
            loss = mae + message_regularization_term
        elif self.loss_fn == 'kl_message':
            mae = nn.L1Loss()(pred, gt)
            raw_msg = self.pnet_model.graph_conv.conv[-1].edge_message_neigh_center
            mu = raw_msg[:, 0::2]
            logvar = raw_msg[:, 1::2]
            full_kl = torch.mean(torch.exp(logvar) + mu**2 - logvar)/2.0
            loss = mae + full_kl
        elif self.loss_fn == 'l1_message_node_embed':
            regularization = 1e-1
            m12 = self.pnet_model.graph_conv.conv[-1].edge_message_neigh_center
            normalized_l05 = torch.mean(torch.abs(m12))
            message_regularization_term = regularization * normalized_l05

            n12 = self.pnet_model.graph_conv.conv[-1].input_node_embeddings
            normalized_n05 = torch.mean(torch.abs(n12))
            node_embed_regularization_term = regularization * normalized_n05

            mae = nn.L1Loss()(pred, gt)

            loss = mae + message_regularization_term + node_embed_regularization_term

        elif self.loss_fn == 'constrain_msg_stds':
            mae = nn.L1Loss()(pred, gt)
            regularization = 1e-1
            m12 = self.pnet_model.graph_conv.conv[-1].edge_message_neigh_center
            std_remaining_abs = torch.abs(torch.std(m12[:, 1:], dim=0)) # k = 1, to align edge messages with pair-potential prediction
            mean_std_remaining_abs = torch.mean(std_remaining_abs)
            loss = mae + regularization * mean_std_remaining_abs # inductive bias to push all info to first k message components.
        else:
            loss = nn.MSELoss()(pred, gt)

        #conservative_loss = (torch.mean(pred)).abs()
        #loss = loss + LAMBDA2 * conservative_loss

        self.training_mean = self.train_data_scaler.mean_
        self.training_var = self.train_data_scaler.var_

        self.log('total loss', loss, on_step=True, prog_bar=True, logger=True)
        self.log(f'{self.loss_fn} loss', loss, on_step=True, prog_bar=True, logger=True)
        self.log('var', torch.tensor(np.sqrt(self.training_var)), on_step=True, prog_bar=True, logger=True)

        return loss

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        sched = StepLR(optim, step_size=5, gamma=0.001**(5/self.epoch_num))
        return [optim], [sched]

    def train_dataloader(self):
        dataset = LJDataNew(dataset_path=os.path.join(self.data_dir, ''),
                               sample_num=1000,
                               case_prefix='data_',
                               seed_num=10,
                               mode='train')

        return DataLoader(dataset, num_workers=2, batch_size=self.batch_size, shuffle=True,
                          collate_fn=
                          lambda batches: {
                              'pos': [batch['pos'] for batch in batches],
                              'forces': [batch['forces'] for batch in batches],
                          })

    def val_dataloader(self):
        dataset = LJDataNew(dataset_path=os.path.join(self.data_dir, ''),
                               sample_num=1000,
                               case_prefix='data_',
                               seed_num=10,
                               mode='test')

        return DataLoader(dataset, num_workers=2, batch_size=16, shuffle=False,
                          collate_fn=
                          lambda batches: {
                              'pos': [batch['pos'] for batch in batches],
                              'forces': [batch['forces'] for batch in batches],
                          })

    def validation_step(self, batch, batch_nb):
        with torch.no_grad():

            pos_lst = batch['pos']
            gt_lst = batch['forces']
            edge_idx_lst = []
            for b in range(len(gt_lst)):
                pos, gt = pos_lst[b], gt_lst[b]
                pos = np.mod(pos, BOX_SIZE)

                gt = self.scale_force(gt, self.train_data_scaler).cuda()
                pos_lst[b] = torch.from_numpy(pos).float().cuda()
                gt_lst[b] = gt

                edge_idx_tsr = self.search_for_neighbor(pos,
                                                        self.nbr_searcher,
                                                        self.nbrlst_to_edge_mask,
                                                        'all')
                edge_idx_lst += [edge_idx_tsr]
            gt = torch.cat(gt_lst, dim=0)

            pred = self.pnet_model(pos_lst,
                                   edge_idx_lst,
                                   )
            ratio = torch.sqrt((pred.reshape(-1) - gt.reshape(-1)) ** 2) / (torch.abs(pred.reshape(-1)) + 1e-8)
            outlier_ratio = ratio[ratio > 10.].shape[0] / ratio.shape[0]
            mse = nn.MSELoss()(pred, gt)
            mae = nn.L1Loss()(pred, gt)

            batch_size = len(gt_lst)
            self.log('val outlier', outlier_ratio, prog_bar=True, logger=True, batch_size = batch_size)
            self.log('val mse', mse, prog_bar=True, logger=True, batch_size = batch_size)
            self.log('val mae', mae, prog_bar=True, logger=True, batch_size = batch_size)


class ModelCheckpointAtEpochEnd(pl.Callback):
    """
       Save a checkpoint at epoch end
    """
    def __init__(
            self,
            filepath,
            save_step_frequency,
            prefix="checkpoint",
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
        """
        self.filepath = filepath
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix

    def on_epoch_end(self, trainer: pl.Trainer, pl_module: ParticleNetLightning):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        if epoch % self.save_step_frequency == 0 or epoch == pl_module.epoch_num -1:
            filename = os.path.join(self.filepath, f"{self.prefix}_{epoch}.ckpt")
            scaler_filename = os.path.join(self.filepath, f"scaler_{epoch}.npz")

            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)
            np.savez(scaler_filename,
                     mean=pl_module.training_mean,
                     var=pl_module.training_var,
                     )
            # joblib.dump(pl_module.train_data_scaler, scaler_filename)


def train_model(args):
    lr = args.lr
    num_gpu = args.num_gpu
    check_point_dir = args.cp_dir
    min_epoch = args.min_epoch
    max_epoch = args.max_epoch
    weight_ckpt = args.state_ckpt_dir
    batch_size = args.batch_size

    model = ParticleNetLightning(epoch_num=max_epoch,
                                 num_device=num_gpu if num_gpu != -1 else 1,
                                 learning_rate=lr,
                                 model_weights_ckpt=weight_ckpt,
                                 batch_size=batch_size,
                                 args=args)
    cwd = os.getcwd()
    model_check_point_dir = os.path.join(cwd, check_point_dir)
    os.makedirs(model_check_point_dir, exist_ok=True)
    #print("Checkpoints will be saved at: ", model_check_point_dir)
    epoch_end_callback = ModelCheckpointAtEpochEnd(filepath=model_check_point_dir, save_step_frequency=1)
    checkpoint_callback = pl.callbacks.ModelCheckpoint()

    trainer = Trainer(
        devices=[0],#num_gpu,  # Use 'devices' instead of 'gpus'
        accelerator='gpu',  # Specify the accelerator as 'gpu'
        callbacks=[epoch_end_callback, checkpoint_callback],
        min_epochs=min_epoch,
        max_epochs=max_epoch,
        precision=16,  # Use 'precision' for mixed precision if needed
        benchmark=True,
        strategy='ddp',  # Use 'strategy' for distributed training
        default_root_dir='/home/pranav/gamd_sr/official/GAMD-main/code/LJ/model_ckpt'
    )

    # Get the version number used for this training run
    version_number = trainer.logger.version
    print(f"Checkpoints are saved in: {model_check_point_dir}/lightning_logs/version_{version_number}/checkpoints/")
    trainer.fit(model)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--min_epoch', default=30, type=int)
    parser.add_argument('--max_epoch', default=30, type=int)
    parser.add_argument('--lr', default=3e-4, type=float)
    parser.add_argument('--cp_dir', default='model_ckpt')
    parser.add_argument('--state_ckpt_dir', default=None, type=str)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--encoding_size', default=256, type=int)
    parser.add_argument('--hidden_dim', default=128, type=int)
    parser.add_argument('--edge_embedding_dim', default=256, type=int)
    parser.add_argument('--drop_edge', action='store_true')
    parser.add_argument('--use_layer_norm', action='store_true')
    parser.add_argument('--disable_rotate_aug', dest='rotate_aug', default=True, action='store_false')
    parser.add_argument('--data_dir', default='./md_dataset')
    parser.add_argument('--loss', default='mae')
    parser.add_argument('--num_gpu', default=-1, type=int)
    args = parser.parse_args()
    train_model(args)


if __name__ == '__main__':
    main()



In [None]:
python train_network_lj.py --num_gpu 1 --min_epoch 30 --max_epoch 30 --batch_size 1 --encoding_size 128 --hidden_dim 128 --edge_embedding_dim 128 --loss constrain_msg_stds --data_dir ../../../../top/  --use_layer_norm

### Train for tip3p test-cases

In [None]:
import argparse
import os, sys
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import jax
import jax.numpy as jnp
import cupy

sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
from nn_module import WaterMDNetNew
from train_utils import WaterDataNew
from graph_utils import NeighborSearcher, graph_network_nbr_fn
# os.environ["CUDA_VISIBLE_DEVICES"] = "" # just to test if it works w/o gpu
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# for water box
CUTOFF_RADIUS = 4.2
left_bound = 0.0
right_bound = 20.0
BOX_SIZE = right_bound - left_bound

NUM_OF_ATOMS = 258 * 3

# NUM_OF_ATOMS = 251 * 3  # tip4p
# CUTOFF_RADIUS = 3.4

LAMBDA1 = 100.
LAMBDA2 = 1e-3


def create_water_bond(total_atom_num):
    bond = []
    for i in range(0, total_atom_num, 3):
        bond += [[i, i+1], [i, i+2]]
    return np.array(bond)


def get_rotation_matrix():
    """ Randomly rotate the point clouds to augument the dataset
        rotation is per shape based along up direction
        Input:
          Nx3 array, original point clouds
        Return:
          Nx3 array, rotated point clouds
    """
    if np.random.uniform() < 0.3:
        angles = np.random.randint(-2, 2, size=(3,)) * np.pi
    else:
        angles = [0., 0., 0.]
    Rx = np.array([[1., 0, 0],
                       [0, np.cos(angles[0]), -np.sin(angles[0])],
                       [0, np.sin(angles[0]), np.cos(angles[0])]], dtype=np.float32)
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                       [0, 1, 0],
                       [-np.sin(angles[1]), 0, np.cos(angles[1])]], dtype=np.float32)
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                   [np.sin(angles[2]), np.cos(angles[2]), 0],
                   [0, 0, 1]], dtype=np.float32)
    rotation_matrix = np.matmul(Rz, np.matmul(Ry, Rx))

    return rotation_matrix


def center_positions(pos):
    offset = np.mean(pos, axis=0)
    return pos - offset, offset

def build_model(args, ckpt=None):
    bond_info = create_water_bond(NUM_OF_ATOMS)

    param_dict = {'in_feats': 1,
                  'encoding_size': args.encoding_size,
                  'out_feats': 3,
                  'bond': bond_info,
                  'hidden_dim': args.hidden_dim,
                  'edge_embedding_dim': args.edge_embedding_dim,
                  'conv_layer': 4,
                  'drop_edge': args.drop_edge,
                  'use_layer_norm': args.use_layer_norm,
                  'box_size': BOX_SIZE,
                  }

    print("Using following set of hyper-parameters")
    print(param_dict)
    model = WaterMDNetNew(**param_dict)

    if ckpt is not None:
        print('Loading model weights from: ', ckpt)
        model.load_state_dict((torch.load(ckpt)))
    return model


class ParticleNetLightning(pl.LightningModule):
    def __init__(self, args, num_device=1, epoch_num=100, batch_size=1, learning_rate=3e-4, log_freq=1000,
                 model_weights_ckpt=None, scaler_ckpt=None):
        super(ParticleNetLightning, self).__init__()
        self.pnet_model = build_model(args, model_weights_ckpt)
        self.epoch_num = epoch_num
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_device = num_device
        self.log_freq = log_freq
        self.train_data_scaler = StandardScaler()
        self.training_mean = np.array([0.])
        self.training_var = np.array([1.])

        if scaler_ckpt is not None:
            self.load_training_stats(scaler_ckpt)

        self.cutoff = CUTOFF_RADIUS
        self.nbr_searcher = NeighborSearcher(BOX_SIZE, self.cutoff)
        self.nbrlst_to_edge_mask = jax.jit(graph_network_nbr_fn(self.nbr_searcher.displacement_fn,
                                                                    self.cutoff,
                                                                    NUM_OF_ATOMS))
        self.nbr_cache = {}
        self.rotate_aug = args.rotate_aug
        self.data_dir = args.data_dir
        self.loss_fn = args.loss
        assert self.loss_fn in ['mae', 'mse']

    def load_training_stats(self, scaler_ckpt):
        if scaler_ckpt is not None:
            scaler_info = np.load(scaler_ckpt)
            self.training_mean = scaler_info['mean']
            self.training_var = scaler_info['var']

    def forward(self, pos, feat, edge_idx_tsr):
        return self.denormalize(self.pnet_model(pos, feat, edge_idx_tsr.long()), self.training_var, self.training_mean)

    def denormalize(self, normalized_force, var, mean):
        return normalized_force * \
                np.sqrt(var) +\
                mean

    def predict_forces(self, feat: torch.Tensor, pos: np.ndarray):
        edge_idx_tsr = self.search_for_neighbor(pos,
                                                self.nbr_searcher,
                                                self.nbrlst_to_edge_mask,
                                                'all')
        # enforce periodic boundary
        pos = np.mod(pos, np.array(BOX_SIZE))
        pos = torch.from_numpy(pos).float().to(feat.device)
        pred = self.pnet_model([pos],
                               feat,
                               [edge_idx_tsr],
                               )

        pred = pred.detach().cpu().numpy()

        pred = self.denormalize(pred, self.training_var, self.training_mean)

        return pred

    def scale_force(self, force, scaler):
        b_pnum, dims = force.shape
        force_flat = force.reshape((-1, 1))
        scaler.partial_fit(force_flat)
        force = torch.from_numpy(scaler.transform(force_flat)).float().view(b_pnum, dims)
        return force

    def get_edge_idx(self, nbrs, pos_jax, mask):
        dummy_center_idx = nbrs.idx.copy()
        dummy_center_idx = jax.ops.index_update(dummy_center_idx, None,
                                                jnp.arange(pos_jax.shape[0]).reshape(-1, 1))
        center_idx = dummy_center_idx.reshape(-1)
        center_idx_ = cupy.asarray(center_idx)
        center_idx_tsr = torch.as_tensor(center_idx_, device='cuda')

        neigh_idx = nbrs.idx.reshape(-1)

        # cast jax device array to cupy array so that it can be transferred to torch
        neigh_idx = cupy.asarray(neigh_idx)
        mask = cupy.asarray(mask)
        mask = torch.as_tensor(mask, device='cuda')
        flat_mask = mask.view(-1)
        neigh_idx_tsr = torch.as_tensor(neigh_idx, device='cuda')

        edge_idx_tsr = torch.cat((center_idx_tsr[flat_mask].view(1, -1), neigh_idx_tsr[flat_mask].view(1, -1)),
                                 dim=0)
        return edge_idx_tsr

    def search_for_neighbor(self, pos, nbr_searcher, masking_fn, type_name):
        pos_jax = jax.device_put(pos, jax.devices("gpu")[0])

        if not nbr_searcher.has_been_init:
            nbrs = nbr_searcher.init_new_neighbor_lst(pos_jax)
            self.nbr_cache[type_name] = nbrs
        else:
            nbrs = nbr_searcher.update_neighbor_lst(pos_jax, self.nbr_cache[type_name])
            self.nbr_cache[type_name] = nbrs

        edge_mask_all = masking_fn(pos_jax, nbrs.idx)
        edge_idx_tsr = self.get_edge_idx(nbrs, pos_jax, edge_mask_all)
        return edge_idx_tsr.long()

    def training_step(self, batch, batch_nb):
        feat, pos_lst = batch['feat'], batch['pos']
        gt_lst = batch['forces']
        edge_idx_lst = []
        for b in range(len(gt_lst)):
            pos, gt = pos_lst[b], gt_lst[b]

            if self.rotate_aug:
                pos = np.mod(pos, BOX_SIZE)
                pos, off = center_positions(pos)
                R = get_rotation_matrix()
                pos = np.matmul(pos, R)
                pos += off
                gt = np.matmul(gt, R)

            pos = np.mod(pos, BOX_SIZE)

            gt = self.scale_force(gt, self.train_data_scaler).to(feat.device)
            pos_lst[b] = torch.from_numpy(pos).float().to(feat.device)
            gt_lst[b] = gt

            edge_idx_tsr = self.search_for_neighbor(pos,
                                                    self.nbr_searcher,
                                                    self.nbrlst_to_edge_mask,
                                                    'all')
            edge_idx_lst += [edge_idx_tsr]
        gt = torch.cat(gt_lst, dim=0)
        pos_lst = [pos + torch.randn_like(pos) * 0.005 for pos in pos_lst]

        pred = self.pnet_model(pos_lst,
                               feat,
                               edge_idx_lst,
                               )

        if self.loss_fn == 'mae':
            loss = nn.L1Loss()(pred, gt)
        else:
            loss = nn.MSELoss()(pred, gt)

        conservative_loss = (torch.mean(pred)).abs()
        loss = loss + LAMBDA2 * conservative_loss

        self.training_mean = self.train_data_scaler.mean_
        self.training_var = self.train_data_scaler.var_

        self.log('total loss', loss, on_step=True, prog_bar=True, logger=True)
        self.log(f'{self.loss_fn} loss', loss, on_step=True, prog_bar=True, logger=True)
        self.log('var', np.sqrt(self.training_var), on_step=True, prog_bar=True, logger=True)

        return loss

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        sched = StepLR(optim, step_size=5, gamma=0.001**(5/self.epoch_num))
        return [optim], [sched]

    def train_dataloader(self):
        dataset = WaterDataNew(dataset_path=os.path.join(self.data_dir, 'water_data_tip3p'),
                               sample_num=1000,
                               case_prefix='data_',
                               seed_num=10,
                               m_num=NUM_OF_ATOMS//3,
                               mode='train',
                               data_type='tip3p')

        return DataLoader(dataset, num_workers=2, batch_size=self.batch_size, shuffle=True,
                          collate_fn=
                          lambda batches: {
                              'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                              'pos': [batch['pos'] for batch in batches],
                              'forces': [batch['forces'] for batch in batches],
                          })

    def val_dataloader(self):
        dataset = WaterDataNew(dataset_path=os.path.join(self.data_dir, 'water_data_tip3p'),
                               sample_num=1000,
                               case_prefix='data_',
                               seed_num=10,
                               m_num=NUM_OF_ATOMS//3,
                               mode='test',
                               data_type='tip3p')

        return DataLoader(dataset, num_workers=2, batch_size=16, shuffle=False,
                          collate_fn=
                          lambda batches: {
                              'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                              'pos': [batch['pos'] for batch in batches],
                              'forces': [batch['forces'] for batch in batches],
                          })

    def validation_step(self, batch, batch_nb):
        with torch.no_grad():

            feat, pos_lst = batch['feat'], batch['pos']
            gt_lst = batch['forces']
            edge_idx_lst = []
            for b in range(len(gt_lst)):
                pos, gt = pos_lst[b], gt_lst[b]
                pos = np.mod(pos, BOX_SIZE)

                gt = self.scale_force(gt, self.train_data_scaler).to(feat.device)
                pos_lst[b] = torch.from_numpy(pos).float().to(feat.device)
                gt_lst[b] = gt

                edge_idx_tsr = self.search_for_neighbor(pos,
                                                        self.nbr_searcher,
                                                        self.nbrlst_to_edge_mask,
                                                        'all')
                edge_idx_lst += [edge_idx_tsr]
            gt = torch.cat(gt_lst, dim=0)

            pred = self.pnet_model(pos_lst,
                                   feat,
                                   edge_idx_lst,
                                   )
            ratio = torch.sqrt((pred.reshape(-1) - gt.reshape(-1)) ** 2) / (torch.abs(pred.reshape(-1)) + 1e-8)
            outlier_ratio = ratio[ratio > 10.].shape[0] / ratio.shape[0]
            mse = nn.MSELoss()(pred, gt)
            mae = nn.L1Loss()(pred, gt)

            self.log('val outlier', outlier_ratio, prog_bar=True, logger=True)
            self.log('val mse', mse, prog_bar=True, logger=True)
            self.log('val mae', mae, prog_bar=True, logger=True)


class ModelCheckpointAtEpochEnd(pl.Callback):
    """
       Save a checkpoint at epoch end
    """
    def __init__(
            self,
            filepath,
            save_step_frequency,
            prefix="checkpoint",
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
        """
        self.filepath = filepath
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix

    def on_epoch_end(self, trainer: pl.Trainer, pl_module: ParticleNetLightning):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        if epoch % self.save_step_frequency == 0 or epoch == pl_module.epoch_num -1:
            filename = os.path.join(self.filepath, f"{self.prefix}_{epoch}.ckpt")
            scaler_filename = os.path.join(self.filepath, f"scaler_{epoch}.npz")

            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)
            np.savez(scaler_filename,
                     mean=pl_module.training_mean,
                     var=pl_module.training_var,
                     )
            # joblib.dump(pl_module.train_data_scaler, scaler_filename)


def train_model(args):
    lr = args.lr
    num_gpu = args.num_gpu
    check_point_dir = args.cp_dir
    min_epoch = args.min_epoch
    max_epoch = args.max_epoch
    weight_ckpt = args.state_ckpt_dir
    batch_size = args.batch_size

    model = ParticleNetLightning(epoch_num=max_epoch,
                                 num_device=num_gpu if num_gpu != -1 else 1,
                                 learning_rate=lr,
                                 model_weights_ckpt=weight_ckpt,
                                 batch_size=batch_size,
                                 args=args)
    cwd = os.getcwd()
    model_check_point_dir = os.path.join(cwd, check_point_dir)
    os.makedirs(model_check_point_dir, exist_ok=True)
    epoch_end_callback = ModelCheckpointAtEpochEnd(filepath=model_check_point_dir, save_step_frequency=5)
    checkpoint_callback = pl.callbacks.ModelCheckpoint()

    trainer = Trainer(gpus=num_gpu,
                      callbacks=[epoch_end_callback, checkpoint_callback],
                      min_epochs=min_epoch,
                      max_epochs=max_epoch,
                      amp_backend='apex',
                      amp_level='O1',
                      benchmark=True,
                      distributed_backend='ddp',
                      )
    trainer.fit(model)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--min_epoch', default=30, type=int)
    parser.add_argument('--max_epoch', default=30, type=int)
    parser.add_argument('--lr', default=3e-4, type=float)
    parser.add_argument('--cp_dir', default='./model_ckpt')
    parser.add_argument('--state_ckpt_dir', default=None, type=str)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--encoding_size', default=256, type=int)
    parser.add_argument('--hidden_dim', default=128, type=int)
    parser.add_argument('--edge_embedding_dim', default=256, type=int)
    parser.add_argument('--drop_edge', action='store_true')
    parser.add_argument('--use_layer_norm', action='store_true')
    parser.add_argument('--disable_rotate_aug', dest='rotate_aug', default=True, action='store_false')
    parser.add_argument('--data_dir', default='./md_dataset')
    parser.add_argument('--loss', default='mae')
    parser.add_argument('--num_gpu', default=-1, type=int)
    args = parser.parse_args()
    train_model(args)


if __name__ == '__main__':
    main()



### Train for tip4p test-cases

In [None]:
import argparse
import os, sys
import joblib
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import jax
import jax.numpy as jnp
import cupy

sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
from nn_module import WaterMDNetNew
from train_utils import WaterDataNew
from graph_utils import NeighborSearcher, graph_network_nbr_fn
# os.environ["CUDA_VISIBLE_DEVICES"] = "" # just to test if it works w/o gpu
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# for water box
CUTOFF_RADIUS = 4.2
left_bound = 0.0
right_bound = 20.0
BOX_SIZE = right_bound - left_bound

NUM_OF_ATOMS = 251 * 3  # 258 *3

LAMBDA1 = 100.
LAMBDA2 = 1e-3


def create_water_bond(total_atom_num):
    bond = []
    for i in range(0, total_atom_num, 3):
        bond += [[i, i+1], [i, i+2]]
    return np.array(bond)


def get_rotation_matrix():
    """ Randomly rotate the point clouds to augument the dataset
        rotation is per shape based along up direction
        Input:
          Nx3 array, original point clouds
        Return:
          Nx3 array, rotated point clouds
    """
    if np.random.uniform() < 0.3:
        angles = np.random.randint(-2, 2, size=(3,)) * np.pi
    else:
        angles = [0., 0., 0.]
    Rx = np.array([[1., 0, 0],
                       [0, np.cos(angles[0]), -np.sin(angles[0])],
                       [0, np.sin(angles[0]), np.cos(angles[0])]], dtype=np.float32)
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                       [0, 1, 0],
                       [-np.sin(angles[1]), 0, np.cos(angles[1])]], dtype=np.float32)
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                   [np.sin(angles[2]), np.cos(angles[2]), 0],
                   [0, 0, 1]], dtype=np.float32)
    rotation_matrix = np.matmul(Rz, np.matmul(Ry, Rx))

    return rotation_matrix


def center_positions(pos):
    offset = np.mean(pos, axis=0)
    return pos - offset, offset

def build_model(args, ckpt=None):
    bond_info = create_water_bond(NUM_OF_ATOMS)

    param_dict = {'in_feats': 1,
                  'encoding_size': args.encoding_size,
                  'out_feats': 3,
                  'bond': bond_info,
                  'hidden_dim': args.hidden_dim,
                  'edge_embedding_dim': args.edge_embedding_dim,
                  'conv_layer': 4,
                  'drop_edge': args.drop_edge,
                  'use_layer_norm': args.use_layer_norm,
                  'box_size': BOX_SIZE,
                  }

    print("Using following set of hyper-parameters")
    print(param_dict)
    model = WaterMDNetNew (**param_dict)

    if ckpt is not None:
        print('Loading model weights from: ', ckpt)
        model.load_state_dict((torch.load(ckpt)))
    return model


class ParticleNetLightning(pl.LightningModule):
    def __init__(self, args, num_device=1, epoch_num=100, batch_size=1, learning_rate=3e-4, log_freq=1000,
                 model_weights_ckpt=None, scaler_ckpt=None):
        super(ParticleNetLightning, self).__init__()
        self.pnet_model = build_model(args, model_weights_ckpt)
        self.epoch_num = epoch_num
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_device = num_device
        self.log_freq = log_freq
        self.train_data_scaler = StandardScaler()
        self.training_mean = np.array([0.])
        self.training_var = np.array([1.])

        if scaler_ckpt is not None:
            self.load_training_stats(scaler_ckpt)

        self.cutoff = CUTOFF_RADIUS
        self.nbr_searcher = NeighborSearcher(BOX_SIZE, self.cutoff)
        self.nbrlst_to_edge_mask = jax.jit(graph_network_nbr_fn(self.nbr_searcher.displacement_fn,
                                                                    self.cutoff,
                                                                    NUM_OF_ATOMS))
        self.nbr_cache = {}
        self.rotate_aug = args.rotate_aug
        self.data_dir = args.data_dir
        self.loss_fn = args.loss
        assert self.loss_fn in ['mae', 'mse']

    def load_training_stats(self, scaler_ckpt):
        if scaler_ckpt is not None:
            scaler_info = np.load(scaler_ckpt)
            self.training_mean = scaler_info['mean']
            self.training_var = scaler_info['var']

    def forward(self, pos, feat, edge_idx_tsr):
        return self.denormalize(self.pnet_model(pos, feat, edge_idx_tsr.long()), self.training_var, self.training_mean)

    def denormalize(self, normalized_force, var, mean):
        return normalized_force * \
                np.sqrt(var) +\
                mean

    def predict_forces(self, feat: torch.Tensor, pos: np.ndarray):
        edge_idx_tsr = self.search_for_neighbor(pos,
                                                self.nbr_searcher,
                                                self.nbrlst_to_edge_mask,
                                                'all')
        # enforce periodic boundary
        pos = np.mod(pos, np.array(BOX_SIZE))
        pos = torch.from_numpy(pos).float().to(feat.device)
        pred = self.pnet_model([pos],
                               feat,
                               [edge_idx_tsr],
                               )
        pred = pred.detach().cpu().numpy()

        pred = self.denormalize(pred, self.training_var, self.training_mean)

        return pred

    def scale_force(self, force, scaler):
        b_pnum, dims = force.shape
        force_flat = force.reshape((-1, 1))
        scaler.partial_fit(force_flat)
        force = torch.from_numpy(scaler.transform(force_flat)).float().view(b_pnum, dims)
        return force

    def get_edge_idx(self, nbrs, pos_jax, mask):
        dummy_center_idx = nbrs.idx.copy()
        dummy_center_idx = jax.ops.index_update(dummy_center_idx, None,
                                                jnp.arange(pos_jax.shape[0]).reshape(-1, 1))
        center_idx = dummy_center_idx.reshape(-1)
        center_idx_ = cupy.asarray(center_idx)
        center_idx_tsr = torch.as_tensor(center_idx_, device='cuda')

        neigh_idx = nbrs.idx.reshape(-1)

        # cast jax device array to cupy array so that it can be transferred to torch
        neigh_idx = cupy.asarray(neigh_idx)
        mask = cupy.asarray(mask)
        mask = torch.as_tensor(mask, device='cuda')
        flat_mask = mask.view(-1)
        neigh_idx_tsr = torch.as_tensor(neigh_idx, device='cuda')

        edge_idx_tsr = torch.cat((center_idx_tsr[flat_mask].view(1, -1), neigh_idx_tsr[flat_mask].view(1, -1)),
                                 dim=0)
        return edge_idx_tsr

    def search_for_neighbor(self, pos, nbr_searcher, masking_fn, type_name):
        pos_jax = jax.device_put(pos, jax.devices("gpu")[0])

        if not nbr_searcher.has_been_init:
            nbrs = nbr_searcher.init_new_neighbor_lst(pos_jax)
            self.nbr_cache[type_name] = nbrs
        else:
            nbrs = nbr_searcher.update_neighbor_lst(pos_jax, self.nbr_cache[type_name])
            self.nbr_cache[type_name] = nbrs

        edge_mask_all = masking_fn(pos_jax, nbrs.idx)
        edge_idx_tsr = self.get_edge_idx(nbrs, pos_jax, edge_mask_all)
        return edge_idx_tsr.long()

    def training_step(self, batch, batch_nb):
        feat, pos_lst = batch['feat'], batch['pos']
        gt_lst = batch['forces']
        edge_idx_lst = []
        for b in range(len(gt_lst)):
            pos, gt = pos_lst[b], gt_lst[b]

            if self.rotate_aug:
                pos = np.mod(pos, BOX_SIZE)
                pos, off = center_positions(pos)
                R = get_rotation_matrix()
                pos = np.matmul(pos, R)
                pos += off
                gt = np.matmul(gt, R)

            pos = np.mod(pos, BOX_SIZE)

            gt = self.scale_force(gt, self.train_data_scaler).to(feat.device)
            pos_lst[b] = torch.from_numpy(pos).float().to(feat.device)
            gt_lst[b] = gt

            edge_idx_tsr = self.search_for_neighbor(pos,
                                                    self.nbr_searcher,
                                                    self.nbrlst_to_edge_mask,
                                                    'all')
            edge_idx_lst += [edge_idx_tsr]
        gt = torch.cat(gt_lst, dim=0)
        pos_lst = [pos + torch.randn_like(pos) * 0.005 for pos in pos_lst]

        pred = self.pnet_model(pos_lst,
                               feat,
                               edge_idx_lst,
                               )

        if self.loss_fn == 'mae':
            loss = nn.L1Loss()(pred, gt)
        else:
            loss = nn.MSELoss()(pred, gt)

        conservative_loss = (torch.mean(pred)).abs()
        loss = loss + LAMBDA2 * conservative_loss

        self.training_mean = self.train_data_scaler.mean_
        self.training_var = self.train_data_scaler.var_

        self.log('total loss', loss, on_step=True, prog_bar=True, logger=True)
        self.log(f'{self.loss_fn} loss', loss, on_step=True, prog_bar=True, logger=True)
        self.log('var', np.sqrt(self.training_var), on_step=True, prog_bar=True, logger=True)

        return loss

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        sched = StepLR(optim, step_size=5, gamma=0.001**(5/self.epoch_num))
        return [optim], [sched]

    def train_dataloader(self):
        dataset = WaterDataNew(dataset_path=os.path.join(self.data_dir, 'water_data_tip4p'),
                               sample_num=1000,
                               case_prefix='data_',
                               seed_num=10,
                               m_num=NUM_OF_ATOMS//3,
                               mode='train',
                               data_type='tip4p')

        return DataLoader(dataset, num_workers=2, batch_size=self.batch_size, shuffle=True,
                          collate_fn=
                          lambda batches: {
                              'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                              'pos': [batch['pos'] for batch in batches],
                              'forces': [batch['forces'] for batch in batches],
                          })

    def val_dataloader(self):
        dataset = WaterDataNew(dataset_path=os.path.join(self.data_dir, 'water_data_tip4p'),
                               sample_num=1000,
                               case_prefix='data_',
                               seed_num=10,
                               m_num=NUM_OF_ATOMS//3,
                               mode='test',
                               data_type='tip4p')

        return DataLoader(dataset, num_workers=2, batch_size=16, shuffle=False,
                          collate_fn=
                          lambda batches: {
                              'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                              'pos': [batch['pos'] for batch in batches],
                              'forces': [batch['forces'] for batch in batches],
                          })

    def validation_step(self, batch, batch_nb):
        with torch.no_grad():

            feat, pos_lst = batch['feat'], batch['pos']
            gt_lst = batch['forces']
            edge_idx_lst = []
            for b in range(len(gt_lst)):
                pos, gt = pos_lst[b], gt_lst[b]
                pos = np.mod(pos, BOX_SIZE)

                gt = self.scale_force(gt, self.train_data_scaler).to(feat.device)
                pos_lst[b] = torch.from_numpy(pos).float().to(feat.device)
                gt_lst[b] = gt

                edge_idx_tsr = self.search_for_neighbor(pos,
                                                        self.nbr_searcher,
                                                        self.nbrlst_to_edge_mask,
                                                        'all')
                edge_idx_lst += [edge_idx_tsr]
            gt = torch.cat(gt_lst, dim=0)

            pred = self.pnet_model(pos_lst,
                                   feat,
                                   edge_idx_lst,
                                   )
            ratio = torch.sqrt((pred.reshape(-1) - gt.reshape(-1)) ** 2) / (torch.abs(pred.reshape(-1)) + 1e-8)
            outlier_ratio = ratio[ratio > 10.].shape[0] / ratio.shape[0]
            mse = nn.MSELoss()(pred, gt)
            mae = nn.L1Loss()(pred, gt)
            self.log('val outlier', outlier_ratio, prog_bar=True, logger=True)
            self.log('val mse', mse, prog_bar=True, logger=True)
            self.log('val mae', mae, prog_bar=True, logger=True)


class ModelCheckpointAtEpochEnd(pl.Callback):
    """
       Save a checkpoint at epoch end
    """
    def __init__(
            self,
            filepath,
            save_step_frequency,
            prefix="checkpoint",
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
        """
        self.filepath = filepath
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix

    def on_epoch_end(self, trainer: pl.Trainer, pl_module: ParticleNetLightning):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        if epoch % self.save_step_frequency == 0 or epoch == pl_module.epoch_num -1:
            filename = os.path.join(self.filepath, f"{self.prefix}_{epoch}.ckpt")
            scaler_filename = os.path.join(self.filepath, f"scaler_{epoch}.npz")

            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)
            np.savez(scaler_filename,
                     mean=pl_module.training_mean,
                     var=pl_module.training_var,
                     )
            # joblib.dump(pl_module.train_data_scaler, scaler_filename)


def train_model(args):
    lr = args.lr
    num_gpu = args.num_gpu
    check_point_dir = args.cp_dir
    min_epoch = args.min_epoch
    max_epoch = args.max_epoch
    weight_ckpt = args.state_ckpt_dir
    batch_size = args.batch_size

    model = ParticleNetLightning(epoch_num=max_epoch,
                                 num_device=num_gpu if num_gpu != -1 else 1,
                                 learning_rate=lr,
                                 model_weights_ckpt=weight_ckpt,
                                 batch_size=batch_size,
                                 args=args)
    cwd = os.getcwd()
    model_check_point_dir = os.path.join(cwd, check_point_dir)
    os.makedirs(model_check_point_dir, exist_ok=True)
    epoch_end_callback = ModelCheckpointAtEpochEnd(filepath=model_check_point_dir, save_step_frequency=5)
    checkpoint_callback = pl.callbacks.ModelCheckpoint()

    trainer = Trainer(gpus=num_gpu,
                      callbacks=[epoch_end_callback, checkpoint_callback],
                      min_epochs=min_epoch,
                      max_epochs=max_epoch,
                      amp_backend='apex',
                      amp_level='O1',
                      benchmark=True,
                      distributed_backend='ddp',
                      )
    trainer.fit(model)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--min_epoch', default=30, type=int)
    parser.add_argument('--max_epoch', default=30, type=int)
    parser.add_argument('--lr', default=3e-4, type=float)
    parser.add_argument('--cp_dir', default='./model_ckpt')
    parser.add_argument('--state_ckpt_dir', default=None, type=str)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--encoding_size', default=256, type=int)
    parser.add_argument('--hidden_dim', default=128, type=int)
    parser.add_argument('--edge_embedding_dim', default=256, type=int)
    parser.add_argument('--drop_edge', action='store_true')
    parser.add_argument('--use_layer_norm', action='store_true')
    parser.add_argument('--disable_rotate_aug', dest='rotate_aug', default=True, action='store_false')
    parser.add_argument('--data_dir', default='./md_dataset')
    parser.add_argument('--loss', default='mae')
    parser.add_argument('--num_gpu', default=-1, type=int)
    args = parser.parse_args()
    train_model(args)


if __name__ == '__main__':
    main()

### Train for DFT test-cases

In [None]:
import argparse
import os, sys
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import cupy

sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
from nn_module import WaterMDDynamicBoxNet
from train_utils import WaterDataRealLarge
# os.environ["CUDA_VISIBLE_DEVICES"] = "" # just to test if it works w/o gpu
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# for water box
# CUTOFF_RADIUS = 9.5
# left_bound = 0.0
# right_bound = 12.4172
# BOX_SIZE = right_bound - left_bound
# NUM_OF_ATOMS = 258 * 3

# CUTOFF_RADIUS = 3.4

LAMBDA1 = 100.
LAMBDA2 = 0.5e-2


def create_water_bond(total_atom_num):
    bond = []
    for i in range(0, total_atom_num, 3):
        bond += [[i, i+1], [i, i+2]]
    return np.array(bond)


def get_rotation_matrix():
    """ Randomly rotate the point clouds to augument the dataset
        rotation is per shape based along up direction
        Input:
          Nx3 array, original point clouds
        Return:
          Nx3 array, rotated point clouds
    """
    if np.random.uniform() < 0.3:
        angles = np.random.randint(-2, 2, size=(3,)) * np.pi
    else:
        angles = [0., 0., 0.]
    Rx = np.array([[1., 0, 0],
                       [0, np.cos(angles[0]), -np.sin(angles[0])],
                       [0, np.sin(angles[0]), np.cos(angles[0])]], dtype=np.float32)
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                       [0, 1, 0],
                       [-np.sin(angles[1]), 0, np.cos(angles[1])]], dtype=np.float32)
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                   [np.sin(angles[2]), np.cos(angles[2]), 0],
                   [0, 0, 1]], dtype=np.float32)
    rotation_matrix = np.matmul(Rz, np.matmul(Ry, Rx))

    return rotation_matrix


def center_positions(pos):
    offset = np.mean(pos, axis=0)
    return pos - offset, offset


def build_model(args, ckpt=None):
    # bond_info = create_water_bond(NUM_OF_ATOMS)
    # print(bond_info)
    param_dict = {'in_feats': 1,
                'encoding_size': args.encoding_size,
                'out_feats': 3,
                #'bond': bond_info,
                'hidden_dim': args.hidden_dim,
                'edge_embedding_dim': args.edge_embedding_dim,
                'conv_layer': args.conv_layer,
                'drop_edge': args.drop_edge,
                'use_layer_norm': args.use_layer_norm,
                'update_edge': args.update_edge,
                }
    # small model
    # param_dict = {'in_feats': 1,
    #               'encoding_size': 512,
    #               'out_feats': 3,
    #               'bond': bond_info,
    #               'hidden_dim': 256,
    #               'conv_layer': 5,
    #               }

    print("Using following set of hyper-parameters")
    print(args)

    # print(param_dict)
    model = WaterMDDynamicBoxNet(**param_dict, expand_edge=args.expand_edge)

    if ckpt is not None:
        print('Loading model weights from: ', ckpt)
        model.load_state_dict((torch.load(ckpt)))
    return model


class ParticleNetLightning(pl.LightningModule):
    def __init__(self, args, num_device=1, epoch_num=100, batch_size=1, learning_rate=3e-4, log_freq=1000,
                 model_weights_ckpt=None, scaler_ckpt=None):
        super(ParticleNetLightning, self).__init__()
        self.pnet_model = build_model(args, model_weights_ckpt)
        self.epoch_num = epoch_num
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_device = num_device
        self.log_freq = log_freq
        self.train_data_scaler = StandardScaler()
        self.training_mean = np.array([0.])
        self.training_var = np.array([1.])

        if scaler_ckpt is not None:
            self.load_training_stats(scaler_ckpt)

        self.cutoff = args.cutoff
        self.rotate_aug = args.rotate_aug
        self.data_dir = args.data_dir
        self.loss_fn = args.loss
        self.use_part = args.use_part
        assert self.loss_fn in ['mae', 'mse']


    def load_training_stats(self, scaler_ckpt):
        if scaler_ckpt is not None:
            scaler_info = np.load(scaler_ckpt)
            self.training_mean = scaler_info['mean']
            self.training_var = scaler_info['var']

    def forward(self, pos, feat, edge_idx_tsr):
        return self.denormalize(self.pnet_model(pos, feat, edge_idx_tsr.long()), self.training_var, self.training_mean)

    def build_graph(self, edge_idx):
        return self.pnet_model.build_partial_graph(edge_idx)

    def denormalize(self, normalized_force, var, mean):
        return normalized_force * \
                np.sqrt(var) +\
                mean

    def predict_forces(self, feat: torch.Tensor, pos: np.ndarray, box_size):
        # enforce periodic boundary
        pos = np.mod(pos, box_size)
        pos = torch.from_numpy(pos).float().to(feat.device)
        pred = self.pnet_model([pos],
                               feat,
                               [box_size],
                               self.cutoff
                               )

        pred = pred.detach().cpu().numpy()

        pred = self.denormalize(pred, self.training_var, self.training_mean)

        return pred

    def scale_force(self, force, scaler):
        b_pnum, dims = force.shape
        force_flat = force.reshape((-1, 1))
        scaler.partial_fit(force_flat)
        force = torch.from_numpy(scaler.transform(force_flat)).float().view(b_pnum, dims)
        return force

    def training_step(self, batch, batch_nb):
        feat, pos_lst, box_size_lst = batch['feat'], batch['pos'], batch['box_size']
        gt_lst = batch['forces']

        for b in range(len(gt_lst)):
            pos, box_size, gt = pos_lst[b], box_size_lst[b], gt_lst[b]
            pos, off = center_positions(pos)
            R = get_rotation_matrix()
            pos = np.matmul(pos, R)
            pos += off
            box_size = np.matmul(box_size, R)
            box_size_lst[b] = box_size
            pos = np.mod(pos, box_size)
            gt = np.matmul(gt, R)
            gt = self.scale_force(gt, self.train_data_scaler).to(feat.device)
            pos_lst[b] = torch.from_numpy(pos).float().to(feat.device)
            gt_lst[b] = gt
        gt = torch.cat(gt_lst, dim=0)

        # enforce periodic boundary
        pos_lst = [pos + torch.randn_like(pos) * 0.00025 for pos in pos_lst]
        pred = self.pnet_model(pos_lst,
                               feat,
                               box_size_lst,
                               self.cutoff
                               )
        epoch = self.current_epoch
        # if epoch > 5:
        #     mae = nn.L1Loss()(pred, gt)
        # else:
        if self.loss_fn == 'mae':
            loss = nn.L1Loss()(pred, gt)
        else:
            loss = nn.MSELoss()(pred, gt)

        conservative_loss = (torch.mean(pred)).abs()
        loss = loss + LAMBDA2*conservative_loss

        self.training_mean = self.train_data_scaler.mean_
        self.training_var = self.train_data_scaler.var_

        self.log('total loss', loss, on_step=True, prog_bar=True, logger=True)
        self.log(f'{self.loss_fn} loss', loss, on_step=True, prog_bar=True, logger=True)
        self.log('var', np.sqrt(self.training_var), on_step=True, prog_bar=True, logger=True)

        # self.log('regularization', conservative_loss, on_step=True, prog_bar=True, logger=True)

        return loss

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        # optim = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9)
        sched = StepLR(optim, step_size=100, gamma=0.001**(100/self.epoch_num))
        return [optim], [sched]

    def train_dataloader(self):
        dataset = WaterDataRealLarge(dataset_path=os.path.join(self.data_dir, 'RPBE-data-processed.npz'), use_part=self.use_part)
        return DataLoader(dataset, num_workers=2, batch_size=self.batch_size, shuffle=True,
                          collate_fn=
                          lambda batches: {
                              'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                              'pos': [batch['pos'] for batch in batches],
                              'forces': [batch['forces'] for batch in batches],
                              'box_size': [batch['box_size'] for batch in batches],
                          })

    def val_dataloader(self):
        dataset = WaterDataRealLarge(dataset_path=os.path.join(self.data_dir, 'RPBE-data-processed.npz'),
                                     mode='test')
        return DataLoader(dataset, num_workers=2, batch_size=self.batch_size*2, shuffle=False,
                          collate_fn=
                          lambda batches: {
                              'feat': torch.cat([torch.from_numpy(batch['feat']).float() for batch in batches], dim=0),
                              'pos': [batch['pos'] for batch in batches],
                              'forces': [batch['forces'] for batch in batches],
                              'box_size': [batch['box_size'] for batch in batches],
                          })

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            feat, pos_lst, box_size_lst = batch['feat'], batch['pos'], batch['box_size']
            gt_lst = batch['forces']

            for b in range(len(gt_lst)):
                pos, box_size, gt = pos_lst[b], box_size_lst[b], gt_lst[b]

                pos = np.mod(pos, box_size)
                gt = self.scale_force(gt, self.train_data_scaler).to(feat.device)
                pos_lst[b] = torch.from_numpy(pos).float().to(feat.device)
                gt_lst[b] = gt
            gt = torch.cat(gt_lst, dim=0)
            # enforce periodic boundary
            pred = self.pnet_model(pos_lst,
                                   feat,
                                   box_size_lst,
                                   self.cutoff
                                   )
            mse = nn.MSELoss()(pred, gt)
            mae = nn.L1Loss()(pred, gt)

            self.training_mean = self.train_data_scaler.mean_
            self.training_var = self.train_data_scaler.var_

            self.log('val mse', mse, prog_bar=True, logger=True)
            self.log('val mae', mae, prog_bar=True, logger=True)


class ModelCheckpointAtEpochEnd(pl.Callback):
    """
       Save a checkpoint at epoch end
    """
    def __init__(
            self,
            filepath,
            save_step_frequency,
            prefix="checkpoint",
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
        """
        self.filepath = filepath
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix

    def on_epoch_end(self, trainer: pl.Trainer, pl_module: ParticleNetLightning):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        if epoch % self.save_step_frequency == 0 or epoch == pl_module.epoch_num -1:
            filename = os.path.join(self.filepath, f"{self.prefix}_{epoch}.ckpt")
            scaler_filename = os.path.join(self.filepath, f"scaler_{epoch}.npz")

            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)
            np.savez(scaler_filename,
                     mean=pl_module.training_mean,
                     var=pl_module.training_var,
                     )
            # joblib.dump(pl_module.train_data_scaler, scaler_filename)


def train_model(args):
    lr = args.lr
    num_gpu = args.num_gpu
    check_point_dir = args.cp_dir
    min_epoch = args.min_epoch
    max_epoch = args.max_epoch
    weight_ckpt = args.state_ckpt_dir
    batch_size = args.batch_size

    model = ParticleNetLightning(epoch_num=max_epoch,
                                 num_device=num_gpu if num_gpu != -1 else 1,
                                 learning_rate=lr,
                                 model_weights_ckpt=weight_ckpt,
                                 batch_size=batch_size,
                                 args=args)
    cwd = os.getcwd()
    model_check_point_dir = os.path.join(cwd, check_point_dir)
    os.makedirs(model_check_point_dir, exist_ok=True)
    epoch_end_callback = ModelCheckpointAtEpochEnd(filepath=model_check_point_dir, save_step_frequency=50)
    checkpoint_callback = pl.callbacks.ModelCheckpoint()

    trainer = Trainer(gpus=num_gpu,
                      callbacks=[epoch_end_callback, checkpoint_callback],
                      min_epochs=min_epoch,
                      max_epochs=max_epoch,
                      amp_backend='apex',
                      amp_level='O2',
                      benchmark=True,
                      distributed_backend='ddp',
                      )
    trainer.fit(model)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--min_epoch', default=800, type=int)
    parser.add_argument('--max_epoch', default=800, type=int)
    parser.add_argument('--lr', default=1e-4, type=float)
    parser.add_argument('--cp_dir', default='./model_ckpt')
    parser.add_argument('--state_ckpt_dir', default=None, type=str)

    parser.add_argument('--batch_size', default=8, type=int)
    parser.add_argument('--encoding_size', default=256, type=int)
    parser.add_argument('--hidden_dim', default=128, type=int)
    parser.add_argument('--edge_embedding_dim', default=256, type=int)
    parser.add_argument('--cutoff', default=9.5, type=float)
    parser.add_argument('--conv_layer', default=5, type=int)
    parser.add_argument('--drop_edge', action='store_true')
    parser.add_argument('--use_layer_norm', action='store_true')
    parser.add_argument('--update_edge', action='store_true')
    parser.add_argument('--disable_expand_edge', dest='expand_edge', default=True, action='store_false')

    parser.add_argument('--disable_rotate_aug', dest='rotate_aug', default=True, action='store_false')
    parser.add_argument('--data_dir', default='./md_dataset')
    parser.add_argument('--use_part', action='store_true')    # use only part of the training data?
    parser.add_argument('--loss', default='mae')
    parser.add_argument('--num_gpu', default=-1, type=int)
    args = parser.parse_args()
    train_model(args)


if __name__ == '__main__':
    main()



### Setup model and dataloaders

In [None]:
def load_model_and_dataset(gamdnet_model_filename, gamdnet_official_model_checkpoint_filename, md_filedir):
    '''
    Load model and MD dataset for SR from input filename and dataset directory.
    '''

    '''
    embed_dim = 128
    hidden_dim = 128
    num_mpnn_layers = 4 # as per paper, for LJ system
    num_mlp_layers = 3
    num_atom_type_classes = 1 # Ar atoms only
    num_edge_types = 1 # non-bonded edges only
    num_rbfs = 10 # RBF expansion of interatomic distance vector of each edge to num_rbfs dimensions
    gamdnet = GAMDNet(embed_dim, hidden_dim, num_mpnn_layers, num_mlp_layers, num_atom_type_classes, num_edge_types, num_rbfs).to(device)
    # Load the weights from 'model.pt'
    # Load the checkpoint from 'model.pt'
    checkpoint = torch.load(gamdnet_model_filename)
    gamdnet.load_state_dict(checkpoint['model_state_dict'])
    # Set the model to evaluation mode
    gamdnet.eval()

    print("GAMD model weights loaded successfully.")
    '''
    train_data_fraction = 1.0 # select 9k for training
    avg_num_neighbors = 20 # criteria for connectivity of atoms for any frame
    rotation_aug = False # online rotation augmentation for a frame
    # create train data-loader
    return_train_data = True
    num_input_files = 1# 40#len(os.listdir(md_filedir))
    batch_size = num_input_files # number of graphs in a batch
    print("Loading input files: ", num_input_files)
    #print("Files are: ", os.listdir(md_filedir))
    dataset = MDDataset(md_filedir, rotation_aug, avg_num_neighbors, train_data_fraction, return_train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
    print("Dataloader initialized.")


    print("Loading official GAMDNET...")
    param_dict = {
                'encoding_size': 128,
                'out_feats': 3,
                'hidden_dim': 128,
                'edge_embedding_dim': 128,
                'conv_layer': 4,
                'drop_edge': False,
                'use_layer_norm': True,
                'box_size': 27.27,
                }
    gamdnet_official = SimpleMDNetNew(**param_dict).to(device)
    checkpoint = torch.load(gamdnet_official_model_checkpoint_filename, map_location='cuda:0')

    state_dict_original = checkpoint['state_dict']

    # Define the prefix to remove
    prefix_to_remove = 'pnet_model.'

    # Create a new dictionary with updated keys
    state_dict_without_prefix = {
        key[len(prefix_to_remove):]: value
        for key, value in state_dict_original.items()
        if key.startswith(prefix_to_remove)
    }

    gamdnet_official.load_state_dict(state_dict_without_prefix)

    print("GAMD official model weights loaded successfully.")

    gamdnet_official.eval()

    #return gamdnet, None, dataloader
    return None, gamdnet_official, dataloader


### Collect edge messages and LJ potential for analysis

In [None]:
def generate_edge_msg_dict(self, gamdnet_official, dataloader):
    msg_force_dict = {}
    # run inference over the input batched graph from dataloader
    # record aggregate edge messages and force ground truths for each node in output dictionary
    for pos, edge_index_list, force_gt in dataloader:
        with torch.no_grad():  # Disable gradient calculation for inference
            # our implementation
            '''
            force_pred = gamdnet(pos, edge_index_list)  # Forward pass through the model
            msg_force_dict['edge_messages'] = gamdnet.mpnn_mlps.mp_blocks[3].edge_message_neigh_center
            evaluate(force_gt, force_pred)
            '''
            # official implementation
            force_pred_official = gamdnet_official([pos],
                               [edge_index_list])
            msg_force_dict['edge_messages'] = gamdnet_official.graph_conv.conv[-1].edge_message_neigh_center
            msg_force_dict['aggregate_edge_messages'] = gamdnet_official.graph_conv.conv[-1].aggregate_edge_messages

            print("Results from official model:")
            evaluate(force_gt, force_pred_official)

            # record messages for SR
            lj_force, lj_potential, radial_distance, valid_indices, dx, dy, dz = compute_lj_force_and_potential(pos, edge_index_list)




            # remove nans
            lj_force = lj_force[valid_indices]
            msg_force_dict['edge_messages'] = msg_force_dict['edge_messages'][valid_indices]
            msg_force_dict['radial_distance'] = radial_distance[valid_indices]

            # Define a threshold for closeness to zero
            threshold = 1e-5

            # Identify rows where all elements are close to zero
            close_to_zero_rows = (torch.abs(lj_force) < threshold).all(dim=1)

            # Create a mask for rows that are NOT close to zero
            non_zero_rows_mask = ~close_to_zero_rows



            print("Before zero row removal, ", lj_force.shape, msg_force_dict['edge_messages'].shape)
            # Filter the tensor to keep only the non-zero rows
            lj_force = lj_force[non_zero_rows_mask]
            msg_force_dict['edge_messages'] = msg_force_dict['edge_messages'][non_zero_rows_mask]
            msg_force_dict['force_gt'] = lj_force # [num_particle * batch_size, 3]
            msg_force_dict['radial_distance'] = msg_force_dict['radial_distance'][non_zero_rows_mask]
            msg_force_dict['dx'] = dx[valid_indices][non_zero_rows_mask]
            msg_force_dict['dy'] = dy[valid_indices][non_zero_rows_mask]
            msg_force_dict['dz'] = dz[valid_indices][non_zero_rows_mask]
            msg_force_dict['potential_gt'] = lj_potential[valid_indices][non_zero_rows_mask]
            msg_force_dict['net_force_gt'] = force_gt
            msg_force_dict['pos'] = pos
            msg_force_dict['node_embeddings'] = gamdnet_official.graph_conv.conv[-1].node_embeddings
        break # run dataloader only once
    return sr_inputs, sr_outputs

In [None]:
def load_sr_inputs_and_outputs(self, sr_config):

  if sr_config['use_pregenerated_edge_msgs?'] == 'n':
    return generate_edge_msg_dict(self, gamdnet_official, dataloader)

  # else, read edge messages from input file
  using PyCall

  # Import the pickle module from Python
  @pyimport pickle

  # Function to load a pickle file
  function load_pickle(filename)
      try
          # Read the entire content of the file as bytes
          bytes = read(filename)  # This reads the file content into a Vector{UInt8}

          # Use pickle.loads() to load the data from bytes
          return pickle.loads(pybytes(bytes))  # Convert bytes to a Python bytes object
      catch e
          println("Error loading pickle file: ", e)
          return nothing  # Return nothing if there was an error
      end
  end

  # Example usage
  #msg_force_dict_pkl_filename = "/content/msg_force_dict_epoch=29-step=270000_edge_msg_constrained_std_trained_over_9k_samples_our_run_5.pkl"
  #msg_force_dict_pkl_filename = "/content/msg_force_dict_epoch=29-step=270000_edge_msg_constrained_std_trained_over_9k_samples_our_run_5_on_gamd_dataset.pkl"
  #msg_force_dict_pkl_filename = "/content/msg_force_dict_epoch=29-step=135000_edge_msg_constrained_std_trained_over_4.5k_samples_custom_potential.pkl"
  msg_force_dict_pkl_filename = "msg_force_dict_epoch=39-step=360000_edge_msg_constrained_std.pkl"
  msg_force_dict = load_pickle(msg_force_dict_pkl_filename)

  # Displaying the loaded data
  if msg_force_dict !== nothing
      println("Loaded data: ", msg_force_dict.keys)
  else
      println("Failed to load data.")
  end

  return sr_inputs, sr_outputs

### Define SR inputs and outputs

In [None]:
def preprocess_sr_input_outputs(sr_inputs, sr_outputs, sr_config):
  using DataFrames

  X = msg_force_dict["radial_distance"].cpu().numpy()


  edge_messages_julia = msg_force_dict["edge_messages"].cpu().numpy()
  Y = edge_messages_julia[:, 3] # Get all rows and the first column (Julia indexing starts at 1)

  # Create a mask
  mask = (0 .<= Y) #.& (10.0 .>= X)

  # Filter X and Y using the mask
  X = X[mask]
  Y = Y[mask]

  # Calculate the standard deviation of Y values
  std_dev_Y = np.std(Y)

  # Print or return the standard deviation
  println("Standard Deviation of Y values: ", std_dev_Y)

  # Plotting Y as a function of X
  plot(X, Y, seriestype = :scatter, label = "Edge messages", xlabel = "Radial Distance (X)", ylabel = "Edge messages (Y)", title = "Edge messages vs Radial Distance", legend = true)


  """# Downsample edge messages (to map it to a function)"""

  using Statistics
  # Define the window size for downsampling
  window_size = 0.1  # Adjust this value as needed

  # Create arrays to store downsampled results
  downsampled_X = Float64[]
  downsampled_Y = Float64[]

  # Get unique X values for downsampling
  unique_X_values = sort(unique(X))

  # Iterate over unique X values and compute averages within the window
  for x in unique_X_values
      # Find indices of Y values within the window around x
      indices_in_window = findall((X .>= (x - window_size / 2)) .& (X .<= (x + window_size / 2)))

      if !isempty(indices_in_window)
          # Calculate average of Y values in this window
          avg_Y = mean(Y[indices_in_window])

          # Check if x is already in downsampled_X before adding it
          if x in downsampled_X
              continue  # Skip if x is already present
          else
              # Append results to downsampled arrays
              push!(downsampled_X, x)
              push!(downsampled_Y, avg_Y)
          end
      end
  end

  # Convert results to arrays if needed
  X = collect(downsampled_X)
  Y = collect(downsampled_Y)

  # Print or return the downsampled results
  println("Downsampled X: ", X)
  println("Downsampled Y: ", Y)

  # Plotting downsampled Y as a function of downsampled X
  plot(X, Y, seriestype = :scatter, label = "Downsampled Edge messages", xlabel = "Radial Distance (X)", ylabel = "Average Edge messages (Y)", title = "Downsampled Edge messages vs Radial Distance", legend = true)

  return sr_inputs, sr_outputs


### Test for linearity of fit between pair potential and pair-force vs edge messages

In [None]:
def compute_lj_force_and_potential(pos, edge_index_list):

    center_node_idx = edge_index_list[0, :]
    neigh_node_idx = edge_index_list[1, :]

    neigh_node_pos = pos[neigh_node_idx]
    center_node_pos = pos[center_node_idx]

    # Calculate the distance vector
    r_vec = neigh_node_pos - center_node_pos  # Shape: [n, 3]

    # Calculate the distance (magnitude)
    r = torch.norm(r_vec, dim=1).unsqueeze(1)  # Shape: [n, 1]

    epsilon = 0.238
    sigma = 3.4
    force_magnitude = 48 * epsilon * (
        ((sigma ** 12) / (r ** 13)) -
        ((sigma ** 6) / (r ** 7))
    )  # Shape: [n, 1]

    # Calculate the force vector (directed)
    force_vector = force_magnitude * (r_vec / r)  # Shape: [n, 3]

    potential_magnitude = 4 * epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6)

    potential_vector = potential_magnitude * (r_vec / r)

    nan_mask = torch.isnan(force_vector).any(dim=1)
    valid_indices = ~nan_mask

    dx = r_vec[:, 0]
    dy = r_vec[:, 1]
    dz = r_vec[:, 2]

    return force_vector, potential_vector, r, valid_indices, dx, dy, dz

In [None]:
from scipy.optimize import minimize

msg_most_imp = None
expected_forces = None
expected_potentials = None

def percentile_sum(x):
    x = x.ravel()
    bot = x.min()
    top = np.percentile(x, 90)
    msk = (x>=bot) & (x<=top)
    frac_good = (msk).sum()/len(x)
    return x[msk].sum()/frac_good


def linear_transformation_3d_force(alpha):

    global msg_most_imp
    global expected_forces

    lincomb1 = (alpha[0] * expected_forces[:, 0] + alpha[1] * expected_forces[:, 1] + alpha[2] * expected_forces[:, 2]) + alpha[3]
    lincomb2 = (alpha[0+4] * expected_forces[:, 0] + alpha[1+4] * expected_forces[:, 1] + alpha[2+4] * expected_forces[:, 2]) + alpha[3+4]
    lincomb3 = (alpha[0+8] * expected_forces[:, 0] + alpha[1+8] * expected_forces[:, 1] + alpha[2+8] * expected_forces[:, 2]) + alpha[3+8]

    score = (
        percentile_sum(np.square(msg_most_imp[:, 0] - lincomb1)) +
        percentile_sum(np.square(msg_most_imp[:, 1] - lincomb2)) +
        percentile_sum(np.square(msg_most_imp[:, 2] - lincomb3))
    )/3.0
    '''
    score = np.mean([np.abs(msg_most_imp[:, 0] - lincomb1) +
        np.abs(msg_most_imp[:, 1] - lincomb2) +
        np.abs(msg_most_imp[:, 2] - lincomb3)]) / 3.0

    print("Alpha now is: ", alpha)
    print("Score now is: ", score)
    '''
    return score


def out_linear_transformation_3d_force(alpha):

    global msg_most_imp
    global expected_forces

    lincomb1 = (alpha[0] * expected_forces[:, 0] + alpha[1] * expected_forces[:, 1] + alpha[2] * expected_forces[:, 2]) + alpha[3]
    lincomb2 = (alpha[0+4] * expected_forces[:, 0] + alpha[1+4] * expected_forces[:, 1] + alpha[2+4] * expected_forces[:, 2]) + alpha[3+4]
    lincomb3 = (alpha[0+8] * expected_forces[:, 0] + alpha[1+8] * expected_forces[:, 1] + alpha[2+8] * expected_forces[:, 2]) + alpha[3+8]
    return lincomb1, lincomb2, lincomb3


def linear_transformation_3d_potential(alpha):

    global msg_most_imp
    global expected_potentials

    lincomb1 = (alpha[0] * expected_potentials[:, 0] + alpha[1] * expected_potentials[:, 1] + alpha[2] * expected_potentials[:, 2]) + alpha[3]
    lincomb2 = (alpha[0+4] * expected_potentials[:, 0] + alpha[1+4] * expected_potentials[:, 1] + alpha[2+4] * expected_potentials[:, 2]) + alpha[3+4]
    lincomb3 = (alpha[0+8] * expected_potentials[:, 0] + alpha[1+8] * expected_potentials[:, 1] + alpha[2+8] * expected_potentials[:, 2]) + alpha[3+8]


    score = (
        percentile_sum(np.square(msg_most_imp[:, 0] - lincomb1)) +
        percentile_sum(np.square(msg_most_imp[:, 1] - lincomb2)) +
        percentile_sum(np.square(msg_most_imp[:, 2] - lincomb3))
    )/3.0
    '''
    score = np.mean([np.abs(msg_most_imp[:, 0] - lincomb1) + np.abs(msg_most_imp[:, 1] - lincomb2) + np.abs(msg_most_imp[:, 2] - lincomb3)]) / 3.0

    print("Alpha now is: ", alpha)
    print("Score now is: ", score)
    '''
    return score


def out_linear_transformation_3d_potential(alpha):

    global msg_most_imp
    global expected_potentials

    lincomb1 = (alpha[0] * expected_potentials[:, 0] + alpha[1] * expected_potentials[:, 1] + alpha[2] * expected_potentials[:, 2]) + alpha[3]
    lincomb2 = (alpha[0+4] * expected_potentials[:, 0] + alpha[1+4] * expected_potentials[:, 1] + alpha[2+4] * expected_potentials[:, 2]) + alpha[3+4]
    lincomb3 = (alpha[0+8] * expected_potentials[:, 0] + alpha[1+8] * expected_potentials[:, 1] + alpha[2+8] * expected_potentials[:, 2]) + alpha[3+8]

    print("alphas: ", alpha)
    return lincomb1, lincomb2,  lincomb3



def are_edge_msgs_gt_force_correlated(msg_force_dict):
    '''
    msg_force_dict: {'edge_messages': [total_edges, emb_dim], 'gt_force': [total_edges, 3]}
    '''
    global msg_most_imp
    global expected_forces

    print("edge message shape: ", msg_force_dict['edge_messages'].shape)
    print("force shape: ", msg_force_dict['force_gt'].shape)

    # Calculate variance for each component of agg_msg across all samples
    msg_comp_std = torch.std(msg_force_dict['edge_messages'], axis=0)  # Variance for each component

    # Step 3: Get top-3 indices based on variance
    top_std_indices = torch.argsort(msg_comp_std)[-3:]  # Get indices of top-3 components with maximum variance

    # Prepare data for linear regression using top-3 components as output variables
    msg_most_imp = msg_force_dict['edge_messages'][:, top_std_indices].cpu()  # Select only the top-3 components


    # normalize the messages
    #msg_most_imp = ((msg_most_imp - torch.mean(msg_most_imp, axis=0)) / torch.std(msg_most_imp, axis=0)).cpu()

    expected_forces = msg_force_dict['force_gt'].cpu()


    dim = 3
    min_result = minimize(linear_transformation_3d_force, np.ones(dim**2 + dim), method='Powell')

    print("Fit score: ", min_result.fun/msg_force_dict['edge_messages'].shape[0])


    # Visualize the fit
    for i in range(dim):
        px = out_linear_transformation_3d_force(min_result.x)[i]
        py = msg_most_imp[:, i]
        plt.scatter(px, py)
        plt.show()

    are_correlated = False
    return are_correlated, msg_most_imp



def are_edge_msgs_gt_potential_correlated(msg_force_dict):
    '''
    msg_force_dict: {'edge_messages': [total_edges, emb_dim], 'gt_force': [total_edges, 3]}
    '''
    global msg_most_imp
    global expected_potentials

    print("edge message shape: ", msg_force_dict['edge_messages'].shape)
    print("force shape: ", msg_force_dict['potential_gt'].shape)

    # Calculate variance for each component of agg_msg across all samples
    msg_comp_std = torch.std(msg_force_dict['edge_messages'], axis=0)  # Variance for each component

    # Step 3: Get top-3 indices based on variance
    top_std_indices = torch.argsort(msg_comp_std)[-3:]  # Get indices of top-3 components with maximum variance

    # Prepare data for linear regression using top-3 components as output variables
    msg_most_imp = msg_force_dict['edge_messages'][:, top_std_indices].cpu()  # Select only the top-3 components


    # normalize the messages
    #msg_most_imp = ((msg_most_imp - torch.mean(msg_most_imp, axis=0)) / torch.std(msg_most_imp, axis=0)).cpu()

    expected_potentials = msg_force_dict['potential_gt'].cpu()


    dim = 3
    min_result = minimize(linear_transformation_3d_potential, np.ones(dim**2 + dim), method='Powell')

    print("Fit score: ", min_result.fun/msg_force_dict['edge_messages'].shape[0])

    # Visualize the fit
    for i in range(dim):
        px = out_linear_transformation_3d_potential(min_result.x)[i]
        py = msg_most_imp[:, i]
        plt.scatter(px, py)
        plt.show()

    are_correlated = False
    return are_correlated, msg_most_imp


### Train SR

In [None]:
def regress_edge_message_equations(sr_inputs, sr_outputs, sr_config):
  using SymbolicRegression
  function lj_potential_structure((; attr_func, rep_func), (rad, ))
    _attr_func = attr_func(rad)^-12
    _rep_func = rep_func(rad)^-6

    out = map((attr_func_i, rep_func_i) -> (attr_func_i - rep_func_i), _attr_func.x, _rep_func.x)
    return ValidVector(out, _attr_func.valid && _rep_func.valid)
  end
  lj_structure = TemplateStructure{(:attr_func, :rep_func)}(lj_potential_structure)

  elementwise_loss = ((x1), (y1)) -> abs(y1 - x1)

  using MLJBase

  model = SRRegressor(;
      niterations=10000,
      selection_method=SymbolicRegression.MLJInterfaceModule.choose_best,
      binary_operators=(*, /),
      maxsize=25,
      elementwise_loss=elementwise_loss,
      #expression_type=TemplateExpression,
      # Note - this is where we pass custom options to the expression type:
      #expression_options=(; structure = lj_structure),
      batching=true,
  )


  mach = machine(model, X, Y)
  fit!(mach)
  return mach

### Plot pred vs gt curve for best SR equation

In [None]:
def plot_pred_vs_gt_sr(mach):
  r = report(mach)
  idx = r.best_idx
  best_expr = r.equations[idx]
  print("Best equation: ", best_expr)
  best_attr = get_contents(best_expr).attr_func
  best_rep = get_contents(best_expr).rep_func

  print("\nAttr term: ", best_attr)
  print("\nRep term: ", best_rep)

  y_pred = predict(mach, X)
  # Plotting Y as a function of X
  plot(y_pred, Y, seriestype = :scatter, label = "GT vs Pred (x-axis)", xlabel = "Pred", ylabel = "GT", title = "GT vs Pred", legend = true)
  savefig("pred_vs_gt_without_temp_exp.png")

### Algorithm alignment of GNNs ?

### Benefits of algo. alignment ?

### GAMD-SR class for setting up the E2E process

In [None]:
class GAMDSR(object):
  def __init__(self):
    pass

  def run(self, question_dict):

    # parse the question dict to setup pipline to answer that question.
    # setup dataset, dataloader and GAMD model
    dataset_type = self.dataset_config.dataset_type
    generate_dataset = self.dataset_config.generate_dataset
    dataset_dir = self.dataset_config.dataset_dir
    openmm_config = self.dataset_config.openmm_config

    if dataset_type == 'lj':
      if generate_dataset:
        raw_dataset_dir = generate_lj_dataset(openmm_config)
      else:
        raw_dataset_dir = dataset_dir
      dataset = LJDataset(raw_dataset_dir)
      dataloader = Dataloader()
      model = SimpleMDNet(gamd_model_config)
      sr_inputs = dict(radial_dist, dx, dy, dz) # TODO
    elif dataset_type == 'tip3p':
      if generate_dataset:
        raw_dataset_dir = generate_tip3p_dataset(openmm_config)
      else:
        raw_dataset_dir = dataset_dir
      dataset = WarerDatasetNew(raw_dataset_dir)
      dataloader = Dataloader(dataset)
      model = WaterMDNet(gamd_model_config)
      sr_inputs = dict(radial_dist, dx, dy, dz, theta, charge1, charge2) # TODO
    elif dataset_type == 'tip4p':
      if generate_dataset:
        raw_dataset_dir = generate_tip4p_dataset(openmm_config)
      else:
        raw_dataset_dir = dataset_dir
      dataset = WaterDatasetNew(raw_dataset_dir)
      dataloader = Dataloader(dataset)
      model = WaterMDNet(gamd_model_config)
      sr_inputs = dict(radial_dist, dx, dy, dz, theta, charge1, charge2) # TODO
    else:
      if generate_dataset:
        raw_dataset_dir = generate_dft_dataset(openmm_config)
      else:
        raw_dataset_dir = dataset_dir
      dataset = WaterDatasetLarge(raw_dataset_dir)
      dataloader = Dataloader(dataset)
      model = WaterMDNetDynamic(gamd_model_config)
      sr_inputs = dict(radial_dist, dx, dy, dz, theta, charge1, charge2) # TODO
    if gamd_model_config.model == 'train':
      model.fit()
    if gamd_model_config.mode == 'eval':
      model.load_state_dict(gamd_model_config.checkpoint_dir)


    # setup SR model
    sr_inputs, sr_outputs = load_sr_inputs_and_outputs(sr_config)
    sr_inputs_preprocessed, sr_outputs_preprocessed = preprocess_sr_input_outputs(sr_inputs, sr_outputs, sr_config)

    # check for linearity hypothesis
    fit_score_potential = are_edge_msg_pair_potential_correlated(sr_inputs, dataset_type)
    fit_score_force = are_edge_msg_pair_force_correlated(sr_inputs, dataset_type)

    print("Fit score for pair potential is: ", fit_score_potential)
    print("Fit score for pair force is: ", fit_score_force)


    sr_checkpoint = regress_edge_message_equations(sr_inputs_preprocessed, sr_outputs_preprocessed, sr_config)


    return model_checkpoint, sr_checkpoint

  def compute_answer(self, eval_config_dict, model_checkpoint, sr_checkpoint, iid_dataset,
                     ood_dataset_config, output_dir):
    if eval_config['evaluate?'] == 'y':
      evaluate_metric(mode_checkpoint, ood_dataset_config)  # report generalization performance

    # report SR performance
   plot_pred_vs_gt_sr(sr_checkpoint)



  def evaluate(self, questions_dict):
    for question in questions_dict:
      model_checkpoint, sr_checkpoint = self.run(question)
      compute_answer(question['eval_config'], model_checkpoint, sr_checkpoint, iid_dataset, ood_dataset, output_dir)

In [None]:
question1_dict = {
    'dataset_config': {
                        'dataset_type': 'LJ',
                        'generate_dataset?': 'n',
                        'pregenerated_dataset_dirs':{'train_val_dir':'top',
                                                     'ood_data_dir': 'top_ood'},
                        'openmm_config':{
                                        'train_val_config':
                                                          {
                                                              'num_particles': 258,
                                                              'particle_type' : 'Ar',
                                                              'num_iters': 1000
                                                              'steps_per_iter': 50
                                                              'reduced_density' : 0.05,
                                                              'box_size':27.27,
                                                          'rad_cutoff': 10.2,
                                                          'custom_potential_expression': 'n',
                                                            'output_dir': 'top/run_1',
                                                          },
                                      'ood_config':
                                                  {
                                                      'num_particles': 1000,
                                                      'particle_type' : 'Ne',
                                                      'num_iters': 1000
                                                      'steps_per_iter': 50
                                                      'reduced_density' : 0.5,
                                                      'rad_cutoff': 15.2,
                                                      'custom_potential_expression': 'n',
                                                      'output_dir': 'top_ood/run_1',
                                                      }
                                    },
                        'export_movie?': 'y',
                        'movie_simulation_dir': 'top/run_1',
                        'movie_sim_frame_range': [500, 999],
                        'movie_output_dir':'sim_movies/run_1/',
                    },
    'GAMD_config': {
        'train_model?': 'n',
        'pretrained_model_checkpoint_dir': 'gamd_model_ckpt_dir/version_1/',
        'model_param_config':{
                  'in_feats' : 1 ,
                 'encoding_size' : 128,
                 'out_feats' : 3 ,
                 'box_size' : 27.27 ,
                 'bond' :None,       #
                 'hidden_dim':128,
                 'conv_layer':4,
                 'edge_embedding_dim':128,
                 'dropout':0.1,
                 'drop_edge':True,
                 'use_layer_norm': False,
                            }
        'model_hyperparam_config':{
                                'loss_fn': 'constrained_edge_std',
                                'lr': 0.00001,
                                'reg_factor': 0.001,
                                'n_epochs': 30,
                                'batch_size': 1,
                                }
                },
    'SR_config':{
                'input_config':{
                                'preprocessing_pipeline': ['avg_filter']
                              },
                'sr_regressor_config':{
                                      'n_iter': 1000,
                                      'unary_operators': [],
                                      'binary_operators': (*, /),
                                      'custom_operators': ,
                                      'maxsize':25,
                                      'batching':'True',
                                      'template_expression':'TemplateExpression',
                                      'custom_objective':'((x1), (y1)) -> abs(y1 - x1)',
                                      'model_selection_criteria':'SymbolicRegression.MLJInterfaceModule.choose_best',
                                      }
              'sr_model_checkpoint_dir': 'sr_model_ckpt_dir/run_1',
              },
    'eval_config':{
                  'evalute': 'n',
                  'metric': 'MAE',
                  }
}

### Q1. Whether the edge messages of GAMD trained on original dataset with constrained msg. std loss correlate with LJ potential function with unguided SR?

In [None]:
gamd_sr = GAMDSR()
gamd_sr.evaluate([question1_dict])

### Q2. Whether the edge messages of GAMD trained on run_5 dataset with constrained msg. std loss correlate with LJ potential function with unguided SR?

In [None]:
question2_dict = {
    'dataset_config': {
                        'dataset_type': 'LJ',
                        'generate_dataset?': 'n',
                        'pregenerated_dataset_dirs':{'train_val_dir':'run_5',
                                                     'ood_data_dir': 'our_runs_ood'},
                        'openmm_config':{
                                        'train_val_config':
                                                          {
                                                              'num_particles': 258,
                                                              'particle_type' : 'Ar',
                                                              'num_iters': 1000
                                                              'steps_per_iter': 50
                                                              'reduced_density' : 0.05,
                                                              'box_size':27.27,
                                                          'rad_cutoff': 10.2,
                                                          'custom_potential_expression': 'n',
                                                            'output_dir': 'top/run_1',
                                                          },
                                      'ood_config':
                                                  {
                                                      'num_particles': 1000,
                                                      'particle_type' : 'Ne',
                                                      'num_iters': 1000
                                                      'steps_per_iter': 50
                                                      'reduced_density' : 0.5,
                                                      'rad_cutoff': 15.2,
                                                      'custom_potential_expression': 'n',
                                                      'output_dir': 'top_ood/run_1',
                                                      }
                                    },
                        'export_movie?': 'y',
                        'movie_simulation_dir': 'top/run_1',
                        'movie_sim_frame_range': [500, 999],
                        'movie_output_dir':'sim_movies/run_1/',
                    },
    'GAMD_config': {
        'train_model?': 'n',
        'pretrained_model_checkpoint_dir': 'gamd_model_ckpt_dir/version_7/',
        'model_param_config':{
                  'in_feats' : 1 ,
                 'encoding_size' : 128,
                 'out_feats' : 3 ,
                 'box_size' : 27.27 ,
                 'bond' :None,       #
                 'hidden_dim':128,
                 'conv_layer':4,
                 'edge_embedding_dim':128,
                 'dropout':0.1,
                 'drop_edge':True,
                 'use_layer_norm': False,
                            }
        'model_hyperparam_config':{
                                'loss_fn': 'constrained_edge_std',
                                'lr': 0.00001,
                                'reg_factor': 0.001,
                                'n_epochs': 30,
                                'batch_size': 1,
                                }
                },
    'SR_config':{
                'input_config':{
                                'preprocessing_pipeline': ['avg_filter']
                              },
                'sr_regressor_config':{
                                      'n_iter': 1000,
                                      'unary_operators': [],
                                      'binary_operators': (*, /),
                                      'custom_operators': ,
                                      'maxsize':25,
                                      'batching':'True',
                                      'template_expression':'TemplateExpression',
                                      'custom_objective':'((x1), (y1)) -> abs(y1 - x1)',
                                      'model_selection_criteria':'SymbolicRegression.MLJInterfaceModule.choose_best',
                                      }
              'sr_model_checkpoint_dir': 'sr_model_ckpt_dir/run_1',
              },
    'eval_config':{
                  'metric': 'MAE',
                  }
}