@@ -21,7 +21,7 @@ TEST(Converters, ATenReLUConvertsCorrectly) {
2121 params = trtorch::core::conversion::get_named_params (g->inputs (), {});
2222 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
2323
24- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
24+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
2525}
2626
2727TEST (Converters, ATenSigmoidConvertsCorrectly) {
@@ -41,7 +41,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) {
4141 params = trtorch::core::conversion::get_named_params (g->inputs (), {});
4242 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
4343
44- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
44+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
4545}
4646
4747TEST (Converters, ATenTanhConvertsCorrectly) {
@@ -61,5 +61,51 @@ TEST(Converters, ATenTanhConvertsCorrectly) {
6161 params = trtorch::core::conversion::get_named_params (g->inputs (), {});
6262 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
6363
64- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
64+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
6565}
66+
67+ // TODO: Seems like the IR parser is not handling negative numbers well, need to follow up with the PyTorch Team
68+ // TEST(Converters, ATenHardTanhConvertsCorrectly) {
69+ // const auto graph = R"IR(
70+ // graph(%0 : Tensor):
71+ // %1 : float = prim::Constant[value=-1.0]()
72+ // %2 : float = prim::Constant[value=1.0]()
73+ // %3 : Tensor = aten::hardtanh(%0, %1, %2)
74+ // return (%3))IR";
75+
76+ // auto g = std::make_shared<torch::jit::Graph>();
77+ // torch::jit::script::parseIR(graph, &*g);
78+
79+ // auto in = at::randint(-5, 5, {5}, {at::kCUDA});
80+ // auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
81+ // auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
82+
83+ // in = at::clone(in);
84+ // params = trtorch::core::conversion::get_named_params(g->inputs(), {});
85+ // auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
86+
87+ // ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
88+ // }
89+
90+ TEST (Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
91+ const auto graph = R"IR(
92+ graph(%0 : Tensor):
93+ %1 : float = prim::Constant[value=0.0]()
94+ %2 : float = prim::Constant[value=6.0]()
95+ %3 : Tensor = aten::hardtanh(%0, %1, %2)
96+ return (%3))IR" ;
97+
98+ auto g = std::make_shared<torch::jit::Graph>();
99+ torch::jit::script::parseIR (graph, &*g);
100+
101+ auto in = at::randint (-5 , 5 , {5 }, {at::kCUDA });
102+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
103+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
104+
105+ in = at::clone (in);
106+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
107+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
108+
109+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
110+ }
111+
0 commit comments