-
-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Best way of implemeting spectral norm #53
Comments
Okay, quite a lot to unpack here. First, just to frame the problem, spectral normalisation is really something being done to the weight matrix, not the linear/etc. layer it happens to be wrapped in. So I'd probably advocate wrapping the weight matrix, not the whole layer: class SpectralNorm(eqx.Module):
weight: jnp.ndarray
...
def __jax_array__(self):
... # use self.weight
return rescaled_weight
model = ...
model = eqx.tree_at(lambda m: m.foo.bar.weight, model, replace_fn=SpectralNorm) Second, However, Putting the above pieces together, we end up with an implementation for the forward pass that looks like: class SpectralNorm(eqx.Module):
weight: jnp.ndarray
u: jnp.ndarray
v: jnp.ndarray
def __init__(self, weight, *, key, **kwargs):
super().__init__(**kwargs)
self.weight = weight
u_dim, v_dim = weight.shape
ukey, vkey = jr.split(key)
u = jr.normal(ukey, (u_dim,))
v = jr.normal(vkey, (v_dim,))
self.u = ... # Need to normalise these.
self.v = ... # I don't recall the exact procedure off the top of my head.
def __jax_array__(self):
u = lax.stop_gradient(u)
v = lax.stop_gradient(v)
σ = jnp.einsum("i,ij,j->", u, self.weight, v)
return weight / σ
model = ...
key = jr.PRNGKey(0)
model = eqx.tree_at(lambda m: m.foo.bar.weight, model, replace_fn=ft.partial(SpectralNorm, key=key)) In practice this implementation can still be improved a little bit -- it could be made to handle >2 dimensional weights, and it's usually best to run some power iterations on the initial On to making the updates for def _is_sn(leaf):
return isinstance(leaf, SpectralNorm)
def _update_spectral_norm(module):
if _is_sn(module):
u = module.u
v = module.v
new_u = ...
new_v = ...
module = eqx.tree_at(lambda m: (m.u, m.v), module, (new_u, new_v))
return module
def update_spectral_norm(module):
return jax.tree_map(_update_spectral_norm, module, is_leaf=_is_sn) On statefulness: yep, this is an option too. If/when I put together a library function for this I'll probably do it that way, but you're correct to be concerned about the sharp edges as an end user! Finally, in terms of constraining a function to being 1-Lipschitz, and in case it's useful for your downstream problem: there's a few less-well-advertised alternatives to spectral norm floating around out there, often based around similar ideas. One approach is to carefully clip the entries of the weight matrix: given a linear map represented as a matrix in |
Wow, thanks a lot for the comprehensive reply, that's more than I could have hoped for! I completely agree with your observation that spectral normalization is not so much about the layer but rather about the weight matrix. The reason why I proposed it that way it's that I didn't see a straightforward way of modifying the weight matrix of a Linear layer. But that combination of Thanks for providing the blueprint for the update_spectral_norm code as well. That should be more than enough to get me started! I'd be happy to provide a PR if there's interest. Finally, thanks for the pointer to weight clipping for achieving 1-Lipschitz. I was indeed unaware of that, it looks like an interesting alternative! In my specific case, I was using spectral normalization along with a trainable scale parameter. So my gut feeling is that the benefit came from the decoupling of scale and direction rather than the 1-Lipschitz property. But since especially in the neural ODE domain, my gut feeling has had a history of being wrong, I might just as well give it a try :p |
So this looked like an interesting problem, and as such I actually got distracted spending a fair chunk of time earlier today trying out an implementation for spectral norm, based on the experimental stateful operations. The result is #55, which introduces What do you think? |
Hi,
I'm coming to Equinox as a user of Diffrax. There I'd like to compare a Neural ODE use-case I have implemented in DiffEqFlux.jl to a Diffrax implementation. In this use-case, it turned out to be profitable to use spectral normalization. So I am wondering how to best implement spectral normalization as an Equinox layer.
The challenge is that spectral normalization has two stateful variables u0, v0 which are approximations of the left- and right-singular vectors of the weight matrix. The difference to batch normalization's statefulness is that their update rules only depend on the weights of the layer, not the inputs to the layer. This means that the forward call can remain pure with merely read-access to u0 and v0. The update of u0 and v0 can happen outside of the forward call after the gradient descent step in the training loop.
So my current idea how to implement this would look something like the following:
The update_spectral_norm method would look something like
This way of updating u0 and v0 outside the training loop should keep everything pure.
Could you comment on whether static_field would be enough to prevent u0 and v0 from being optimized or would I also need a stop_gradient in the forward call?
Also, I'd be interested in your opinion on the poor man's multiple dispatch solution to updating the spectral normalization. I have seen that you have added some experimental way of handling statefulness. Using this could be an option as well, but given that it's explicitly marked as experimental and dangerous I'm a bit hesitant about this.
Thanks for your help!
The text was updated successfully, but these errors were encountered: