Skip to content

Commit

Permalink
Backport PR #1665: Adapt flax to use any mutable states used by a mod…
Browse files Browse the repository at this point in the history
…el generically (#1699)

Co-authored-by: Adam Gayoso <adamgayoso@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and adamgayoso committed Sep 20, 2022
1 parent b1398b1 commit e1fbfa2
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 26 deletions.
4 changes: 2 additions & 2 deletions scvi/module/base/__init__.py
Expand Up @@ -5,7 +5,7 @@
PyroBaseModuleClass,
)
from ._decorators import auto_move_data
from ._jax_module_wrapper import JaxModuleWrapper, TrainStateWithBatchNorm
from ._jax_module_wrapper import JaxModuleWrapper, TrainStateWithState

__all__ = [
"BaseModuleClass",
Expand All @@ -14,5 +14,5 @@
"auto_move_data",
"JaxBaseModuleClass",
"JaxModuleWrapper",
"TrainStateWithBatchNorm",
"TrainStateWithState",
]
16 changes: 8 additions & 8 deletions scvi/module/base/_jax_module_wrapper.py
Expand Up @@ -13,8 +13,8 @@
from ._base_module import JaxBaseModuleClass


class TrainStateWithBatchNorm(train_state.TrainState):
batch_stats: FrozenDict[str, Any]
class TrainStateWithState(train_state.TrainState):
state: FrozenDict[str, Any]


class JaxModuleWrapper:
Expand Down Expand Up @@ -92,7 +92,7 @@ def train(self):
def _bound_module(self):
"""Module bound with parameters learned from training."""
return self.module.bind(
{"params": self.params, "batch_stats": self.batch_stats},
{"params": self.params, **self.state},
rngs=self.rngs,
)

Expand Down Expand Up @@ -161,21 +161,21 @@ def _split_rngs(self):
return ret_rngs

@property
def train_state(self) -> TrainStateWithBatchNorm:
def train_state(self) -> TrainStateWithState:
"""Train state containing learned parameter values from training."""
return self._train_state

@train_state.setter
def train_state(self, train_state: TrainStateWithBatchNorm):
def train_state(self, train_state: TrainStateWithState):
self._train_state = train_state

@property
def params(self) -> flax.core.FrozenDict[str, Any]:
def params(self) -> FrozenDict[str, Any]:
return self.train_state.params

@property
def batch_stats(self) -> FrozenDict[str, Any]:
return self.train_state.batch_stats
def state(self) -> FrozenDict[str, Any]:
return self.train_state.state

def state_dict(self) -> Dict[str, Any]:
"""Returns a serialized version of the train state as a dictionary."""
Expand Down
6 changes: 3 additions & 3 deletions scvi/train/_callbacks.py
Expand Up @@ -160,6 +160,6 @@ def on_train_start(self, trainer, pl_module):
else:
dl = self.dataloader
module_init = module.init(module.rngs, next(iter(dl)))
params = module_init["params"]
batch_stats = module_init["batch_stats"]
pl_module.set_train_state(params, batch_stats)
params = module_init.pop("params")
state = module_init
pl_module.set_train_state(params, state)
24 changes: 11 additions & 13 deletions scvi/train/_trainingplans.py
Expand Up @@ -21,7 +21,7 @@
JaxModuleWrapper,
LossRecorder,
PyroBaseModuleClass,
TrainStateWithBatchNorm,
TrainStateWithState,
)
from scvi.nn import one_hot

Expand Down Expand Up @@ -970,7 +970,7 @@ def __init__(
if optim_kwargs is not None:
self.optim_kwargs.update(optim_kwargs)

def set_train_state(self, params, batch_stats=None):
def set_train_state(self, params, state=None):
if self.module.train_state is not None:
return

Expand All @@ -980,27 +980,27 @@ def set_train_state(self, params, batch_stats=None):
optax.additive_weight_decay(weight_decay=weight_decay),
optax.adam(**self.optim_kwargs),
)
train_state = TrainStateWithBatchNorm.create(
train_state = TrainStateWithState.create(
apply_fn=self.module.apply,
params=params,
tx=optimizer,
batch_stats=batch_stats,
state=state,
)
self.module.train_state = train_state

@staticmethod
@jax.jit
def jit_training_step(
state: TrainStateWithBatchNorm,
state: TrainStateWithState,
batch: Dict[str, np.ndarray],
rngs: Dict[str, jnp.ndarray],
**kwargs,
):
# batch stats can't be passed here
# state can't be passed here
def loss_fn(params):
vars_in = {"params": params, "batch_stats": state.batch_stats}
vars_in = {"params": params, **state.state}
outputs, new_model_state = state.apply_fn(
vars_in, batch, rngs=rngs, mutable=["batch_stats"], **kwargs
vars_in, batch, rngs=rngs, mutable=list(state.state.keys()), **kwargs
)
loss_recorder = outputs[2]
loss = loss_recorder.loss
Expand All @@ -1010,9 +1010,7 @@ def loss_fn(params):
(loss, (elbo, new_model_state)), grads = jax.value_and_grad(
loss_fn, has_aux=True
)(state.params)
new_state = state.apply_gradients(
grads=grads, batch_stats=new_model_state["batch_stats"]
)
new_state = state.apply_gradients(grads=grads, state=new_model_state)
return new_state, loss, elbo

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -1044,12 +1042,12 @@ def training_step(self, batch, batch_idx):
@partial(jax.jit, static_argnums=(0,))
def jit_validation_step(
self,
state: TrainStateWithBatchNorm,
state: TrainStateWithState,
batch: Dict[str, np.ndarray],
rngs: Dict[str, jnp.ndarray],
**kwargs,
):
vars_in = {"params": state.params, "batch_stats": state.batch_stats}
vars_in = {"params": state.params, **state.state}
outputs = self.module.apply(vars_in, batch, rngs=rngs, **kwargs)
loss_recorder = outputs[2]
loss = loss_recorder.loss
Expand Down

0 comments on commit e1fbfa2

Please sign in to comment.