Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TSNE and UMAP allow several distance types #4779

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f8a2c86
tsne allow distance types
tarang-jain Jun 14, 2022
dbefb19
Added other distance metrics for UMAP
tarang-jain Jun 16, 2022
d934f1e
Modified UMAPPARAMS
tarang-jain Jun 16, 2022
ce9def6
Restructured tsne code with supported distance metrics
tarang-jain Jun 17, 2022
08c3669
Added minkowski distance parameter p
tarang-jain Jun 22, 2022
81fcd23
Styling fixes
tarang-jain Jun 22, 2022
be0c6cc
Style fixes
tarang-jain Jun 22, 2022
67d1625
styling and metric changes
tarang-jain Jun 23, 2022
453df40
styling fixes (copyright)
tarang-jain Jun 23, 2022
f2be71f
update tsne tests
tarang-jain Jun 23, 2022
8e80c52
Merge branch 'branch-22.08' of github.com:rapidsai/cuml into fea-tsne…
tarang-jain Jun 23, 2022
d3120a5
Re-evaluate supported distance metrics, update tests
tarang-jain Jun 23, 2022
f9ef5df
Update UMAP metric docs
tarang-jain Jul 7, 2022
859ece3
Merge branch 'branch-22.08' of github.com:rapidsai/cuml into fea-tsne…
tarang-jain Jul 12, 2022
a82c081
correction in TSNE gtest
tarang-jain Jul 18, 2022
628c033
Style fix
tarang-jain Jul 18, 2022
e9c9195
Fix failing tests in CI
tarang-jain Jul 23, 2022
f74dc6f
Update documentation based on PR Reviews
tarang-jain Jul 25, 2022
40f742f
Merge branch 'branch-22.08' of github.com:rapidsai/cuml into fea-tsne…
tarang-jain Aug 3, 2022
43b2de5
Updated docs for CI failure
tarang-jain Aug 3, 2022
7286757
Merge branch 'rapidsai:branch-22.08' into fea-tsne-umap-user-configur…
tarang-jain Aug 4, 2022
e31352e
Merge remote-tracking branch 'rapidsai/branch-22.10' into fea-tsne-um…
cjnolet Aug 4, 2022
5a13a36
Merge remote-tracking branch 'rapidsai/branch-22.10' into fea-tsne-um…
cjnolet Aug 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion cpp/include/cuml/manifold/tsne.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
#pragma once

#include <cuml/common/logger.hpp>
#include <raft/distance/distance_type.hpp>

namespace raft {
class handle_t;
Expand Down Expand Up @@ -101,6 +102,12 @@ struct TSNEParams {
// behavior of Scikit-learn's T-SNE.
bool square_distances = true;

// Distance metric to use.
raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded;

// Value of p for Minkowski distance
float p = 2.0;

// Which implementation algorithm to use.
TSNE_ALGORITHM algorithm = TSNE_ALGORITHM::FFT;
};
Expand Down
1 change: 1 addition & 0 deletions cpp/include/cuml/manifold/umap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <cstddef>
#include <cstdint>
#include <cuml/manifold/umapparams.h>
#include <memory>

namespace raft {
Expand Down
7 changes: 6 additions & 1 deletion cpp/include/cuml/manifold/umapparams.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@

#include <cuml/common/callback.hpp>
#include <cuml/common/logger.hpp>
#include <raft/distance/distance_type.hpp>

namespace ML {

Expand Down Expand Up @@ -157,6 +158,10 @@ class UMAPParams {
*/
bool deterministic = true;

raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded;

float p = 2.0;

Internals::GraphBasedDimRedCallback* callback = nullptr;
};

Expand Down
29 changes: 20 additions & 9 deletions cpp/src/tsne/distances.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,31 @@
namespace ML {
namespace TSNE {

auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded;

/**
* @brief Uses FAISS's KNN to find the top n_neighbors. This speeds up the attractive forces.
* @param[in] input: dense/sparse manifold input
* @param[out] indices: The output indices from KNN.
* @param[out] distances: The output sorted distances from KNN.
* @param[in] n_neighbors: The number of nearest neighbors you want.
* @param[in] stream: The GPU stream.
* @param[in] metric: The distance metric.
*/
template <typename tsne_input, typename value_idx, typename value_t>
void get_distances(const raft::handle_t& handle,
tsne_input& input,
knn_graph<value_idx, value_t>& k_graph,
cudaStream_t stream);
cudaStream_t stream,
raft::distance::DistanceType metric,
value_t p);

// dense, int64 indices
template <>
void get_distances(const raft::handle_t& handle,
manifold_dense_inputs_t<float>& input,
knn_graph<int64_t, float>& k_graph,
cudaStream_t stream)
cudaStream_t stream,
raft::distance::DistanceType metric,
float p)
{
// TODO: for TSNE transform first fit some points then transform with 1/(1+d^2)
// #861
Expand Down Expand Up @@ -89,15 +92,18 @@ void get_distances(const raft::handle_t& handle,
true,
true,
nullptr,
DEFAULT_DISTANCE_METRIC);
metric,
p);
}

// dense, int32 indices
template <>
void get_distances(const raft::handle_t& handle,
manifold_dense_inputs_t<float>& input,
knn_graph<int, float>& k_graph,
cudaStream_t stream)
cudaStream_t stream,
raft::distance::DistanceType metric,
float p)
{
throw raft::exception("Dense TSNE does not support 32-bit integer indices yet.");
}
Expand All @@ -107,7 +113,9 @@ template <>
void get_distances(const raft::handle_t& handle,
manifold_sparse_inputs_t<int, float>& input,
knn_graph<int, float>& k_graph,
cudaStream_t stream)
cudaStream_t stream,
raft::distance::DistanceType metric,
float p)
{
raft::sparse::selection::brute_force_knn(input.indptr,
input.indices,
Expand All @@ -127,15 +135,18 @@ void get_distances(const raft::handle_t& handle,
handle,
ML::Sparse::DEFAULT_BATCH_SIZE,
ML::Sparse::DEFAULT_BATCH_SIZE,
DEFAULT_DISTANCE_METRIC);
metric,
p);
}

