From e72c11271dbb1fe215e9f1615ae8e2aa5d393cc9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 13 Nov 2022 06:51:29 +0000 Subject: [PATCH 1/2] Enable E2E hf_bert dynamo+inductor train/eval --- torchbenchmark/e2e_models/hf_bert/__init__.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/torchbenchmark/e2e_models/hf_bert/__init__.py b/torchbenchmark/e2e_models/hf_bert/__init__.py index 663821c339..1cce6c66e3 100644 --- a/torchbenchmark/e2e_models/hf_bert/__init__.py +++ b/torchbenchmark/e2e_models/hf_bert/__init__.py @@ -1,5 +1,6 @@ from accelerate.utils.dataclasses import DeepSpeedPlugin import torch +import torch._dynamo import math import os from pathlib import Path @@ -212,14 +213,11 @@ def train(self) -> Optional[dict]: for _epoch in range(self.hf_args.num_train_epochs): self.model.train() for step, batch in enumerate(self.train_dataloader): - outputs = self.model(**batch) - loss = outputs.loss + loss = self.run_forward(batch) loss = loss / self.hf_args.gradient_accumulation_steps - self.accelerator.backward(loss) + self.run_backward(loss) if step % self.hf_args.gradient_accumulation_steps == 0 or step == len(self.train_dataloader) - 1: - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() + self.run_optimizer_step() completed_steps += 1 if completed_steps >= self.hf_args.max_train_steps: @@ -227,7 +225,7 @@ def train(self) -> Optional[dict]: if self.tb_args.validate_in_train: self.model.eval() for step, batch in enumerate(self.eval_dataloader): - outputs = self.model(**batch) + outputs = self.run_eval(batch) predictions = outputs.logits.argmax(dim=-1) if not self.is_regression else outputs.logits.squeeze() self.metric.add_batch( predictions=self.accelerator.gather(predictions), @@ -245,7 +243,7 @@ def train(self) -> Optional[dict]: self.model.eval() for step, batch in enumerate(eval_dataloader): - outputs = self.model(**batch) + outputs = self.run_eval(batch) predictions = outputs.logits.argmax(dim=-1) self.metric.add_batch( predictions=self.accelerator.gather(predictions), @@ -259,7 +257,7 @@ def eval(self) -> Optional[dict]: self.model.eval() for _step, batch in enumerate(self.eval_dataloader): with torch.no_grad(): - outputs = self.model(**batch) + outputs = self.run_eval(batch) predictions = outputs.logits.argmax(dim=-1) if not self.is_regression else outputs.logits.squeeze() self.metric.add_batch( predictions=self.accelerator.gather(predictions), @@ -275,11 +273,43 @@ def run_forward(self, input): """ compute model forward and return loss """ + if self.dynamo: + backend = self.opt_args.torchdynamo + return torch._dynamo.optimize(backend)(self._run_forward)(input) + else: + return self._run_forward(input) + + def _run_forward(self, input): return self.model(**input).loss def run_backward(self, loss): + if self.dynamo: + backend = self.opt_args.torchdynamo + return torch._dynamo.optimize(backend)(self._run_backward)(loss) + else: + return self._run_backward(loss) + + def _run_backward(self, loss): self.accelerator.backward(loss) def run_optimizer_step(self): + if self.dynamo: + backend = self.opt_args.torchdynamo + return torch._dynamo.optimize(backend)(self._run_optimizer_step)() + else: + return self._run_optimizer_step() + + def _run_optimizer_step(self): self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + def run_eval(self, input): + if self.dynamo: + backend = self.opt_args.torchdynamo + return torch._dynamo.optimize(backend)(self._run_eval)(input) + else: + return self._run_eval(input) + def _run_eval(self, input): + return self.model(**input) From 5fc0e614c31255067b0733e851a610e90ae43f5a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Nov 2022 21:00:30 +0000 Subject: [PATCH 2/2] Wrap import torch._dynamo in try ... except ... --- torchbenchmark/e2e_models/hf_bert/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchbenchmark/e2e_models/hf_bert/__init__.py b/torchbenchmark/e2e_models/hf_bert/__init__.py index 1cce6c66e3..8f61684337 100644 --- a/torchbenchmark/e2e_models/hf_bert/__init__.py +++ b/torchbenchmark/e2e_models/hf_bert/__init__.py @@ -1,6 +1,5 @@ from accelerate.utils.dataclasses import DeepSpeedPlugin import torch -import torch._dynamo import math import os from pathlib import Path @@ -24,6 +23,11 @@ from torchbenchmark.util.framework.transformers.text_classification.dataset import prep_dataset, preprocess_dataset, prep_labels from torchbenchmark.util.framework.transformers.text_classification.args import parse_args, parse_torchbench_args +try: + import torch._dynamo +except ImportError: + pass + # setup environment variable CURRENT_DIR = Path(os.path.dirname(os.path.realpath(__file__)))