-
Notifications
You must be signed in to change notification settings - Fork 552
gpt-oss model enablement #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
gpt-oss model enablement #1754
Conversation
torchtitan/models/attention.py
Outdated
block_mask = FlexAttention.block_masks[self.mask_key] | ||
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) | ||
|
||
def _forward_with_sink( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I'm curious how expensive it is to always return lse. If it is actually no cost, we can merge the FlexAttention call to the original forward.
cc., @drisspg
48b2a11
to
07c0ff4
Compare
Need to rebase onto #1776 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great in general. Left some comments. May need some rebase on recent & near-future development.
|
||
|
||
# TODO(jianiw): This need to be merged with expert_parallel | ||
def expert_parallel(func: Callable) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I'll merge my refactor, and then please rebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you referring to #1569 ?
if use_sliding_attention: | ||
self.attn = build_attention( | ||
use_flex_attn=True, | ||
attn_mask_type="sliding_window", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think sliding_window
should be orthogonal to causal vs. block-causal.
Namely, with document masking, the sliding window should only attend within single documents.
This should be much easier after #1776
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…ks but reduces mfu for 20b
Summary of current status: There are some prerequisite PRs:
Once these PRs are landed, I will refactor:
|
Keep developing on top of #1559. Thanks @KhoomeiK for initial contribution!
Initialized by the same seed checkpoint, set seed=0 and deterministic = True.

Run 1: dp_shard = 2

Run 2: dp_shard = 2, TP degree = 2 (NGPU=4)

Run 3: dp_shard = 2, TP degree =2, EP degree = 2 (NGPU=4)

Run 4: dp_shard = 2, TP degree = 2, EP degree = 2, ETP degree = 2 (NGPU=4)

Run 5: dp_shard=2, EP degree = 2 (NGPU=2)
