Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions torchbenchmark/e2e_models/hf_bert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
from torchbenchmark.util.framework.transformers.text_classification.dataset import prep_dataset, preprocess_dataset, prep_labels
from torchbenchmark.util.framework.transformers.text_classification.args import parse_args, parse_torchbench_args

try:
import torch._dynamo
except ImportError:
pass

# setup environment variable
CURRENT_DIR = Path(os.path.dirname(os.path.realpath(__file__)))

Expand Down Expand Up @@ -212,22 +217,19 @@ def train(self) -> Optional[dict]:
for _epoch in range(self.hf_args.num_train_epochs):
self.model.train()
for step, batch in enumerate(self.train_dataloader):
outputs = self.model(**batch)
loss = outputs.loss
loss = self.run_forward(batch)
loss = loss / self.hf_args.gradient_accumulation_steps
self.accelerator.backward(loss)
self.run_backward(loss)
if step % self.hf_args.gradient_accumulation_steps == 0 or step == len(self.train_dataloader) - 1:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
self.run_optimizer_step()
completed_steps += 1

if completed_steps >= self.hf_args.max_train_steps:
break
if self.tb_args.validate_in_train:
self.model.eval()
for step, batch in enumerate(self.eval_dataloader):
outputs = self.model(**batch)
outputs = self.run_eval(batch)
predictions = outputs.logits.argmax(dim=-1) if not self.is_regression else outputs.logits.squeeze()
self.metric.add_batch(
predictions=self.accelerator.gather(predictions),
Expand All @@ -245,7 +247,7 @@ def train(self) -> Optional[dict]:

self.model.eval()
for step, batch in enumerate(eval_dataloader):
outputs = self.model(**batch)
outputs = self.run_eval(batch)
predictions = outputs.logits.argmax(dim=-1)
self.metric.add_batch(
predictions=self.accelerator.gather(predictions),
Expand All @@ -259,7 +261,7 @@ def eval(self) -> Optional[dict]:
self.model.eval()
for _step, batch in enumerate(self.eval_dataloader):
with torch.no_grad():
outputs = self.model(**batch)
outputs = self.run_eval(batch)
predictions = outputs.logits.argmax(dim=-1) if not self.is_regression else outputs.logits.squeeze()
self.metric.add_batch(
predictions=self.accelerator.gather(predictions),
Expand All @@ -275,11 +277,43 @@ def run_forward(self, input):
"""
compute model forward and return loss
"""
if self.dynamo:
backend = self.opt_args.torchdynamo
return torch._dynamo.optimize(backend)(self._run_forward)(input)
else:
return self._run_forward(input)

def _run_forward(self, input):
return self.model(**input).loss

def run_backward(self, loss):
if self.dynamo:
backend = self.opt_args.torchdynamo
return torch._dynamo.optimize(backend)(self._run_backward)(loss)
else:
return self._run_backward(loss)

def _run_backward(self, loss):
self.accelerator.backward(loss)

def run_optimizer_step(self):
if self.dynamo:
backend = self.opt_args.torchdynamo
return torch._dynamo.optimize(backend)(self._run_optimizer_step)()
else:
return self._run_optimizer_step()

def _run_optimizer_step(self):
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()

def run_eval(self, input):
if self.dynamo:
backend = self.opt_args.torchdynamo
return torch._dynamo.optimize(backend)(self._run_eval)(input)
else:
return self._run_eval(input)

def _run_eval(self, input):
return self.model(**input)