Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt flax to use any mutable states used by a model generically #1665

Merged
merged 5 commits into from Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions scvi/module/base/__init__.py
Expand Up @@ -5,7 +5,7 @@
PyroBaseModuleClass,
)
from ._decorators import auto_move_data
from ._jax_module_wrapper import JaxModuleWrapper, TrainStateWithBatchNorm
from ._jax_module_wrapper import JaxModuleWrapper, TrainStateWithState

__all__ = [
"BaseModuleClass",
Expand All @@ -14,5 +14,5 @@
"auto_move_data",
"JaxBaseModuleClass",
"JaxModuleWrapper",
"TrainStateWithBatchNorm",
"TrainStateWithState",
]
16 changes: 8 additions & 8 deletions scvi/module/base/_jax_module_wrapper.py
Expand Up @@ -13,8 +13,8 @@
from ._base_module import JaxBaseModuleClass


class TrainStateWithBatchNorm(train_state.TrainState):
batch_stats: FrozenDict[str, Any]
class TrainStateWithState(train_state.TrainState):
state: FrozenDict[str, Any]


class JaxModuleWrapper:
Expand Down Expand Up @@ -92,7 +92,7 @@ def train(self):
def _bound_module(self):
"""Module bound with parameters learned from training."""
return self.module.bind(
{"params": self.params, "batch_stats": self.batch_stats},
{"params": self.params, **self.state},
rngs=self.rngs,
)

Expand Down Expand Up @@ -161,21 +161,21 @@ def _split_rngs(self):
return ret_rngs

@property
def train_state(self) -> TrainStateWithBatchNorm:
def train_state(self) -> TrainStateWithState:
"""Train state containing learned parameter values from training."""
return self._train_state

@train_state.setter
def train_state(self, train_state: TrainStateWithBatchNorm):
def train_state(self, train_state: TrainStateWithState):
self._train_state = train_state

@property
def params(self) -> flax.core.FrozenDict[str, Any]:
def params(self) -> FrozenDict[str, Any]:
return self.train_state.params

@property
def batch_stats(self) -> FrozenDict[str, Any]:
return self.train_state.batch_stats
def state(self) -> FrozenDict[str, Any]:
return self.train_state.state

def state_dict(self) -> Dict[str, Any]:
"""Returns a serialized version of the train state as a dictionary."""
Expand Down
6 changes: 3 additions & 3 deletions scvi/train/_callbacks.py
Expand Up @@ -160,6 +160,6 @@ def on_train_start(self, trainer, pl_module):
else:
dl = self.dataloader
module_init = module.init(module.rngs, next(iter(dl)))
params = module_init["params"]
batch_stats = module_init["batch_stats"]
pl_module.set_train_state(params, batch_stats)
params = module_init.pop("params")
state = module_init
pl_module.set_train_state(params, state)
24 changes: 11 additions & 13 deletions scvi/train/_trainingplans.py
Expand Up @@ -21,7 +21,7 @@
JaxModuleWrapper,
LossRecorder,
PyroBaseModuleClass,
TrainStateWithBatchNorm,
TrainStateWithState,
)
from scvi.nn import one_hot

Expand Down Expand Up @@ -970,7 +970,7 @@ def __init__(
if optim_kwargs is not None:
self.optim_kwargs.update(optim_kwargs)

def set_train_state(self, params, batch_stats=None):
def set_train_state(self, params, state=None):
if self.module.train_state is not None:
return

Expand All @@ -980,27 +980,27 @@ def set_train_state(self, params, batch_stats=None):
optax.additive_weight_decay(weight_decay=weight_decay),
optax.adam(**self.optim_kwargs),
)
train_state = TrainStateWithBatchNorm.create(
train_state = TrainStateWithState.create(
apply_fn=self.module.apply,
params=params,
tx=optimizer,
batch_stats=batch_stats,
state=state,
)
self.module.train_state = train_state

@staticmethod
@jax.jit
def jit_training_step(
state: TrainStateWithBatchNorm,
state: TrainStateWithState,
batch: Dict[str, np.ndarray],
rngs: Dict[str, jnp.ndarray],
**kwargs,
):
# batch stats can't be passed here
# state can't be passed here
def loss_fn(params):
vars_in = {"params": params, "batch_stats": state.batch_stats}
vars_in = {"params": params, **state.state}
outputs, new_model_state = state.apply_fn(
vars_in, batch, rngs=rngs, mutable=["batch_stats"], **kwargs
vars_in, batch, rngs=rngs, mutable=list(state.state.keys()), **kwargs
)
loss_recorder = outputs[2]
loss = loss_recorder.loss
Expand All @@ -1010,9 +1010,7 @@ def loss_fn(params):
(loss, (elbo, new_model_state)), grads = jax.value_and_grad(
loss_fn, has_aux=True
)(state.params)
new_state = state.apply_gradients(
grads=grads, batch_stats=new_model_state["batch_stats"]
)
new_state = state.apply_gradients(grads=grads, state=new_model_state)
return new_state, loss, elbo

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -1044,12 +1042,12 @@ def training_step(self, batch, batch_idx):
@partial(jax.jit, static_argnums=(0,))
def jit_validation_step(
self,
state: TrainStateWithBatchNorm,
state: TrainStateWithState,
batch: Dict[str, np.ndarray],
rngs: Dict[str, jnp.ndarray],
**kwargs,
):
vars_in = {"params": state.params, "batch_stats": state.batch_stats}
vars_in = {"params": state.params, **state.state}
outputs = self.module.apply(vars_in, batch, rngs=rngs, **kwargs)
loss_recorder = outputs[2]
loss = loss_recorder.loss
Expand Down