diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index ea86624..397d439 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -682,10 +682,36 @@ ModelState::AutoCompleteConfig() OrtAllocator* default_allocator; std::string model_path; { + TRITONSERVER_InstanceGroupKind kind = TRITONSERVER_INSTANCEGROUPKIND_CPU; + +#ifdef TRITON_ENABLE_GPU + triton::common::TritonJson::Value instance_group; + ModelConfig().Find("instance_group", &instance_group); + + // Earlier in the model lifecycle, device checks for the instance group + // have already occurred. If at least one instance group with + // "kind" = "KIND_GPU" then allow model to use GPU else autocomplete to + // "KIND_CPU" + for (size_t i = 0; i < instance_group.ArraySize(); ++i) { + triton::common::TritonJson::Value instance_obj; + instance_group.IndexAsObject(i, &instance_obj); + + triton::common::TritonJson::Value instance_group_kind; + instance_obj.Find("kind", &instance_group_kind); + std::string kind_str; + RETURN_IF_ERROR(instance_group_kind.AsString(&kind_str)); + + if (kind_str == "KIND_GPU") { + kind = TRITONSERVER_INSTANCEGROUPKIND_GPU; + break; + } + } +#endif // TRITON_ENABLE_GPU + OrtSession* sptr = nullptr; RETURN_IF_ERROR(LoadModel( - artifact_name, TRITONSERVER_INSTANCEGROUPKIND_AUTO, 0, &model_path, - &sptr, &default_allocator, nullptr)); + artifact_name, kind, 0, &model_path, &sptr, &default_allocator, + nullptr)); session.reset(sptr); } OnnxTensorInfoMap input_tensor_infos;