diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index 2397bcfb851..e2d46f8e8d0 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -42,7 +42,7 @@ protected void onCreate(Bundle savedInstanceState) { .findFirst() .get(); - int numIter = intent.getIntExtra("num_iter", 10); + int numIter = intent.getIntExtra("num_iter", 50); // TODO: Format the string with a parsable format Stats stats = new Stats(); diff --git a/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm b/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm index ce685335767..f6c6927e78e 100644 --- a/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm +++ b/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm @@ -85,7 +85,10 @@ @implementation GenericTests XCTFail("Unsupported tag %i at input %d", *input_tag, index); } } + XCTMeasureOptions *options = [[XCTMeasureOptions alloc] init]; + options.iterationCount = 20; [testCase measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ] + options:options block:^{ XCTAssertEqual(module->forward().error(), Error::Ok); }];