Skip to content

Commit

Permalink
Add typing-extensions, revert removal of Array alias
Browse files Browse the repository at this point in the history
  • Loading branch information
mcwitt committed Apr 12, 2022
1 parent 5006563 commit 4583cc5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -94,6 +94,7 @@ def build_extension(self, ext):
"pymbar>3.0.4",
"pyyaml",
"scipy",
"typing-extensions",
],
extras_require={
"dev": [
Expand Down
9 changes: 6 additions & 3 deletions timemachine/potentials/jax_utils.py
Expand Up @@ -2,9 +2,12 @@
import numpy as np
from jax import vmap
from numpy.typing import NDArray
from typing_extensions import TypeAlias

Array: TypeAlias = NDArray

def get_all_pairs_indices(n: int) -> NDArray:

def get_all_pairs_indices(n: int) -> Array:
"""all indices i, j such that i < j < n"""
n_interactions = n * (n - 1) / 2

Expand All @@ -15,7 +18,7 @@ def get_all_pairs_indices(n: int) -> NDArray:
return pairs


def pairs_from_interaction_groups(group_a_indices: NDArray, group_b_indices: NDArray) -> NDArray:
def pairs_from_interaction_groups(group_a_indices: Array, group_b_indices: Array) -> Array:
"""(a, b) for a in group_a_indices, b in group_b_indices"""
n_interactions = len(group_a_indices) * len(group_b_indices)

Expand Down Expand Up @@ -46,7 +49,7 @@ def compute_lifting_parameter(lamb, lambda_plane_idxs, lambda_offset_idxs, cutof
return w


def augment_dim(x3: NDArray, w: NDArray) -> NDArray:
def augment_dim(x3: Array, w: Array) -> Array:
"""(x,y,z) -> (x,y,z,w)"""

d4 = jnp.expand_dims(w, axis=-1)
Expand Down
7 changes: 5 additions & 2 deletions timemachine/potentials/nonbonded.py
Expand Up @@ -5,6 +5,7 @@
from jax import vmap
from jax.scipy.special import erfc
from numpy.typing import NDArray
from typing_extensions import TypeAlias

from timemachine.potentials import jax_utils
from timemachine.potentials.jax_utils import (
Expand All @@ -15,6 +16,8 @@
pairs_from_interaction_groups,
)

Array: TypeAlias = NDArray


def switch_fn(dij, cutoff):
return jnp.power(jnp.cos((jnp.pi * jnp.power(dij, 8)) / (2 * cutoff)), 2)
Expand Down Expand Up @@ -377,7 +380,7 @@ def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutof
return prefactor_i


def coulomb_prefactors_on_snapshot(x_ligand, x_env, q_env, box=None, beta=2.0, cutoff=np.inf) -> NDArray:
def coulomb_prefactors_on_snapshot(x_ligand, x_env, q_env, box=None, beta=2.0, cutoff=np.inf) -> Array:
"""Map coulomb_prefactor_on_atom over atoms in x_ligand
Parameters
Expand Down Expand Up @@ -431,7 +434,7 @@ def f_snapshot(coords, box):
return vmap(f_snapshot)(traj, boxes)


def coulomb_interaction_group_energy(q_ligand: NDArray, q_prefactors: NDArray) -> float:
def coulomb_interaction_group_energy(q_ligand: Array, q_prefactors: Array) -> float:
"""Assuming q_prefactors = coulomb_prefactors_on_snapshot(x_ligand, ...),
cheaply compute the energy of ligand-environment interaction group
Expand Down

0 comments on commit 4583cc5

Please sign in to comment.