Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointwise fuse cat with pointwise inputs or outputs and <= 4 inputs #111233

Closed
wants to merge 13 commits into from

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Oct 13, 2023

Stack from ghstack (oldest at bottom):

Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise.

Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)
    
    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@eellison eellison mentioned this pull request Oct 13, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111233

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 80e8b4d with merge base ba04d84 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

eellison added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: c176b253b176f070df85c694c46cb801f7b60fd0
Pull Request resolved: #111233
@eellison eellison changed the title Pointwise fuse cat with pointwise inputs or outputs and <= 4 inptus Pointwise fuse cat with pointwise inputs or outputs and <= 4 inputs Oct 13, 2023
… 4 inputs"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: 24020613e06aacfbb0156e4c3fe356bc789873a9
Pull Request resolved: #111233
… 4 inputs"


Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or InputBuffers or if all the outputs are pointwise.

I am going to do an OSS perf run, but I'd also be curious to thoughts from reviewers. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)
    
    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: 5fbf97d115940f3ccc285ff67b7887ed8c19ffc6
Pull Request resolved: #111233
Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like there's some test failures, but the overall optimization makes a lot of sense to me :)

Also need some tests. I think this is a good usage of the "read/write" based tests in test_perf.py.

torch/_inductor/lowering.py Outdated Show resolved Hide resolved
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level looks fine, but tests are failing.

… 4 inputs"


Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or InputBuffers or if all the outputs are pointwise.

I am going to do an OSS perf run, but I'd also be curious to thoughts from reviewers. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)
    
    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 18, 2023
ghstack-source-id: 42577559e3daf6bb680844f4595f7f58ac7228e8
Pull Request resolved: #111233
… 4 inputs"


Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or InputBuffers or if all the outputs are pointwise.

I am going to do an OSS perf run, but I'd also be curious to thoughts from reviewers. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)
    
    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 19, 2023
ghstack-source-id: 0fd06c835f316257531bdf883fbc427d03d462a4
Pull Request resolved: #111233
@eellison eellison requested review from jansel and Chillee and removed request for jansel and Chillee October 19, 2023 00:00
… 4 inputs"


Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or InputBuffers or if all the outputs are pointwise.

I am going to do an OSS perf run, but I'd also be curious to thoughts from reviewers. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)
    
    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 19, 2023
ghstack-source-id: 909cb647129bb9cc174abe198f934f21cb2b1436
Pull Request resolved: #111233
@eellison eellison requested a review from Chillee October 19, 2023 04:58
@eellison
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict gh/eellison/555/orig returned non-zero exit code 1

Rebasing (1/2)
Rebasing (2/2)
Auto-merging torch/_inductor/graph.py
Auto-merging torch/_inductor/ir.py
Auto-merging torch/_inductor/lowering.py
CONFLICT (content): Merge conflict in torch/_inductor/lowering.py
Auto-merging torch/_inductor/utils.py
error: could not apply 4ca00b3aacb... Pointwise fuse cat with pointwise inputs or outputs and <= 4 inptus
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply 4ca00b3aacb... Pointwise fuse cat with pointwise inputs or outputs and <= 4 inptus

Raised by https://github.com/pytorch/pytorch/actions/runs/6581362653

… 4 inputs"


Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. 

Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)
    
    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 19, 2023
ghstack-source-id: 31421c15e7f61c33ccb9c492db3213557df6c5cd
Pull Request resolved: #111233
… 4 inputs"


Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. 

Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)
    
    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 20, 2023
ghstack-source-id: 14451f02ae1aaf879aa4b6bffa0ddd33bfb24359
Pull Request resolved: #111233
@eellison
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

… 4 inputs"


Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise. 

Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)
    
    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/eellison/555/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/111233)

pytorchmergebot pushed a commit that referenced this pull request Oct 20, 2023
ghstack-source-id: baa2f9d4c0e2a75b78b0fc1b1afee0de50748f15
Pull Request resolved: #111233
@eellison eellison added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 20, 2023
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/eellison/555/head branch October 24, 2023 14:23
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
…ytorch#111233)

Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise.

Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)

    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

Pull Request resolved: pytorch#111233
Approved by: https://github.com/Chillee
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…ytorch#111233)

Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise.

Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)

    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

Pull Request resolved: pytorch#111233
Approved by: https://github.com/Chillee
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…ytorch#111233)

Improves perf of llama_v2 locally from 1.55 -> 1.57

The initial heuristic is to lower to pointwise if # of inputs is <= 4, and all the inputs are pointwise or cannot be memory planned away, or if all the outputs are pointwise.

Perf run was +3% on inference.. There are definitely instances where we should be lowering to foreach_kernels, but it's less flexible for fusion. The motivating example was:

```
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    iota =  torch.ops.prims.iota.default(512, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False)

    # File: /scratch/eellison/work/torchdynamo/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    unsqueeze = torch.ops.aten.unsqueeze.default(iota, 0)
    position_ids = torch.ops.aten.reshape.default(unsqueeze, [-1, 512]);  unsqueeze = None

    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```

Also not sure if I should be more worried about concatting reduction->pointwise inputs.

Pull Request resolved: pytorch#111233
Approved by: https://github.com/Chillee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants