Skip to content

Commit

Permalink
Make mypy happy
Browse files Browse the repository at this point in the history
  • Loading branch information
skye committed May 19, 2020
1 parent 82f283a commit e87d0e7
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions jax/interpreters/sharded_jit.py
Expand Up @@ -93,6 +93,7 @@ def _sharded_callable(
raise ValueError("sharded_jit only works on TPU!")

num_partitions = pxla.reconcile_num_partitions(jaxpr, num_partitions)
assert num_partitions is not None
if num_partitions > xb.local_device_count():
raise ValueError(
f"sharded_jit computation requires {num_partitions} devices, "
Expand Down

0 comments on commit e87d0e7

Please sign in to comment.