Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python-like apply method to Module to initialize weights and biases #61

Open
hmf opened this issue Oct 13, 2023 · 2 comments
Open

Comments

@hmf
Copy link
Contributor

hmf commented Oct 13, 2023

Add a weight and bias initialization method to the nn.Module so we can set these values via an apply method like PyTorch that does this.

Reference to Python documentation here.
Code here.

This code is required to complete issue #51.

@hmf
Copy link
Contributor Author

hmf commented Oct 14, 2023

I am trying to re-implement the following Python function that initializes the values of a module's weights and biases:

    # better init, not covered in the original GPT video, but important, will cover in followup video
    self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

After adding some additional init function to Storch, I coded the following function:

    private def init_weights[D <: FloatNN | ComplexNN](m: Module with HasWeight[D]): Unit = 
      m match
        case lm : nn.Linear[_] => 
          torch.nn.init.normal_(lm.weight, mean=0.0, std=0.02)
          if true // lm.options.bias()
          then
            torch.nn.init.zeros_(lm.bias)
        case _ : nn.Embedding[_] => 
          ???
        case _ => ???
      ???

The first thing to note is that Moduledoes not have a weightmember so I had to use HasWeight[D]. The HasWeight[D] does not, unlike other traits in Module extend nn.Module.

The second thing of note is that we don't have a (adapted from HasWeight[D]):

trait HasBias[ParamType <: FloatNN | ComplexNN]:
  def bias: Tensor[ParamType]

The issue I now have is to find a way to test if the Module has bias. The nn.Linear, for example, has LinearOptions that I could use, but it is private. I assume the objective is to keep this hidden to maintain an idiomatic Scala API. Moreover, not all modules will have options that include bias (for example Embedding).

The simplest solution is to have a hasBias(): Boolean method. The Module trait could have a default implementation that returns false. Any class that could have bias would have to override this method and access the options to return Boolean value.

Alternatively one could add a HasBias trait with the hasBias(): Boolean method. In this case overriding the method to return true may not be safe (depends on the order in which a class/trait is extended?)

Finally, we could try something fancy with type parameters so that bias existence is known at compile time, but I am uncertain of this.

Any suggestions on how I should proceed?

TIA

@sbrunk
Copy link
Owner

sbrunk commented Oct 17, 2023

Sorry @hmf missed that somehow. I'd suggest we start with the simplest option, adding hasBias(): Boolean to Module.

Since enabling/disabling bias is often a constructor parameter, I think it is harder to type compared to HasWeights. We can still improve later if we see that it makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants