Skip to content

Commit

Permalink
Add a module to apply updates every k steps (and accumulate them othe…
Browse files Browse the repository at this point in the history
…rwise) (google#2350)
  • Loading branch information
perolat committed Mar 10, 2020
1 parent 863576c commit 5c3b478
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
30 changes: 30 additions & 0 deletions jax/experimental/optix.py
Expand Up @@ -324,6 +324,36 @@ def update_fn(updates, state): # pylint: disable=missing-docstring
return InitUpdate(init_fn, update_fn)


ApplyEvery = collections.namedtuple("ApplyEvery", "count grad_acc")


def apply_every(k=1):
"""accumulate gradients and apply them every k steps.
Args:
k: apply the update every k steps otherwise accumulate the gradients.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
grad_acc = tree_multimap(jnp.zeros_like, params)
return ApplyEvery(count=jnp.zeros([], jnp.int32), grad_acc=grad_acc)

def update_fn(updates, state):

c = state.count % k
acc = c != 0
grad_acc = tree_multimap(
lambda g, ga: acc * ga + g, updates, state.grad_acc)
emit = c == (k - 1)
updates = tree_multimap(lambda ga: emit * ga, grad_acc)
return updates, ApplyEvery(count=state.count + 1, grad_acc=grad_acc)

return InitUpdate(init_fn, update_fn)


### Utilities for building and using custom optimizers. ###


Expand Down
39 changes: 39 additions & 0 deletions tests/optix_test.py
Expand Up @@ -60,6 +60,45 @@ def test_sgd(self):
for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
onp.testing.assert_allclose(x, y, rtol=1e-5)

def test_apply_every(self):
# The frequency of the application of sgd
k = 4
zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

# experimental/optix.py sgd
optix_sgd_params = self.init_params
sgd = optix.sgd(LR, 0.0)
state_sgd = sgd.init(optix_sgd_params)

# experimental/optix.py sgd apply every
optix_sgd_apply_every_params = self.init_params
sgd_apply_every = optix.chain(optix.apply_every(k=k),
optix.trace(decay=0, nesterov=False),
optix.scale(-LR))
state_sgd_apply_every = sgd_apply_every.init(optix_sgd_apply_every_params)
for i in range(STEPS):
# Apply a step of sgd
updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd)
optix_sgd_params = optix.apply_updates(optix_sgd_params, updates_sgd)

# Apply a step of sgd_apply_every
updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update(
self.per_step_updates, state_sgd_apply_every)
optix_sgd_apply_every_params = optix.apply_updates(
optix_sgd_apply_every_params, updates_sgd_apply_every)
if i % k == k-1:
# Check equivalence.
for x, y in zip(
tree_leaves(optix_sgd_apply_every_params),
tree_leaves(optix_sgd_params)):
onp.testing.assert_allclose(x, y, atol=1e-6, rtol=100)
else:
# Check updaue is zero.
for x, y in zip(
tree_leaves(updates_sgd_apply_every),
tree_leaves(zero_update)):
onp.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)

def test_adam(self):
b1, b2, eps = 0.9, 0.999, 1e-8

Expand Down

0 comments on commit 5c3b478

Please sign in to comment.