-
Notifications
You must be signed in to change notification settings - Fork 25.7k
fix view_copy kernel striding check logic #81553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 2a9c14b (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @bdhirsh. |
Summary:
The composite kernel for `view_copy` that we generate is special-cased a bit for efficiency to avoid having to do extra clones in some cases.
That logic was slightly wrong though, and is fixed here (it needs to mirror the logic in `reshape()`).
It manifested as a debug assert firing for Lazy Tensor, which I confirmed no longer fires when running this script:
```
# ran with "python test_ltc_only_torch.py --device=lazy --sync=1 --nvtx=1"
import torch
import torch._lazy
from torch._lazy.ts_backend import init as init_ts_backend
init_ts_backend()
torch.manual_seed(42)
from transformers import BertForSequenceClassification
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--sync', type=bool, default=False)
parser.add_argument('--nvtx', type=bool, default=False)
return parser.parse_args()
args = parse_args()
device = args.device
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True)
from transformers import AdamW
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_batch = ["I love Pixar.", "I don't care for Pixar."]
encoding = tokenizer(text_batch, return_tensors='pt', padding=True, truncation=True)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
model = model.to(device)
model.train()
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
labels = torch.tensor([1,0]).unsqueeze(0).to(device)
for _ in range(6):
torch.cuda.nvtx.range_push(f'Iter{_}')
torch.cuda.nvtx.range_push('F')
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
if args.sync:
torch._lazy.mark_step()
torch._lazy.wait_device_ops()
torch.cuda.nvtx.range_pop()
loss = outputs.loss
torch.cuda.nvtx.range_push('B')
optimizer.zero_grad()
loss.backward()
if args.sync:
torch._lazy.mark_step()
torch._lazy.wait_device_ops()
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push('O')
optimizer.step()
if args.sync:
torch._lazy.mark_step()
torch._lazy.wait_device_ops()
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
torch._lazy.mark_step()
torch._lazy.wait_device_ops()
```
Pull Request resolved: #81553
Approved by: https://github.com/ezyang
Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ed0091f8db1265449f13e2bdd1647bf873bd1fea
Reviewed By: jeanschmidt
Differential Revision: D37990671
Pulled By: bdhirsh
fbshipit-source-id: 76a8292f7050e3f24cbe5bacdc6cb8c392ddd4fd
The composite kernel for
view_copythat we generate is special-cased a bit for efficiency to avoid having to do extra clones in some cases.That logic was slightly wrong though, and is fixed here (it needs to mirror the logic in
reshape()).It manifested as a debug assert firing for Lazy Tensor, which I confirmed no longer fires when running this script:
Stack from ghstack: