-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
49 lines (40 loc) · 1.73 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Definitions of models and loss functions."""
import jax
import jax.numpy as jnp
import optax
import haiku as hk
import kfac_jax
from functools import partial
def create_model(name, **constructor_kwargs):
"""Create and transform an instance of hk.nets.`name` using `kwargs`."""
model_constructor = getattr(hk.nets, name)
if 'activation' in constructor_kwargs:
constructor_kwargs['activation'] = getattr(jax.nn, constructor_kwargs['activation'])
return hk.without_apply_rng(
hk.transform_with_state(
lambda x, **kwargs: model_constructor(**constructor_kwargs)(x, **kwargs)))
def create_loss(name, **kwargs):
"""Create an instance of `name` using ``kwargs``."""
loss_function = globals()[name]
return partial(loss_function, **kwargs)
def cross_entropy_loss(logits, labels, kfac_mask, num_classes):
"""Cross-entropy loss function, with necessary registration calls for KFAC-JAX."""
# KFAC_JAX needs to be told to ignore padded data, but `mask` will only zero it,
# so also set a corresponding correcting `weight`
kfac_jax.register_softmax_cross_entropy_loss(
jnp.where(jnp.isfinite(logits), logits, 0),
labels,
mask=kfac_mask,
weight=kfac_mask.shape[0]/kfac_mask.sum())
one_hot_labels = jax.nn.one_hot(labels, num_classes=num_classes)
return jnp.nanmean(
optax.softmax_cross_entropy(
logits, one_hot_labels))
def mse_loss(predictions, targets, kfac_mask):
"""MSE loss function, with necessary registration calls for KFAC-JAX."""
kfac_jax.register_squared_error_loss(
predictions,
targets,
weight=kfac_mask.shape[0]/kfac_mask.sum())
return jnp.nanmean(
optax.l2_loss(predictions, targets))