Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama] Make Llama in torchao trainable #674

Closed
gau-nernst opened this issue Aug 14, 2024 · 0 comments · Fixed by #728
Closed

[Llama] Make Llama in torchao trainable #674

gau-nernst opened this issue Aug 14, 2024 · 0 comments · Fixed by #728
Labels
enhancement New feature or request

Comments

@gau-nernst
Copy link
Collaborator

gau-nernst commented Aug 14, 2024

While working on #644, @msaroufim suggested to use the built-in Llama for testing the mini train recipe. I looked into it and here are the 2 main changes to be made.

  1. Initialize freq_cis without initializing KV-Cache and causal mask
    def setup_caches(self, max_batch_size, max_seq_length):
    if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
    return
    head_dim = self.config.dim // self.config.n_head
    max_seq_length = find_multiple(max_seq_length, 8)
    self.max_seq_length = max_seq_length
    self.max_batch_size = max_batch_size
    dtype = self.output.weight.dtype
    # For quantized layers, dtype is encoded in scales
    if hasattr(self.output, "scales"):
    dtype = self.output.scales.dtype
    elif hasattr(self.output, "scales_and_zeros"):
    dtype = self.output.scales_and_zeros.dtype
    for b in self.layers:
    b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
    self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
    self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
  2. Don't use attention mask, just use is_causal=True directly
    y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

This will make it convenient for some of our training recipes (e.g. QAT) to have a mini training scripts directly in torchao, and also act as self-contained examples.

API wise, I think we can add a training flag to Transformer.setup_caches() method.

  • When training=False (default), the old behavior is maintained.
  • When training=True, only freq_cis is initialized, and in the .forward() method, we don't pass mask to TransformerBlock/Attention.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants