3
3
TEST_P (CppAPITests, CompiledModuleIsClose) {
4
4
std::vector<torch::jit::IValue> jit_inputs_ivalues;
5
5
std::vector<torch::jit::IValue> trt_inputs_ivalues;
6
- for (auto in_shape : input_shapes) {
7
- auto in = at::randint (5 , in_shape, {at::kCUDA });
6
+ std::vector<torch_tensorrt::Input> shapes;
7
+ for (uint64_t i = 0 ; i < input_shapes.size (); i++) {
8
+ auto in = at::randint (5 , input_shapes[i], {at::kCUDA }).to (input_types[i]);
8
9
jit_inputs_ivalues.push_back (in.clone ());
9
10
trt_inputs_ivalues.push_back (in.clone ());
11
+ auto in_spec = torch_tensorrt::Input (input_shapes[i]);
12
+ in_spec.dtype = input_types[i];
13
+ shapes.push_back (in_spec);
14
+ std::cout << in_spec << std::endl;
10
15
}
11
16
12
17
torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward (mod, jit_inputs_ivalues);
13
18
std::vector<at::Tensor> jit_results;
14
- jit_results.push_back (jit_results_ivalues.toTensor ());
19
+ if (jit_results_ivalues.isTuple ()) {
20
+ auto tuple = jit_results_ivalues.toTuple ();
21
+ for (auto t : tuple->elements ()) {
22
+ jit_results.push_back (t.toTensor ());
23
+ }
24
+ } else {
25
+ jit_results.push_back (jit_results_ivalues.toTensor ());
26
+ }
27
+
28
+ auto spec = torch_tensorrt::ts::CompileSpec (shapes);
29
+ spec.truncate_long_and_double = true ;
15
30
16
- auto trt_mod = torch_tensorrt::ts::compile (mod, input_shapes );
31
+ auto trt_mod = torch_tensorrt::ts::compile (mod, spec );
17
32
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward (trt_mod, trt_inputs_ivalues);
18
33
std::vector<at::Tensor> trt_results;
19
- trt_results.push_back (trt_results_ivalues.toTensor ());
34
+ if (trt_results_ivalues.isTuple ()) {
35
+ auto tuple = trt_results_ivalues.toTuple ();
36
+ for (auto t : tuple->elements ()) {
37
+ trt_results.push_back (t.toTensor ());
38
+ }
39
+ } else {
40
+ trt_results.push_back (trt_results_ivalues.toTensor ());
41
+ }
20
42
21
43
for (size_t i = 0 ; i < trt_results.size (); i++) {
22
44
ASSERT_TRUE (
@@ -30,13 +52,14 @@ INSTANTIATE_TEST_SUITE_P(
30
52
CompiledModuleForwardIsCloseSuite,
31
53
CppAPITests,
32
54
testing::Values (
33
- PathAndInSize ({" tests/modules/resnet18_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
34
- PathAndInSize({" tests/modules/resnet50_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
35
- PathAndInSize({" tests/modules/mobilenet_v2_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
36
- PathAndInSize({" tests/modules/resnet18_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
37
- PathAndInSize({" tests/modules/resnet50_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
38
- PathAndInSize({" tests/modules/mobilenet_v2_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
39
- PathAndInSize({" tests/modules/efficientnet_b0_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 8e-3 }),
40
- PathAndInSize({" tests/modules/vit_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 8e-2 })));
55
+ PathAndInput ({" tests/modules/resnet18_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
56
+ PathAndInput({" tests/modules/resnet50_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
57
+ PathAndInput({" tests/modules/mobilenet_v2_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
58
+ PathAndInput({" tests/modules/resnet18_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
59
+ PathAndInput({" tests/modules/resnet50_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
60
+ PathAndInput({" tests/modules/mobilenet_v2_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
61
+ PathAndInput({" tests/modules/efficientnet_b0_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 8e-3 }),
62
+ PathAndInput({" tests/modules/bert_base_uncased_traced.jit.pt" , {{1 , 14 }, {1 , 14 }}, {at::kInt , at::kInt }, 8e-2 }),
63
+ PathAndInput({" tests/modules/vit_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 8e-2 })));
41
64
42
65
#endif
0 commit comments