Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Jul 15, 2022

The composite kernel for view_copy that we generate is special-cased a bit for efficiency to avoid having to do extra clones in some cases.

That logic was slightly wrong though, and is fixed here (it needs to mirror the logic in reshape()).

It manifested as a debug assert firing for Lazy Tensor, which I confirmed no longer fires when running this script:

# ran with "python test_ltc_only_torch.py --device=lazy --sync=1 --nvtx=1"
import torch

import torch._lazy
from torch._lazy.ts_backend import init as init_ts_backend
init_ts_backend()
torch.manual_seed(42)
from transformers import BertForSequenceClassification

def parse_args():
  import argparse
  parser = argparse.ArgumentParser(description='')
  parser.add_argument('--device', type=str, default='cuda')
  parser.add_argument('--sync', type=bool, default=False)
  parser.add_argument('--nvtx', type=bool, default=False)
  return parser.parse_args()

args = parse_args()

device = args.device
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True)

from transformers import AdamW
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_batch = ["I love Pixar.", "I don't care for Pixar."]
encoding = tokenizer(text_batch, return_tensors='pt', padding=True, truncation=True)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

model = model.to(device)
model.train()

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
labels = torch.tensor([1,0]).unsqueeze(0).to(device)
for _ in range(6):
  torch.cuda.nvtx.range_push(f'Iter{_}')

  torch.cuda.nvtx.range_push('F')
  outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
  if args.sync:
    torch._lazy.mark_step()
    torch._lazy.wait_device_ops()
  torch.cuda.nvtx.range_pop()

  loss = outputs.loss

  torch.cuda.nvtx.range_push('B')
  optimizer.zero_grad()
  loss.backward()
  if args.sync:
    torch._lazy.mark_step()
    torch._lazy.wait_device_ops()
  torch.cuda.nvtx.range_pop()

  torch.cuda.nvtx.range_push('O')
  optimizer.step()
  if args.sync:
    torch._lazy.mark_step()
    torch._lazy.wait_device_ops()
  torch.cuda.nvtx.range_pop()

  torch.cuda.nvtx.range_pop()
torch._lazy.mark_step()
torch._lazy.wait_device_ops()

Stack from ghstack:

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 15, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 2a9c14b (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

bdhirsh added a commit that referenced this pull request Jul 15, 2022
ghstack-source-id: 233e3e3
Pull Request resolved: #81553
@bdhirsh bdhirsh requested a review from ezyang July 15, 2022 14:59
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Jul 19, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @bdhirsh.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jul 20, 2022
Summary:
The composite kernel for `view_copy` that we generate is special-cased a bit for efficiency to avoid having to do extra clones in some cases.

That logic was slightly wrong though, and is fixed here (it needs to mirror the logic in `reshape()`).

It manifested as a debug assert firing for Lazy Tensor, which I confirmed no longer fires when running this script:
```
# ran with "python test_ltc_only_torch.py --device=lazy --sync=1 --nvtx=1"
import torch

import torch._lazy
from torch._lazy.ts_backend import init as init_ts_backend
init_ts_backend()
torch.manual_seed(42)
from transformers import BertForSequenceClassification

def parse_args():
  import argparse
  parser = argparse.ArgumentParser(description='')
  parser.add_argument('--device', type=str, default='cuda')
  parser.add_argument('--sync', type=bool, default=False)
  parser.add_argument('--nvtx', type=bool, default=False)
  return parser.parse_args()

args = parse_args()

device = args.device
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True)

from transformers import AdamW
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_batch = ["I love Pixar.", "I don't care for Pixar."]
encoding = tokenizer(text_batch, return_tensors='pt', padding=True, truncation=True)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

model = model.to(device)
model.train()

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
labels = torch.tensor([1,0]).unsqueeze(0).to(device)
for _ in range(6):
  torch.cuda.nvtx.range_push(f'Iter{_}')

  torch.cuda.nvtx.range_push('F')
  outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
  if args.sync:
    torch._lazy.mark_step()
    torch._lazy.wait_device_ops()
  torch.cuda.nvtx.range_pop()

  loss = outputs.loss

  torch.cuda.nvtx.range_push('B')
  optimizer.zero_grad()
  loss.backward()
  if args.sync:
    torch._lazy.mark_step()
    torch._lazy.wait_device_ops()
  torch.cuda.nvtx.range_pop()

  torch.cuda.nvtx.range_push('O')
  optimizer.step()
  if args.sync:
    torch._lazy.mark_step()
    torch._lazy.wait_device_ops()
  torch.cuda.nvtx.range_pop()

  torch.cuda.nvtx.range_pop()
torch._lazy.mark_step()
torch._lazy.wait_device_ops()
```

Pull Request resolved: #81553
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ed0091f8db1265449f13e2bdd1647bf873bd1fea

Reviewed By: jeanschmidt

Differential Revision: D37990671

Pulled By: bdhirsh

fbshipit-source-id: 76a8292f7050e3f24cbe5bacdc6cb8c392ddd4fd
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/276/head branch July 22, 2022 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants