Skip to content

Commit

Permalink
extend to to other label types
Browse files Browse the repository at this point in the history
  • Loading branch information
Saurabh7 committed Jun 30, 2016
1 parent aed17e6 commit 9acf8d4
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/shogun/features/DenseFeatures.cpp
Expand Up @@ -642,7 +642,7 @@ CFeatures* CDenseFeatures<ST>::shallow_subset_copy()
CFeatures* shallow_copy_features=nullptr;

SG_SDEBUG("Using underlying feature matrix with %d dimensions and %d feature vectors!\n", num_features, num_vectors);
SGMatrix<ST> shallow_copy_matrix = SGMatrix<ST>(feature_matrix.matrix, num_features, num_vectors, false);
SGMatrix<ST> shallow_copy_matrix(feature_matrix);
shallow_copy_features=new CDenseFeatures<ST>(shallow_copy_matrix);
if (m_subset_stack->has_subsets())
shallow_copy_features->add_subset(m_subset_stack->get_last_subset()->get_subset_idx());
Expand Down
13 changes: 13 additions & 0 deletions src/shogun/labels/BinaryLabels.cpp
Expand Up @@ -138,3 +138,16 @@ void CBinaryLabels::scores_to_probabilities(float64_t a, float64_t b)

SG_DEBUG("leaving CBinaryLabels::scores_to_probabilities()\n")
}

CLabels* CBinaryLabels::shallow_subset_copy()
{
CLabels* shallow_copy_labels=nullptr;
SGVector<float64_t> shallow_copy_vector(m_labels);
shallow_copy_labels=new CBinaryLabels(m_labels.size());

((CDenseLabels*) shallow_copy_labels)->set_labels(shallow_copy_vector);
if (m_subset_stack->has_subsets())
shallow_copy_labels->add_subset(m_subset_stack->get_last_subset()->get_subset_idx());

return shallow_copy_labels;
}
3 changes: 3 additions & 0 deletions src/shogun/labels/BinaryLabels.h
Expand Up @@ -116,6 +116,9 @@ class CBinaryLabels : public CDenseLabels
{
return "BinaryLabels";
}

virtual CLabels* shallow_subset_copy();

};
}
#endif
2 changes: 1 addition & 1 deletion src/shogun/labels/MulticlassLabels.cpp
Expand Up @@ -138,7 +138,7 @@ int32_t CMulticlassLabels::get_num_classes()
CLabels* CMulticlassLabels::shallow_subset_copy()
{
CLabels* shallow_copy_labels=nullptr;
SGVector<float64_t> shallow_copy_vector = SGVector<float64_t>(m_labels.vector, m_labels.size(), false);
SGVector<float64_t> shallow_copy_vector(m_labels);
shallow_copy_labels=new CMulticlassLabels(m_labels.size());

((CDenseLabels*) shallow_copy_labels)->set_labels(shallow_copy_vector);
Expand Down
12 changes: 12 additions & 0 deletions src/shogun/labels/RegressionLabels.cpp
Expand Up @@ -25,3 +25,15 @@ ELabelType CRegressionLabels::get_label_type() const
return LT_REGRESSION;
}

CLabels* CRegressionLabels::shallow_subset_copy()
{
CLabels* shallow_copy_labels=nullptr;
SGVector<float64_t> shallow_copy_vector(m_labels);
shallow_copy_labels=new CRegressionLabels(m_labels.size());

((CDenseLabels*) shallow_copy_labels)->set_labels(shallow_copy_vector);
if (m_subset_stack->has_subsets())
shallow_copy_labels->add_subset(m_subset_stack->get_last_subset()->get_subset_idx());

return shallow_copy_labels;
}
2 changes: 2 additions & 0 deletions src/shogun/labels/RegressionLabels.h
Expand Up @@ -65,6 +65,8 @@ class CRegressionLabels : public CDenseLabels

/** @return object name */
virtual const char* get_name() const { return "RegressionLabels"; }

virtual CLabels* shallow_subset_copy();
};
}
#endif
7 changes: 4 additions & 3 deletions src/shogun/multiclass/tree/CARTree.cpp
Expand Up @@ -603,6 +603,8 @@ int32_t CCARTree::compute_best_attribute(const SGMatrix<float64_t>& mat, const S
{
SGVector<float64_t> feats(num_vecs);
SGVector<index_t> sorted_args(num_vecs);
SGVector<int32_t> temp_count_indices(count_indices.size());
memcpy(temp_count_indices.vector, count_indices.vector, sizeof(int32_t)*count_indices.size());

if (m_pre_sort)
{
Expand All @@ -613,13 +615,12 @@ int32_t CCARTree::compute_best_attribute(const SGMatrix<float64_t>& mat, const S
{
if (indices_mask[sorted_indices[j]]>=0)
{
int32_t count_idx = count_indices[sorted_indices[j]];
while(count_idx>0)
while(temp_count_indices[sorted_indices[j]]>0)
{
feats[count]=temp_col[j];
sorted_args[count]=indices_mask[sorted_indices[j]];
++count;
--count_idx;
--temp_count_indices[sorted_indices[j]];
}
if (count==num_vecs)
break;
Expand Down

0 comments on commit 9acf8d4

Please sign in to comment.