Skip to content

Commit

Permalink
[iOS][OSS][BE] Add simulator tests for Metal
Browse files Browse the repository at this point in the history
Pull Request resolved: #64852


ghstack-source-id: 137849299

Differential Revision: [D30877961](https://our.internmc.facebook.com/intern/diff/D30877961/)
  • Loading branch information
xta0 committed Sep 11, 2021
1 parent f75482e commit 059a673
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .circleci/config.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions .circleci/verbatim-sources/job-specs/job-specs-custom.yml
Expand Up @@ -574,6 +574,8 @@
instruments -s -devices
if [ $(BUILD_LITE_INTERPRETER) == 1 ]; then
fastlane scan --only_testing TestAppTests/TestAppTests/testLiteInterpreter
elif [ $(USE_PYTORCH_METAL) == 1 ]; then
fastlane scan --only_testing TestAppTests/TestAppTests/testMetal
else
fastlane scan --only_testing TestAppTests/TestAppTests/testFullJIT
fi
Expand Down
10 changes: 10 additions & 0 deletions ios/TestApp/TestAppTests/TestAppTests.mm
Expand Up @@ -31,4 +31,14 @@ - (void)testFullJIT {
XCTAssertTrue(outputTensor.numel() == 1000);
}

- (void)testMetal {
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_metal"
ofType:@"ptl"];
auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
c10::InferenceMode mode;
auto input = torch::ones({1, 3, 224, 224}, at::kFloat).metal();
auto outputTensor = module.forward({input}).toTensor().cpu();
XCTAssertTrue(outputTensor.numel() == 1000);
}

@end
12 changes: 12 additions & 0 deletions ios/TestApp/benchmark/setup.rb
Expand Up @@ -67,6 +67,7 @@
end
puts "Linking static libraries..."
libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
frameworks = ['CoreML', 'Metal', 'MetalPerformanceShaders', 'Accelerate']
targets.each do |target|
target.frameworks_build_phases.clear
for lib in libs do
Expand All @@ -78,5 +79,16 @@
end
end

# link system frameworks
if frameworks
frameworks.each do |framework|
path = "System/Library/Frameworks/#{framework}.framework"
framework_ref = project.frameworks_group.new_reference(path)
framework_ref.name = "#{framework}.framework"
framework_ref.source_tree = 'SDKROOT'
target.frameworks_build_phases.add_file_reference(framework_ref)
end
end

project.save
puts "Done."
4 changes: 3 additions & 1 deletion ios/TestApp/benchmark/trace_model.py
Expand Up @@ -8,4 +8,6 @@
traced_script_module = torch.jit.trace(model, example)
optimized_scripted_module = optimize_for_mobile(traced_script_module)
torch.jit.save(optimized_scripted_module, '../models/model.pt')
exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter("../models/model_lite.ptl")
optimized_scripted_module._save_for_lite_interpreter("../models/model_lite.ptl")
optimized_scripted_module = optimize_for_mobile(traced_script_module, backend='metal')
optimized_scripted_module._save_for_lite_interpreter("../models/model_metal.ptl")

0 comments on commit 059a673

Please sign in to comment.