Skip to content

Commit

Permalink
Merge 9fba12c into 21b87dc
Browse files Browse the repository at this point in the history
  • Loading branch information
undertherain committed Jun 19, 2020
2 parents 21b87dc + 9fba12c commit e14ac9d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion benchmarker/modules/problems/bert/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def get_data(params):
params["problem"]["len_sequence"],
)
X = np.random.random(shape).astype(np.int64)
Y = np.ones((cnt_batches, params["batch_size"]))
Y = np.ones((cnt_batches, params["batch_size"]), dtype=np.int64)
return X, Y
13 changes: 12 additions & 1 deletion test/pytorch/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PytorchBertTest(unittest.TestCase):
def setUp(self):
self.name = "benchmarker"

def test_bert(self):
def test_bert_inference(self):
run_module(
self.name,
"--framework=pytorch",
Expand All @@ -20,3 +20,14 @@ def test_bert(self):
"--nb_epoch=1",
"--mode=inference",
)

def test_bert_training(self):
run_module(
self.name,
"--framework=pytorch",
"--problem=bert",
"--problem_size=4,8",
"--batch_size=2",
"--nb_epoch=1",
"--mode=training",
)

0 comments on commit e14ac9d

Please sign in to comment.