// sparse, int64
template <>
void get_distances(const raft::handle_t& handle,
manifold_sparse_inputs_t<int64_t, float>& input,
knn_graph<int64_t, float>& k_graph,
cudaStream_t stream)
cudaStream_t stream,
raft::distance::DistanceType metric,
float p)
{
throw raft::exception("Sparse TSNE does not support 64-bit integer indices yet.");
}
Expand Down
1 change: 1 addition & 0 deletions cpp/src/tsne/tsne.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "tsne_runner.cuh"
#include <cuml/manifold/tsne.h>
#include <raft/distance/distance_type.hpp>

namespace ML {

Expand Down
3 changes: 2 additions & 1 deletion cpp/src/tsne/tsne_runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cuml/common/logger.hpp>
#include <cuml/manifold/common.hpp>
#include <raft/cudart_utils.h>
#include <raft/distance/distance_type.hpp>
#include <rmm/device_uvector.hpp>

#include <thrust/transform.h>
Expand Down Expand Up @@ -119,7 +120,7 @@ class TSNE_runner {
k_graph.knn_indices = indices.data();
k_graph.knn_dists = distances.data();

TSNE::get_distances(handle, input, k_graph, stream);
TSNE::get_distances(handle, input, k_graph, stream, params.metric, params.p);
}

if (params.square_distances) {
Expand Down
7 changes: 5 additions & 2 deletions cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ void launcher(const raft::handle_t& handle,
inputsB.n,
out.knn_indices,
out.knn_dists,
n_neighbors);
n_neighbors,
params->metric,
params->p);
}

// Instantiation for dense inputs, int indices
Expand Down Expand Up @@ -112,7 +114,8 @@ void launcher(const raft::handle_t& handle,
handle,
ML::Sparse::DEFAULT_BATCH_SIZE,
ML::Sparse::DEFAULT_BATCH_SIZE,
raft::distance::DistanceType::L2Expanded);
params->metric,
params->p);
}

template <>
Expand Down
6 changes: 5 additions & 1 deletion cpp/test/sg/tsne_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cuml/manifold/tsne.h>
#include <cuml/metrics/metrics.hpp>
#include <raft/distance/distance_type.hpp>
#include <raft/linalg/map.hpp>

