Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Best way of implemeting spectral norm #53

Closed
jaschau opened this issue Mar 29, 2022 · 3 comments · Fixed by #55
Closed

Best way of implemeting spectral norm #53

jaschau opened this issue Mar 29, 2022 · 3 comments · Fixed by #55

Comments

@jaschau
Copy link

jaschau commented Mar 29, 2022

Hi,
I'm coming to Equinox as a user of Diffrax. There I'd like to compare a Neural ODE use-case I have implemented in DiffEqFlux.jl to a Diffrax implementation. In this use-case, it turned out to be profitable to use spectral normalization. So I am wondering how to best implement spectral normalization as an Equinox layer.

The challenge is that spectral normalization has two stateful variables u0, v0 which are approximations of the left- and right-singular vectors of the weight matrix. The difference to batch normalization's statefulness is that their update rules only depend on the weights of the layer, not the inputs to the layer. This means that the forward call can remain pure with merely read-access to u0 and v0. The update of u0 and v0 can happen outside of the forward call after the gradient descent step in the training loop.

for step in range(n_steps):
    # compute loss for model
    loss = compute_loss(model)
    # update spectral norm
    update_spectral_norm(model)

So my current idea how to implement this would look something like the following:

class SpectralNormalization(eqx.Module):
    # is static_field() enough to prevent back-propagation?
    layer: eqx.Module
    layer_var: string = static_field()
    u0: Array = static_field()
    v0: Array = static_field()
    
    def __init__(self, layer: eqx.Module, layer_var="weight"):
        self.layer = layer
        self.layer_var = layer_var
       # TODO: init u0 and v0 randomly with shape appropriate for layer's layer_var

    def __call__(x):
        W = getattr(x, self.layer_var)
        v0 = self.v0 # does this need a stop_gradient()?
        u0 = self.v0  # does this need a stop_gradient()
        sigma = jnp.matmul(jnp.matmul(v0, W), jnp.transpose(u0))[0, 0]
        # for a layer of the form W*x + b, i.e., linear layers and conv layers, this results in computing
        # W/sigma * x + b, which is precisely the spectral normalization
        return self.layer(x / sigma)

The update_spectral_norm method would look something like

def update_spectral_norm(model):
    # poor man's multiple dispatch
    if isinstance(model, "SpectralNormalization"):
        # update u0 and v0 with the power method
   else
       pass

This way of updating u0 and v0 outside the training loop should keep everything pure.

Could you comment on whether static_field would be enough to prevent u0 and v0 from being optimized or would I also need a stop_gradient in the forward call?
Also, I'd be interested in your opinion on the poor man's multiple dispatch solution to updating the spectral normalization. I have seen that you have added some experimental way of handling statefulness. Using this could be an option as well, but given that it's explicitly marked as experimental and dangerous I'm a bit hesitant about this.
Thanks for your help!

@patrick-kidger
Copy link
Owner

patrick-kidger commented Mar 29, 2022

Okay, quite a lot to unpack here.

First, just to frame the problem, spectral normalisation is really something being done to the weight matrix, not the linear/etc. layer it happens to be wrapped in. So I'd probably advocate wrapping the weight matrix, not the whole layer:

class SpectralNorm(eqx.Module):
    weight: jnp.ndarray
    ...
    
    def __jax_array__(self):
        ...  # use self.weight
        return rescaled_weight

model = ...
model = eqx.tree_at(lambda m: m.foo.bar.weight, model, replace_fn=SpectralNorm)

Second, static_field isn't the right tool for this job. This is intended for immutable metadata. The way to think about this is "if my static field changes somehow (i.e. a new instance of the Module with different metadata), then I would like to recompile any jitted methods that I hit". Here, of course, we don't want to re-jit on every step after having updated u and v.

However, lax.stop_gradient is the right tool for this job! This is indeed simply blocking gradient updates from occuring, and that's exactly what we want.

Putting the above pieces together, we end up with an implementation for the forward pass that looks like:

