Skip to content

Commit

Permalink
Force a sync on non-CPU tensors for the benchmark to reflect the timi…
Browse files Browse the repository at this point in the history
…ng accurately.

ghstack-source-id: 5c8e310984bf160719379dce88cdc697bed82241
Pull Request resolved: #47714
  • Loading branch information
Ashkan Aliabadi committed Dec 4, 2020
1 parent cb28508 commit bac0c75
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions binaries/speed_benchmark_torch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ int main(int argc, char** argv) {
FLAGS_warmup,
".");
for (int i = 0; i < FLAGS_warmup; ++i) {
module.forward(inputs);
if (FLAGS_vulkan) {
module.forward(inputs).toTensor().cpu();
} else {
module.forward(inputs);
}
}

std::cout << "Main runs." << std::endl;
Expand All @@ -231,7 +235,11 @@ int main(int argc, char** argv) {
auto micros = timer.MicroSeconds();
for (int i = 0; i < FLAGS_iter; ++i) {
auto start = high_resolution_clock::now();
module.forward(inputs);
if (FLAGS_vulkan) {
module.forward(inputs).toTensor().cpu();
} else {
module.forward(inputs);
}
auto stop = high_resolution_clock::now();
auto duration = duration_cast<microseconds>(stop - start);
times.push_back(duration.count());
Expand Down

0 comments on commit bac0c75

Please sign in to comment.