Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/onnxruntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,36 @@ ModelState::AutoCompleteConfig()
OrtAllocator* default_allocator;
std::string model_path;
{
TRITONSERVER_InstanceGroupKind kind = TRITONSERVER_INSTANCEGROUPKIND_CPU;

#ifdef TRITON_ENABLE_GPU
triton::common::TritonJson::Value instance_group;
ModelConfig().Find("instance_group", &instance_group);

// Earlier in the model lifecycle, device checks for the instance group
// have already occurred. If at least one instance group with
// "kind" = "KIND_GPU" then allow model to use GPU else autocomplete to
// "KIND_CPU"
for (size_t i = 0; i < instance_group.ArraySize(); ++i) {
triton::common::TritonJson::Value instance_obj;
instance_group.IndexAsObject(i, &instance_obj);

triton::common::TritonJson::Value instance_group_kind;
instance_obj.Find("kind", &instance_group_kind);
std::string kind_str;
RETURN_IF_ERROR(instance_group_kind.AsString(&kind_str));

if (kind_str == "KIND_GPU") {
kind = TRITONSERVER_INSTANCEGROUPKIND_GPU;
break;
}
}
#endif // TRITON_ENABLE_GPU

OrtSession* sptr = nullptr;
RETURN_IF_ERROR(LoadModel(
artifact_name, TRITONSERVER_INSTANCEGROUPKIND_AUTO, 0, &model_path,
&sptr, &default_allocator, nullptr));
artifact_name, kind, 0, &model_path, &sptr, &default_allocator,
nullptr));
session.reset(sptr);
}
OnnxTensorInfoMap input_tensor_infos;
Expand Down