class SpectralNorm(eqx.Module):
    weight: jnp.ndarray
    u: jnp.ndarray
    v: jnp.ndarray

    def __init__(self, weight, *, key, **kwargs):
        super().__init__(**kwargs)
        self.weight = weight
        u_dim, v_dim = weight.shape
        ukey, vkey = jr.split(key)
        u = jr.normal(ukey, (u_dim,))
        v = jr.normal(vkey, (v_dim,))
        self.u = ...  # Need to normalise these.
        self.v = ...  # I don't recall the exact procedure off the top of my head.
    
    def __jax_array__(self):
        u = lax.stop_gradient(u)
        v = lax.stop_gradient(v)
        σ = jnp.einsum("i,ij,j->", u, self.weight, v)
        return weight / σ

model = ...
key = jr.PRNGKey(0)
model = eqx.tree_at(lambda m: m.foo.bar.weight, model, replace_fn=ft.partial(SpectralNorm, key=key))

In practice this implementation can still be improved a little bit -- it could be made to handle >2 dimensional weights, and it's usually best to run some power iterations on the initial self.weight rather than leaving u, v completely random at initialisation. In both cases c.f. the PyTorch implementation here.

On to making the updates for u, v. Your "poor man's multiple dispatch" is also essentially the right thing to do. You need to (a) tree_map it over your model, and (b) arrange for the updates to happen out-of-place, as in JAX most things are immutable (in particular, including eqx.Modules). So something like the following.

def _is_sn(leaf):
    return isinstance(leaf, SpectralNorm)

def _update_spectral_norm(module):
    if _is_sn(module):
        u = module.u
        v = module.v
        new_u = ...
        new_v = ...
        module = eqx.tree_at(lambda m: (m.u, m.v), module, (new_u, new_v))
    return module

def update_spectral_norm(module):
    return jax.tree_map(_update_spectral_norm, module, is_leaf=_is_sn)

On statefulness: yep, this is an option too. If/when I put together a library function for this I'll probably do it that way, but you're correct to be concerned about the sharp edges as an end user!

Finally, in terms of constraining a function to being 1-Lipschitz, and in case it's useful for your downstream problem: there's a few less-well-advertised alternatives to spectral norm floating around out there, often based around similar ideas. One approach is to carefully clip the entries of the weight matrix: given a linear map represented as a matrix in R^{n x m} then clipping entries to the range [-1/m, 1/m] will do this, and in particular actually guarantees the 1-Lipschitz bound (unlike power iterations, which are an approximation scheme). In the context of neural SDEs that's done in this paper; some of the review comments also mention alternate (approximate) strategies that may be of interest to you.

@jaschau
Copy link
Author

jaschau commented Mar 29, 2022

Wow, thanks a lot for the comprehensive reply, that's more than I could have hoped for!

I completely agree with your observation that spectral normalization is not so much about the layer but rather about the weight matrix. The reason why I proposed it that way it's that I didn't see a straightforward way of modifying the weight matrix of a Linear layer. But that combination of eqx.tree_at and __jax__array__ looks like a neat and ingenious of achieving that.

Thanks for providing the blueprint for the update_spectral_norm code as well. That should be more than enough to get me started! I'd be happy to provide a PR if there's interest.

Finally, thanks for the pointer to weight clipping for achieving 1-Lipschitz. I was indeed unaware of that, it looks like an interesting alternative! In my specific case, I was using spectral normalization along with a trainable scale parameter. So my gut feeling is that the benefit came from the decoupling of scale and direction rather than the 1-Lipschitz property. But since especially in the neural ODE domain, my gut feeling has had a history of being wrong, I might just as well give it a try :p

@patrick-kidger
Copy link
Owner

So this looked like an interesting problem, and as such I actually got distracted spending a fair chunk of time earlier today trying out an implementation for spectral norm, based on the experimental stateful operations.

The result is #55, which introduces equinox.experimental.SpectralNorm. :)
(The PR also a few other unrelated things in there, but this file is the one to do with spectral normalisation.)

What do you think?
In retrospect I should have poked you in that direction instead, as it's always a great thing to get the community involved writing PRs!
If there's any changes you'd make then please let me know (/open a PR for those :D )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants