Skip to content

Commit

Permalink
Merge pull request mlpack#1519 from chigur/master
Browse files Browse the repository at this point in the history
Replace copy and move overloads with pass-by-value for RASearch class
  • Loading branch information
rcurtin committed Oct 14, 2018
2 parents eee326b + 6209c34 commit a7e9cb5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 146 deletions.
83 changes: 9 additions & 74 deletions src/mlpack/methods/rann/ra_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ class RASearch
* distance::MahalanobisDistance class).
*
* This method will copy the matrices to internal copies, which are rearranged
* during tree-building. You can avoid this extra copy by pre-constructing
* the trees and using the appropriate constructor, or by using the
* constructor that takes an rvalue reference to the data with std::move().
* during tree-building. If you don't need to keep the reference dataset,
* you can use std::move() to remove the overhead of making copies. Using
* std::move() transfers the ownership of the dataset.
*
* tau, the rank-approximation parameter, specifies that we are looking for k
* neighbors with probability alpha of being in the top tau percent of nearest
Expand Down Expand Up @@ -119,61 +119,7 @@ class RASearch
* @param singleSampleLimit The limit on the largest node that can be
* approximated by sampling. This defaults to 20.
*/
RASearch(const MatType& referenceSet,
const bool naive = false,
const bool singleMode = false,
const double tau = 5,
const double alpha = 0.95,
const bool sampleAtLeaves = false,
const bool firstLeafExact = false,
const size_t singleSampleLimit = 20,
const MetricType metric = MetricType());

/**
* Initialize the RASearch object, passing both a reference dataset (this is
* the dataset that will be searched). Optionally, perform the computation in
* naive mode or single-tree mode. An initialized distance metric can be
* given, for cases where the metric has internal data (i.e. the
* distance::MahalanobisDistance class).
*
* This method will take ownership of the given reference set, avoiding a
* copy. If you need to use the reference set for other purposes, too,
* consider using the constructor that takes a const reference.
*
* tau, the rank-approximation parameter, specifies that we are looking for k
* neighbors with probability alpha of being in the top tau percent of nearest
* neighbors. So, as an example, if our dataset has 1000 points, and we want
* 5 nearest neighbors with 95% probability of being in the top 5% of nearest
* neighbors (or, the top 50 nearest neighbors), we set k = 5, tau = 5, and
* alpha = 0.95.
*
* The method will fail (and throw a std::invalid_argument exception) if the
* value of tau is too low: tau must be set such that the number of points in
* the corresponding percentile of the data is greater than k. Thus, if we
* choose tau = 0.1 with a dataset of 1000 points and k = 5, then we are
* attempting to choose 5 nearest neighbors out of the closest 1 point -- this
* is invalid.
*
* @param referenceSet Set of reference points.
* @param naive If true, the rank-approximate search will be performed by
* directly sampling the whole set instead of using the stratified
* sampling on the tree.
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search). This is useful when Search() will be called with
* few query points.
* @param metric An optional instance of the MetricType class.
* @param tau The rank-approximation in percentile of the data. The default
* value is 5%.
* @param alpha The desired success probability. The default value is 0.95.
* @param sampleAtLeaves Sample at leaves for faster but less accurate
* computation. This defaults to 'false'.
* @param firstLeafExact Traverse to the first leaf without approximation.
* This can ensure that the query definitely finds its (near) duplicate
* if there exists one. This defaults to 'false' for now.
* @param singleSampleLimit The limit on the largest node that can be
* approximated by sampling. This defaults to 20.
*/
RASearch(MatType&& referenceSet,
RASearch(MatType referenceSet,
const bool naive = false,
const bool singleMode = false,
const double tau = 5,
Expand Down Expand Up @@ -276,25 +222,14 @@ class RASearch

/**
* "Train" the model on the given reference set. If tree-based search is
* being used (if Naive() is false), this means rebuilding the reference tree.
* This particular method will make a copy of the given reference data. To
* avoid that copy, use the Train() method that takes an rvalue reference with
* std::move().
*
* @param referenceSet New reference set to use.
*/
void Train(const MatType& referenceSet);

/**
* "Train" the model on the given reference set, taking ownership of the data
* matrix. If tree-based search is being used (if Naive() is false), this
* also means rebuilding the reference tree. If you need to keep a copy of
* the reference data, use the Train() method that takes a const reference to
* the data.
* being used (if Naive() is false), the reference tree is rebuilt. Thus, a
* copy of the reference dataset is made. If you don't need to keep the
* dataset, you can avoid copying by using std::move(). This transfers the
* ownership of the dataset.
*
* @param referenceSet New reference set to use.
*/
void Train(MatType&& referenceSet);
void Train(MatType referenceSet);

/**
* Set the reference tree to a new reference tree.
Expand Down
74 changes: 2 additions & 72 deletions src/mlpack/methods/rann/ra_search_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,40 +46,6 @@ TreeType* BuildTree(

} // namespace aux

// Construct the object.
template<typename SortPolicy,
typename MetricType,
typename MatType,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
RASearch<SortPolicy, MetricType, MatType, TreeType>::
RASearch(const MatType& referenceSetIn,
const bool naive,
const bool singleMode,
const double tau,
const double alpha,
const bool sampleAtLeaves,
const bool firstLeafExact,
const size_t singleSampleLimit,
const MetricType metric) :
referenceTree(naive ? NULL : aux::BuildTree<Tree>(
const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
treeOwner(!naive),
setOwner(false),
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
tau(tau),
alpha(alpha),
sampleAtLeaves(sampleAtLeaves),
firstLeafExact(firstLeafExact),
singleSampleLimit(singleSampleLimit),
metric(metric)
{
// Nothing to do.
}

// Construct the object, taking ownership of the data matrix.
template<typename SortPolicy,
typename MetricType,
Expand All @@ -88,7 +54,7 @@ template<typename SortPolicy,
typename TreeStatType,
typename TreeMatType> class TreeType>
RASearch<SortPolicy, MetricType, MatType, TreeType>::
RASearch(MatType&& referenceSetIn,
RASearch(MatType referenceSetIn,
const bool naive,
const bool singleMode,
const double tau,
Expand Down Expand Up @@ -210,43 +176,7 @@ template<typename SortPolicy,
typename TreeStatType,
typename TreeMatType> class TreeType>
void RASearch<SortPolicy, MetricType, MatType, TreeType>::Train(
const MatType& referenceSet)
{
// Clean up the old tree, if we built one.
if (treeOwner && referenceTree)
delete referenceTree;

// We may need to rebuild the tree.
if (!naive)
{
referenceTree = aux::BuildTree<Tree>(referenceSet, oldFromNewReferences);
treeOwner = true;
}
else
{
treeOwner = false;
}

// Delete the old reference set, if we owned it.
if (setOwner && this->referenceSet)
delete this->referenceSet;

if (!naive)
this->referenceSet = &referenceTree->Dataset();
else
this->referenceSet = &referenceSet;
setOwner = false; // We don't own the set in either case.
}

// Train on a new reference set.
template<typename SortPolicy,
typename MetricType,
typename MatType,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
void RASearch<SortPolicy, MetricType, MatType, TreeType>::Train(
MatType&& referenceSet)
MatType referenceSet)
{
// Clean up the old tree, if we built one.
if (treeOwner && referenceTree)
Expand Down

0 comments on commit a7e9cb5

Please sign in to comment.