-
Notifications
You must be signed in to change notification settings - Fork 371
Closed
Labels
No ActivitybugSomething isn't workingSomething isn't workingbug: triaged [verified]We can replicate the bugWe can replicate the bugcomponent: partitioning
Description
Bug Description
/*
The following test is ambigious and somehow works in TRT 8.2, which might have a bug.
This FP16 model has inputs and weights configured to be FP16 but the builder precision
is set to FP32. So during shape analysis, when the Pyt/TRT segments (are run as pytorch
modules), the inputs of each segments are configured to be FP16 but after TRT conversion
and inference, TRT segments generate float outputs which become float inputs to following
segments. Hence type check fails during runtime at
https://github.com/pytorch/TensorRT/blob/master/core/runtime/execute_engine.cpp#L91
TO DO: Resolve type system check in partitioning
*/
TEST(Partitioning, ComputeResNet50HalfFallbackGraphCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}
mod.to(torch::kHalf);
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA}).to(torch::kHalf);
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}
auto in_shape = torch_tensorrt::core::ir::Input({1, 3, 224, 224});
in_shape.dtype = nvinfer1::DataType::kHALF;
std::vector<torch_tensorrt::core::ir::Input> input_ranges({in_shape});
auto g = mod.get_method("forward").graph();
torch_tensorrt::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
// Lower threshold because FP16
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-1));
}
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip,libtorch, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
Metadata
Metadata
Labels
No ActivitybugSomething isn't workingSomething isn't workingbug: triaged [verified]We can replicate the bugWe can replicate the bugcomponent: partitioning