In [15]:
import jax.numpy as jnp
from jax import grad, jit
import MDAnalysis as mda

@jit
def distance(positions):
    dist = jnp.linalg.norm(positions[0]-positions[1])
    return dist

grad_distance = grad(distance)
dummy_box = jnp.zeros((3,3))    

### pdb file inspection

In [36]:
file = "/Users/sss/Documents/EnergyGap_project/positions_data/4dvd/4dvd_plumed.pdb"
u = mda.Universe(file)
#Here printed the alpha C positions and the shifted positions. 
#To be compared with the first few plumed printed positions reading from the same file.

# plumed_ca_positions = jnp.array([
#     [-4.02000008e+00,  1.45340004e+00,  1.77740002e+00],
#     [-4.15909996e+00,  1.77970009e+00,  1.68409996e+00],
#     [-4.24570007e+00,  1.77240009e+00,  1.41020002e+00],
#     [-4.30859985e+00,  1.50799999e+00,  1.35129995e+00]
# ])



#Checking the numbe rof trajectories from which the positions can come 
# frames=jnp.array([t.frame for t in u.trajectory])
# print(f'The file has {frames+1} frames.')

#Checking the coordinates of the alpha carbons in this case.
ca= u.select_atoms("name CA")
# ca_positions = ca.positions

# ca_shifted_idx=ca.ix_array+1
# ca_shifted_positions = u.atoms[ca_shifted_idx].positions

# print(ca_positions)
# print(ca_shifted_positions) 
# ca.ix_array
# ca.indices

indx = ca.ids-1
ca_positions = u.atoms[indx].positions
ca_positions
# ca= u.select_atoms("name CA")
# indx = ca.ids
# indx


array([[-4.0125e+01,  1.5904e+01,  1.7255e+01],
       [-4.2797e+01,  1.8511e+01,  1.6415e+01],
       [-4.2870e+01,  1.7366e+01,  1.2735e+01],
       [-4.3717e+01,  1.3775e+01,  1.3653e+01],
       [-4.6534e+01,  1.5079e+01,  1.5864e+01],
       [-4.7744e+01,  1.7538e+01,  1.3249e+01],
       [-4.7913e+01,  1.4689e+01,  1.0688e+01],
       [-4.9832e+01,  1.2414e+01,  1.2979e+01],
       [-5.2305e+01,  1.5148e+01,  1.3765e+01],
       [-5.2726e+01,  1.5765e+01,  1.0098e+01],
       [-5.3519e+01,  1.2266e+01,  8.7520e+00],
       [-5.5785e+01,  1.2164e+01,  1.1797e+01],
       [-5.7681e+01,  1.5327e+01,  1.0835e+01],
       [-5.7949e+01,  1.3914e+01,  7.2800e+00],
       [-5.9457e+01,  1.0684e+01,  8.5540e+00],
       [-6.1801e+01,  1.2807e+01,  1.0670e+01],
       [-6.2881e+01,  1.4941e+01,  7.7320e+00],
       [-6.3547e+01,  1.1775e+01,  5.7350e+00],
       [-6.5504e+01,  1.0178e+01,  8.5980e+00],
       [-6.7730e+01,  1.3215e+01,  8.8570e+00],
       [-6.8377e+01,  1.3027e+01,  5.125

In [39]:
import MDAnalysis as mda
filepdb = "/Users/sss/Documents/EnergyGap_project/positions_data/4dvd/4dvd_plumed.pdb"
filedcd = "/Users/sss/Documents/EnergyGap_project/positions_data/4dvd/4dvd.dcd"
u = mda.Universe(filepdb, filedcd)
print("Trajectory box (Å):", u.trajectory.ts.dimensions[:3])


Trajectory box (Å): [149.61 149.61  62.7 ]




#### Applying the distance function to the 
1. ca_positions  

2. plumed_ca_positions  

3. ca_shifted_posiitons  

Inorder to compare the distance values here with those computed through plumed. 


In [17]:
plumed_positions = jnp.array([
    [-4.02000008e+00,  1.45340004e+00,  1.77740002e+00],
    [-4.15909996e+00,  1.77970009e+00,  1.68409996e+00],
    [-4.24570007e+00,  1.77240009e+00,  1.41020002e+00],
    [-4.30859985e+00,  1.50799999e+00,  1.35129995e+00]
])

d_ca = distance(ca_positions)
d_plumed= distance(plumed_positions)
d_shifted = distance(ca_shifted_positions)

print(f'Distance from ca_positions: {d_ca} vs 0.366777 computed through plumed python Interface')
print(f'Distance form d_plumed: {d_plumed} vs 0.366777 computed through plumed python Interface and 0.146693 using plumed DISTANCE')
print(f'Distance from d_shifted: {d_shifted} vs 0.366777 computed through plumed python Interface')


Distance from ca_positions: 3.8264386653900146 vs 0.366777 computed through plumed python Interface
Distance form d_plumed: 0.36677712202072144 vs 0.366777 computed through plumed python Interface and 0.146693 using plumed DISTANCE
Distance from d_shifted: 3.8264386653900146 vs 0.366777 computed through plumed python Interface


### Is the problem in "linalg" ?

In [18]:
def scrach_distance(positions):
    x1,y1,z1 = positions[0]
    x2,y2,z2 = positions[1]
    dist = ((x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2)**0.5
    return dist

scrach_distance(plumed_positions)

Array(0.36677712, dtype=float32)

### Compare the simple function below with the interface

In [32]:
def scrach_distance(positions):
    x,y,z = positions[0]
    fn = 2*x**2 + 2*y**2 + 2*z**2
    print(f'x,y,z: {x},{y},{z}')
    return fn

d = scrach_distance(ca_positions)
grad_d = grad(scrach_distance)

print(f'function value: {d} and gradient: {grad_d(ca_positions)}') #,d, grad_d(ca_positions))

x,y,z: -41.50699996948242,16.469999313354492,16.893999099731445
x,y,z: Traced<ShapedArray(float32[])>with<JVPTrace> with
  primal = Array(-41.507, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x1285adb30>, in_tracers=(Traced<ShapedArray(float32[1]):JaxprTrace>,), out_tracer_refs=[<weakref at 0x128858bd0; to 'JaxprTracer' at 0x128858ef0>], out_avals=[ShapedArray(float32[])], primitive=squeeze, params={'dimensions': (0,)}, effects=frozenset(), source_info=<jax._src.source_info_util.SourceInfo object at 0x128862350>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={})),Traced<ShapedArray(float32[])>with<JVPTrace> with
  primal = Array(16.47, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x1285adc30>, in