#include <cuml/common/logger.hpp>
Expand Down Expand Up @@ -109,6 +110,9 @@ class TSNETest : public ::testing::TestWithParam<TSNEInput> {
auto stream = handle.get_stream();
TSNEResults results;

auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded;
float minkowski_p = 2.0;

// Setup parameters
model_params.algorithm = algo;
model_params.dim = 2;
Expand All @@ -133,7 +137,7 @@ class TSNETest : public ::testing::TestWithParam<TSNEInput> {
input_dists.resize(n * model_params.n_neighbors, stream);
k_graph.knn_indices = input_indices.data();
k_graph.knn_dists = input_dists.data();
TSNE::get_distances(handle, input, k_graph, stream);
TSNE::get_distances(handle, input, k_graph, stream, DEFAULT_DISTANCE_METRIC, minkowski_p);
}
handle.sync_stream(stream);
TSNE_runner<manifold_dense_inputs_t<float>, knn_indices_dense_t, float> runner(
Expand Down
50 changes: 41 additions & 9 deletions python/cuml/manifold/t_sne.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ from cuml.common.doc_utils import generate_docstring
from cuml.common import input_to_cuml_array
from cuml.common.mixins import CMajorInputTagMixin
from cuml.common.sparsefuncs import extract_knn_graph
from cuml.metrics.distance_type cimport DistanceType
import rmm

from libcpp cimport bool
Expand Down Expand Up @@ -79,6 +80,8 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML":
int verbosity,
bool initialize_embeddings,
bool square_distances,
DistanceType metric,
float p,
TSNE_ALGORITHM algorithm


Expand Down Expand Up @@ -151,9 +154,10 @@ class TSNE(Base,
Used in the 'exact' and 'fft' algorithms. Consider reducing if
the embeddings are unsatisfactory. It's recommended to use a
smaller value for smaller datasets.
metric : str 'euclidean' only (default 'euclidean')
Currently only supports euclidean distance. Will support cosine in
a future release.
metric : str (default='euclidean').
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a disclaimer here and explicitly point out the square_distances argument. The math in the base TSNE algorithm itself assumes the distances can be squared (eg that Euclidean is used by default, which then becomes sqeuclidean) during the loss computation. We want to make sure users know that if they are using a different distance, they will likely want to set the square_distance argument to false.

We probably want to document this but also provide a warning when a distance other Euclidean is used so that the users know to turn it off. We probably also want to turn this off in the pytests for all distances other than Euclidean.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet sklearn have deprecated their square_distances argument and in the docs, they state that distances are always squared. Should we do something similar?

Distance metric to use. Supported distances are ['l1, 'cityblock',
'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'minkowski',
'chebyshev', 'cosine', 'correlation']
init : str 'random' (default 'random')
Currently supports random intialization.
verbose : int or boolean, default=False
Expand Down Expand Up @@ -189,11 +193,13 @@ class TSNE(Base,
During the late phases, less forcefully apply gradients.
square_distances : boolean, default=True
Whether TSNE should square the distance values.
Internally, this will be used to compute a kNN graph using 'euclidean'
Internally, this will be used to compute a kNN graph using the provided
metric and then squaring it when True. If a `knn_graph` is passed
to `fit` or `fit_transform` methods, all the distances will be
squared when True. For example, if a `knn_graph` was obtained using
'sqeuclidean' metric, the distances will still be squared when True.
Note: This argument should likely be set to False for distance metrics
other than 'euclidean' and 'l2'.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
Expand Down Expand Up @@ -259,6 +265,7 @@ class TSNE(Base,
n_iter_without_progress=300,
min_grad_norm=1e-07,
metric='euclidean',
metric_params=None,
init='random',
verbose=False,
random_state=None,
Expand Down Expand Up @@ -302,11 +309,6 @@ class TSNE(Base,
if n_iter <= 100:
warnings.warn("n_iter = {} might cause TSNE to output wrong "
"results. Set it higher.".format(n_iter))
if metric.lower() != 'euclidean':
# TODO https://github.com/rapidsai/cuml/issues/1653
warnings.warn("TSNE does not support {} (only Euclidean).".format(
metric))
metric = 'euclidean'
if init.lower() != 'random':
# TODO https://github.com/rapidsai/cuml/issues/3458
warnings.warn("TSNE does not support {} but only random "
Expand Down Expand Up @@ -353,6 +355,7 @@ class TSNE(Base,
self.n_iter_without_progress = n_iter_without_progress
self.min_grad_norm = min_grad_norm
self.metric = metric
self.metric_params = metric_params
self.init = init
self.random_state = random_state
self.method = method
Expand Down Expand Up @@ -426,6 +429,7 @@ class TSNE(Base,
convert_format=False)
n, p = self.X_m.shape
self.sparse_fit = True

# Handle dense inputs
else:
self.X_m, n, p, _ = \
Expand Down Expand Up @@ -497,6 +501,7 @@ class TSNE(Base,
self._build_tsne_params(algo)

cdef float kl_divergence = 0

if self.sparse_fit:
TSNE_fit_sparse(handle_[0],
<int*><uintptr_t>
Expand Down Expand Up @@ -583,6 +588,32 @@ class TSNE(Base,
params.initialize_embeddings = <bool> True
params.square_distances = <bool> self.square_distances
params.algorithm = algo

# metric
metric_parsing = {
"l2": DistanceType.L2SqrtExpanded,
"euclidean": DistanceType.L2SqrtExpanded,
"sqeuclidean": DistanceType.L2Expanded,
"cityblock": DistanceType.L1,
"l1": DistanceType.L1,
"manhattan": DistanceType.L1,
"minkowski": DistanceType.LpUnexpanded,
"chebyshev": DistanceType.Linf,
"cosine": DistanceType.CosineExpanded,
"correlation": DistanceType.CorrelationExpanded
}

if self.metric.lower() in metric_parsing:
params.metric = metric_parsing[self.metric.lower()]
else:
raise ValueError("Invalid value for metric: {}"
.format(self.metric))

if self.metric_params is None:
params.p = <float> 2.0
else:
params.p = <float>self.metric_params.get('p')

return <size_t> params

@property
Expand Down Expand Up @@ -625,6 +656,7 @@ class TSNE(Base,
"n_iter_without_progress",
"min_grad_norm",
"metric",
"metric_params",
"init",
"random_state",
"method",
Expand Down