Skip to content

Commit

Permalink
Fix logic for concatenating Treelite objects (#5387)
Browse files Browse the repository at this point in the history
Closes #5359

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)
  - William Hicks (https://github.com/wphicks)

Approvers:
  - William Hicks (https://github.com/wphicks)

URL: #5387
  • Loading branch information
hcho3 committed May 16, 2023
1 parent b03f9f1 commit 7af44d4
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions cpp/src/randomforest/randomforest.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -396,21 +396,25 @@ void compare_concat_forest_to_subforests(ModelHandle concat_tree_handle,
*/
ModelHandle concatenate_trees(std::vector<ModelHandle> treelite_handles)
{
tl::Model& first_model = *(tl::Model*)treelite_handles[0];
/* TODO(hcho3): Use treelite::ConcatenateModelObjects(),
once https://github.com/dmlc/treelite/issues/474 is fixed. */
if (treelite_handles.empty()) { return nullptr; }
tl::Model& first_model = *static_cast<tl::Model*>(treelite_handles[0]);
tl::Model* concat_model = first_model.Dispatch([&treelite_handles](auto& first_model_inner) {
// first_model_inner is of the concrete type tl::ModelImpl<T, L>
using model_type = std::remove_reference_t<decltype(first_model_inner)>;
auto* concat_model = dynamic_cast<model_type*>(
tl::Model::Create(first_model_inner.GetThresholdType(), first_model_inner.GetLeafOutputType())
.release());
for (std::size_t forest_idx = 0; forest_idx < treelite_handles.size(); forest_idx++) {
tl::Model& model = *(tl::Model*)treelite_handles[forest_idx];
tl::Model& model = *static_cast<tl::Model*>(treelite_handles[forest_idx]);
auto& model_inner = dynamic_cast<model_type&>(model);
for (const auto& tree : model_inner.trees) {
concat_model->trees.push_back(tree.Clone());
}
}
concat_model->num_feature = first_model_inner.num_feature;
concat_model->task_type = first_model_inner.task_type;
concat_model->task_param = first_model_inner.task_param;
concat_model->average_tree_output = first_model_inner.average_tree_output;
concat_model->param = first_model_inner.param;
Expand Down

0 comments on commit 7af44d4

Please sign in to comment.