New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pointwise fuse cat with pointwise inputs or outputs and <= 4 inputs #111233
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111233
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 80e8b4d with merge base ba04d84 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: c176b253b176f070df85c694c46cb801f7b60fd0 Pull Request resolved: #111233
… 4 inputs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
df1305e
to
1d0e974
Compare
ghstack-source-id: 24020613e06aacfbb0156e4c3fe356bc789873a9 Pull Request resolved: #111233
… 4 inputs" Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or InputBuffers or if all the outputs are pointwise. I am going to do an OSS perf run, but I'd also be curious to thoughts from reviewers. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
1d0e974
to
d17f000
Compare
ghstack-source-id: 5fbf97d115940f3ccc285ff67b7887ed8c19ffc6 Pull Request resolved: #111233
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like there's some test failures, but the overall optimization makes a lot of sense to me :)
Also need some tests. I think this is a good usage of the "read/write" based tests in test_perf.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level looks fine, but tests are failing.
… 4 inputs" Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or InputBuffers or if all the outputs are pointwise. I am going to do an OSS perf run, but I'd also be curious to thoughts from reviewers. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
ghstack-source-id: 42577559e3daf6bb680844f4595f7f58ac7228e8 Pull Request resolved: #111233
… 4 inputs" Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or InputBuffers or if all the outputs are pointwise. I am going to do an OSS perf run, but I'd also be curious to thoughts from reviewers. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
ghstack-source-id: 0fd06c835f316257531bdf883fbc427d03d462a4 Pull Request resolved: #111233
… 4 inputs" Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or InputBuffers or if all the outputs are pointwise. I am going to do an OSS perf run, but I'd also be curious to thoughts from reviewers. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
ghstack-source-id: 909cb647129bb9cc174abe198f934f21cb2b1436 Pull Request resolved: #111233
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/6581362653 |
… 4 inputs" Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
ghstack-source-id: 31421c15e7f61c33ccb9c492db3213557df6c5cd Pull Request resolved: #111233
… 4 inputs" Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
ghstack-source-id: 14451f02ae1aaf879aa4b6bffa0ddd33bfb24359 Pull Request resolved: #111233
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
… 4 inputs" Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Successfully rebased |
ghstack-source-id: baa2f9d4c0e2a75b78b0fc1b1afee0de50748f15 Pull Request resolved: #111233
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#111233) Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. Pull Request resolved: pytorch#111233 Approved by: https://github.com/Chillee
…ytorch#111233) Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. Pull Request resolved: pytorch#111233 Approved by: https://github.com/Chillee
…ytorch#111233) Improves perf of llama_v2 locally from 1.55 -> 1.57 The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was: ``` def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): iota = torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length) unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0) position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]); unsqueeze = None # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` Also not sure if I should be more worried about concatting reduction->pointwise inputs. Pull Request resolved: pytorch#111233 Approved by: https://github.com/Chillee
Stack from ghstack (oldest at bottom):
Improves perf of llama_v2 locally from 1.55 -> 1.57
The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise.
Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:
Also not sure if I should be more worried about concatting reduction->pointwise inputs.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler