In [None]:
import py3Dmol

In [None]:
import numpy as np
import tensorflow as tf

tf.reset_default_graph()

from timemachine.functionals import bonded
from timemachine import integrator
import xmltodict


def get_box_and_conf():
    with open('/Users/hessian/Code/timemachine/examples/water/state.xml') as fd:
        doc = xmltodict.parse(fd.read())
        box = doc['State']['PeriodicBoxVectors']
        x = np.float64(box['A']['@x'])
        y = np.float64(box['B']['@y'])
        z = np.float64(box['C']['@z'])
        coords = doc['State']['Positions']
        geom = []
        for elem in coords['Position']:
            geom.append((
                np.float64(elem['@x']),
                np.float64(elem['@y']),
                np.float64(elem['@z']))
            )
        return np.array([x,y,z]), np.array(geom)

def get_system():
    with open('/Users/hessian/Code/timemachine/examples/water/system.xml') as fd:
        doc = xmltodict.parse(fd.read())
        sys = doc['System']
        masses = []
        for p in sys['Particles']['Particle']:
            masses.append(np.float64(p['@mass']))
        masses = np.array(masses)

        bond_params = [
            # tf.get_variable(name='bond_k', dtype=tf.float64, shape=tuple(), initializer=tf.constant_initializer(462750.4)),
            tf.get_variable(name='bond_k', dtype=tf.float64, shape=tuple(), initializer=tf.constant_initializer(100.4)),
            tf.get_variable(name='bond_d', dtype=tf.float64, shape=tuple(), initializer=tf.constant_initializer(0.09572)),
        ]
        bond_idxs = []
        bond_param_idxs = []

        angle_params = [
            # tf.get_variable(name='angle_k', dtype=tf.float64, shape=tuple(), initializer=tf.constant_initializer(836.8)),
            tf.get_variable(name='angle_k', dtype=tf.float64, shape=tuple(), initializer=tf.constant_initializer(100.8)),
            tf.get_variable(name='angle_theta', dtype=tf.float64, shape=tuple(), initializer=tf.constant_initializer(1.82421813418)),
        ]
        angle_idxs = []
        angle_param_idxs = []

        for f in sys['Forces']['Force']:
            if f['@type'] == 'HarmonicBondForce':
                for b in f['Bonds']['Bond']:
                    src, dst = np.int64(b['@p1']), np.int64(b['@p2'])
                    bond_idxs.append((src, dst))
                    bond_param_idxs.append((0, 1))
            if f['@type'] == 'HarmonicAngleForce':
                for a in f['Angles']['Angle']:
                    src, mid, dst = np.int64(a['@p1']), np.int64(a['@p2']), np.int64(a['@p3'])
                    angle_idxs.append((src, mid, dst))
                    angle_param_idxs.append((0, 1))

        bond_idxs = np.array(bond_idxs)
        bond_param_idxs = np.array(bond_param_idxs)

        angle_idxs = np.array(angle_idxs)
        angle_param_idxs = np.array(angle_param_idxs)

        hb = bonded.HarmonicBond(bond_params, bond_idxs, bond_param_idxs)
        ha = bonded.HarmonicAngle(angle_params, angle_idxs, angle_param_idxs)

        return masses, hb, ha

def make_xyz(masses, coords):
    num_atoms = coords.shape[0]
    res = str(num_atoms) + "\n"
    res += "\n"

    for idx in range(num_atoms):
        if masses[idx] > 2:
            element = "O"
        else:
            element = "H"
        c = coords[idx]
        res += element + " " + str(c[0]*10) + " " + str(c[1]*10) + " " + str(c[2]*10) + "\n"

    return res
    

box, x0 = get_box_and_conf()
masses, hb, ha = get_system()

num_atoms = x0.shape[0]

x_ph = tf.placeholder(name="x", shape=(num_atoms, 3), dtype=tf.float64)

friction = 10.0
dt = 0.03
temp = 100

intg = integrator.LangevinIntegrator(masses, x_ph, None, [hb, ha], dt, friction, temp)
dx_op, dxdp_op = intg.step_op(inference=True)

num_steps = 500000

sess = tf.Session()
sess.run(tf.initializers.global_variables())

In [None]:
x = x0.copy()
all_xyz = ""
for step in range(30000):
#     print(step)
    dx_val = sess.run(dx_op, feed_dict={x_ph: x})
    x += dx_val
    if step % 100 == 0:
        all_xyz += make_xyz(masses, x)


lastModel = None

p = py3Dmol.view(width=400,height=400)
p.addModelsAsFrames(all_xyz,'xyz')
p.setStyle({'sphere':{}})
p.setBackgroundColor('0xeeeeee')
p.animate()
p.zoomTo()
p.show()