Skip to content

Commit

Permalink
Set enable_mapped_nlist to expect mapping_fxn to take in box dimensio…
Browse files Browse the repository at this point in the history
…ns (#337)

* enable_mapped_nlist now expects mapping_fxn to handle box dims

* update test mapping function call
  • Loading branch information
RainierBarrett committed Oct 18, 2021
1 parent 3f2a11c commit d0fae8a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
36 changes: 34 additions & 2 deletions htf/simmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def mapped_positions(self, positions):
return positions[:self._map_i], positions[self._map_i:]

@tf.function
def precompute(self, dtype, positions_addr):
def precompute(self, dtype, positions_addr, box_addr):
if self._map_nlist:
tf_to_hoomd_module = load_htf_op_library('tf2hoomd_op')
tf_to_hoomd = tf_to_hoomd_module.tf_to_hoomd
Expand All @@ -300,7 +300,39 @@ def precompute(self, dtype, positions_addr):
T=dtype,
name='pos-input-pre'
)
cg_pos = self._map_fxn(pos[:self._map_i])

box = hoomd_to_tf(
address=box_addr,
shape=[3],
T=dtype,
name='box-input'
)

# check box skew
tf.Assert(tf.less(tf.reduce_sum(box[2]), 0.0001), ['box is skewed'])

# for TF2.4.1 we hack the box to have leading batch dimension
# because TF has 4k backlogged issues
# get and parse the version of the detected TF version
vtf = parse_version(tf.__version__)
if vtf >= parse_version('2.4'):
box = tf.SparseTensor(
indices=[[0, 0, 0],
[0, 0, 1],
[0, 0, 2],
[0, 1, 0],
[0, 1, 1],
[0, 1, 2],
[0, 2, 0],
[0, 2, 1],
[0, 2, 2]],
values=tf.reshape(box, (-1,)),
dense_shape=(tf.shape(pos)[0], 3, 3)
)

bs = box_size(box)

cg_pos = self._map_fxn(pos[:self._map_i], bs)
# types will NOT be overwritten, so we do not need to add offset
new_pos = tf.concat((pos[:self._map_i], cg_pos), axis=0)
tf_to_hoomd(tf.cast(new_pos, dtype),
Expand Down
14 changes: 8 additions & 6 deletions htf/tensorflowcompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,19 @@ def enable_mapped_nlist(self, system, mapping_fxn):
:param system: hoomd system
:type system: hoomd system
:param mapping_fxn: a function whose signature is ``f(positions)`` where positions is an
``Nx4`` array of fine-grained positions and
whose return value is an ``Mx4`` array
of coarse-grained positions.
:param mapping_fxn: a function whose signature is ``f(positions, box)`` where
positions is an ``Nx4`` array of fine-grained positions and
box is a list containing Lx, Ly, and Lz of the simulation box,
and whose return value is an ``Mx4`` array of
coarse-grained positions.
:type mapping_fxn: python callable
'''

# get snapshot and insert cg beads
snap = system.take_snapshot()
cg_pos = mapping_fxn(
snap.particles.position.astype(self.model.dtype))
snap.particles.position.astype(self.model.dtype),
[snap.box.Lx, snap.box.Ly, snap.box.Lz])
M = cg_pos.shape[0]
AAN = snap.particles.N
aa_pos = snap.particles.position
Expand Down Expand Up @@ -306,7 +308,7 @@ def _start_update(self):
''' Perhaps suboptimal call to see if there is a precompute step.
'''
self.model.precompute(
self.dtype, self.cpp_force.getPositionsBuffer())
self.dtype, self.cpp_force.getPositionsBuffer(), self.cpp_force.getBoxBuffer())

def _finish_update(self, batch_index):
''' Allow TF to read output and we wait for it to finish.
Expand Down
2 changes: 1 addition & 1 deletion htf/test-py/build_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def compute(self, nlist, positions, box):


class MappedNlist(htf.SimModel):
def my_map(pos):
def my_map(pos, box):
x = tf.reduce_mean(pos[:, :3], axis=0, keepdims=True)
cg1 = tf.concat((x, tf.zeros((1, 1), dtype=x.dtype)), -1)
cg2 = tf.convert_to_tensor([[0, 0, 0.1, 1]], dtype=x.dtype)
Expand Down

0 comments on commit d0fae8a

Please sign in to comment.