Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
Add typing-extensions for TypeAlias annotation. See:
- python/mypy#7866
- https://peps.python.org/pep-0613/
  • Loading branch information
mcwitt committed Apr 15, 2022
1 parent 5e9d77b commit 992e941
Show file tree
Hide file tree
Showing 26 changed files with 139 additions and 115 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Expand Up @@ -20,6 +20,10 @@ plugins = ["numpy.typing.mypy_plugin"]
ignore_missing_imports = true
check_untyped_defs = false

[[tool.mypy.overrides]]
module = "timemachine._vendored.fire"
follow_imports = "silent"

[build-system]
requires = ["setuptools>=43.0.0", "wheel", "cmake==3.22.1", "versioneer-518"]
build-backend = "setuptools.build_meta"
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -8,3 +8,4 @@ pyyaml==5.4.1
hilbertcurve==1.0.5
networkx==2.5
importlib-resources==5.4.0
typing-extensions==4.1.1
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -94,6 +94,7 @@ def build_extension(self, ext):
"pymbar>3.0.4",
"pyyaml",
"scipy",
"typing-extensions",
],
extras_require={
"dev": [
Expand Down
11 changes: 6 additions & 5 deletions timemachine/fe/endpoint_correction.py
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray
from scipy.stats import special_ortho_group

from timemachine.potentials import bonded, rmsd
Expand Down Expand Up @@ -79,16 +80,16 @@ def estimate_delta_us(k_translation, k_rotation, core_idxs, core_params, beta, l
k_rotation: float
Force constant of the rotational restraint
core_idxs: int np.array (C, 2)
core_idxs: int NDArray (C, 2)
Atom mapping between the two cores
core_params: float np.array (C, 2)
core_params: float NDArray (C, 2)
Bonded parameters of the intractable restraint
lhs_xs: np.array [T, N, 3]
lhs_xs: NDArray [T, N, 3]
Samples from the intractable left hand state, with the restraints turned on.
rhs_xs: np.array [T, N, 3]
rhs_xs: NDArray [T, N, 3]
Samples from the non-interacting, fully unrestrained right hand state.
beta: 1/kT
Expand Down Expand Up @@ -177,7 +178,7 @@ def align(x, r, t):


# courtesy of jfass
def ecdf(x: np.array) -> Tuple[np.array, np.array]:
def ecdf(x: NDArray) -> Tuple[NDArray, NDArray]:
"""empirical cdf, from https://stackoverflow.com/a/37660583"""
xs = np.sort(x)
ys = np.arange(1, len(xs) + 1) / float(len(xs))
Expand Down
20 changes: 9 additions & 11 deletions timemachine/fe/estimator.py
@@ -1,7 +1,7 @@
import copy
import dataclasses
import time
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import pymbar
Expand Down Expand Up @@ -32,7 +32,7 @@ class FreeEnergyModel:
v0: NDArray
integrator: LangevinIntegrator
barostat: MonteCarloBarostat
lambda_schedule: List[float]
lambda_schedule: Union[Sequence[float], NDArray]
equil_steps: int
prod_steps: int
beta: float
Expand All @@ -41,7 +41,7 @@ class FreeEnergyModel:

def equilibrate(
integrator: LangevinIntegrator,
barostat: LangevinIntegrator,
barostat: MonteCarloBarostat,
potentials: List,
coords: NDArray,
box: NDArray,
Expand Down Expand Up @@ -181,7 +181,7 @@ def simulate(
prod_steps: int,
x_interval: int,
u_interval: int,
lambda_windows: List[float],
lambda_windows: Union[Sequence[float], NDArray],
):
"""
Run a simulation and collect relevant statistics for this simulation.
Expand Down Expand Up @@ -321,13 +321,13 @@ def deltaG_from_results(
else:
sim_results = results

U_knk = []
U_knk_ = []
N_k = []
for result in sim_results:
U_knk.append(result.lambda_us)
U_knk_.append(result.lambda_us)
N_k.append(len(result.lambda_us)) # number of frames

U_knk = np.array(U_knk)
U_knk = np.array(U_knk_)

bar_dG = 0
bar_dG_err = 0
Expand All @@ -354,9 +354,8 @@ def deltaG_from_results(
)

# for MBAR we need to sanitize the energies
clean_U_knks = [] # [K, F, K]
for lambda_idx, full_us in enumerate(U_knk):
clean_U_knks.append(sanitize_energies(full_us, lambda_idx))
# [K, F, K]
clean_U_knks = np.array([sanitize_energies(full_us, lambda_idx) for lambda_idx, full_us in enumerate(U_knk)])

print(
model.prefix,
Expand All @@ -369,7 +368,6 @@ def deltaG_from_results(
)

K = len(model.lambda_schedule)
clean_U_knks = np.array(clean_U_knks) # [K, F, K]
U_kn = np.reshape(clean_U_knks, (-1, K)).transpose() # [K, F*K]
u_kn = U_kn * model.beta

Expand Down
6 changes: 3 additions & 3 deletions timemachine/fe/frames.py
@@ -1,15 +1,15 @@
from typing import Any, Generator, List, Tuple
from typing import Any, Iterable, List, Tuple

SimulationResult = Any
FrameIterator = Generator[Tuple[int, SimulationResult], None, None]
FrameIterator = Iterable[Tuple[int, SimulationResult]]


def all_frames(results: List[SimulationResult]) -> FrameIterator:
return enumerate(results)


def endpoint_frames_only(results: List[SimulationResult]) -> FrameIterator:
output = []
output: List[Tuple[int, SimulationResult]] = []
if len(results) == 0:
return output
output.append((0, results[0]))
Expand Down
2 changes: 1 addition & 1 deletion timemachine/fe/free_energy.py
Expand Up @@ -51,7 +51,7 @@ def log(self):
self.dG_solvent_decouple_error,
)

def __eq__(self, other: "RABFEResult") -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, RABFEResult):
return NotImplemented
equal = True
Expand Down
24 changes: 13 additions & 11 deletions timemachine/fe/model.py
@@ -1,9 +1,10 @@
import os
from pickle import dump, load
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray
from rdkit import Chem
from simtk import openmm

Expand All @@ -23,13 +24,13 @@ def __init__(
client: Optional[AbstractClient],
ff: Forcefield,
complex_system: openmm.System,
complex_coords: np.ndarray,
complex_box: np.ndarray,
complex_schedule: np.ndarray,
complex_coords: NDArray,
complex_box: NDArray,
complex_schedule: Union[Sequence[float], NDArray],
solvent_system: openmm.System,
solvent_coords: np.ndarray,
solvent_box: np.ndarray,
solvent_schedule: np.ndarray,
solvent_coords: NDArray,
solvent_box: NDArray,
solvent_schedule: Union[Sequence[float], NDArray],
equil_steps: int,
prod_steps: int,
barostat_interval: int = 25,
Expand All @@ -54,9 +55,9 @@ def __init__(
self.barostat_interval = barostat_interval
self.pre_equilibrate = pre_equilibrate
self.hmr = hmr
self._equil_cache = {}
self._equil_cache: Dict[str, Any] = {}

def _edge_hash(self, stage: str, mol_a: Chem.Mol, mol_b: Chem.Mol, core: np.ndarray) -> str:
def _edge_hash(self, stage: str, mol_a: Chem.Mol, mol_b: Chem.Mol, core: NDArray) -> str:
a = Chem.MolToSmiles(mol_a)
b = Chem.MolToSmiles(mol_b)
# Perhaps bad idea to have ordering of a and b?
Expand Down Expand Up @@ -131,6 +132,7 @@ def equilibrate_edges(
pots = []
for bp, params in zip(unbound_potentials, sys_params):
pots.append(bp.bind(np.asarray(params)))
assert self.client is not None
future = self.client.submit(
estimator.equilibrate, *[integrator, barostat, pots, coords, host_box, lamb, equilibration_steps]
)
Expand All @@ -147,7 +149,7 @@ def equilibrate_edges(
dump(self._equil_cache, ofs)
print(f"Saved equilibration_cache to {cache_path}")

def predict(self, ff_params: list, mol_a: Chem.Mol, mol_b: Chem.Mol, core: np.ndarray):
def predict(self, ff_params: list, mol_a: Chem.Mol, mol_b: Chem.Mol, core: NDArray):
"""
Predict the ddG of morphing mol_a into mol_b. This function is differentiable w.r.t. ff_params.
Expand All @@ -163,7 +165,7 @@ def predict(self, ff_params: list, mol_a: Chem.Mol, mol_b: Chem.Mol, core: np.nd
mol_b: Chem.Mol
Starting molecule corresponding to lambda = 1
core: np.ndarray
core: NDArray
N x 2 list of ints corresponding to the atom mapping of the core.
Returns
Expand Down
35 changes: 18 additions & 17 deletions timemachine/fe/model_rabfe.py
@@ -1,7 +1,8 @@
from abc import ABC
from typing import Any, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
from numpy.typing import NDArray
from rdkit import Chem
from simtk import openmm

Expand All @@ -17,17 +18,17 @@
class AbsoluteModel(ABC):
def __init__(
self,
client: AbstractClient or None,
client: Optional[AbstractClient],
ff: Forcefield,
host_system: openmm.System,
host_schedule: np.ndarray,
host_schedule: Union[Sequence[float], NDArray],
host_topology: openmm.app.Topology,
temperature: float,
pressure: float,
dt: float,
equil_steps: int,
prod_steps: int,
frame_filter: Optional[callable] = None,
frame_filter: Optional[Callable] = None,
):

self.host_system = host_system
Expand Down Expand Up @@ -235,7 +236,7 @@ def __init__(
dt: float,
equil_steps: int,
prod_steps: int,
frame_filter: Optional[callable] = None,
frame_filter: Optional[Callable] = None,
k_core: float = 30.0,
):

Expand Down Expand Up @@ -445,9 +446,9 @@ def predict_from_futures(
assert len(futures) == 2
assert len(models) == 2
assert len(sys_params) == 2
err = 0
fwd_dG = 0
back_dG = 0
err = 0.0
fwd_dG = 0.0
back_dG = 0.0
for i, (params, model, sub_futures) in enumerate(zip(sys_params, models, futures)):
results = [fut.result() for fut in sub_futures]
dG, dG_err, results = estimator.deltaG_from_results(model, results, params)
Expand Down Expand Up @@ -481,9 +482,9 @@ def predict(
ff_params: list,
mol_a: Chem.Mol,
mol_b: Chem.Mol,
core_idxs: np.array,
x0: np.array,
box0: np.array,
core_idxs: NDArray,
x0: NDArray,
box0: NDArray,
prefix: str,
seed: int = 0,
):
Expand All @@ -508,7 +509,7 @@ def predict(
mol_b: Chem.Mol
Resulting molecule
core_idxs: np.array (Nx2), dtype int32
core_idxs: NDArray (Nx2), dtype int32
Atom mapping defining the core, mapping atoms from mol_a to atoms in mol_b.
x0: np.ndarray
Expand Down Expand Up @@ -606,7 +607,7 @@ def __init__(
dt: float,
equil_steps: int,
prod_steps: int,
frame_filter: Optional[callable] = None,
frame_filter: Optional[Callable] = None,
k_core: float = 30.0,
):

Expand Down Expand Up @@ -801,9 +802,9 @@ def predict(
ff_params: list,
mol_a: Chem.Mol,
mol_b: Chem.Mol,
core_idxs: np.array,
x0: np.array,
box0: np.array,
core_idxs: NDArray,
x0: NDArray,
box0: NDArray,
prefix: str,
seed: int = 0,
):
Expand All @@ -824,7 +825,7 @@ def predict(
mol_b: Chem.Mol
Resulting molecule
core_idxs: np.array (Nx2), dtype int32
core_idxs: NDArray (Nx2), dtype int32
Atom mapping defining the core, mapping atoms from mol_a to atoms in mol_b.
x0: np.ndarray
Expand Down
10 changes: 4 additions & 6 deletions timemachine/fe/restraints.py
Expand Up @@ -57,12 +57,10 @@ def setup_relative_restraints_by_distance(

row_idxs, col_idxs = linear_sum_assignment(rij)

core_idxs = []

for core_a, core_b in zip(row_idxs, col_idxs):
core_idxs.append((core_idxs_a[core_a], core_idxs_b[core_b]))

core_idxs = np.array(core_idxs, dtype=np.int32)
core_idxs = np.array(
[(core_idxs_a[core_a], core_idxs_b[core_b]) for core_a, core_b in zip(row_idxs, col_idxs)],
dtype=np.int32,
)

return core_idxs

Expand Down
4 changes: 2 additions & 2 deletions timemachine/fe/reweighting.py
Expand Up @@ -5,15 +5,15 @@
"interpret_as_mixture_potential",
]

from typing import Callable, Collection
from typing import Any, Callable, Collection

import numpy as np
from jax import numpy as jnp
from jax.scipy.special import logsumexp

Samples = Collection
Params = Collection
Array = jnp.ndarray
Array = Any # see https://github.com/google/jax/issues/943
Energies = Array

BatchedReducedPotentialFxn = Callable[[Samples, Params], Energies]
Expand Down

0 comments on commit 992e941

Please sign in to comment.