Skip to content

Commit

Permalink
[FEA] Support multiple classes in multi-node-multi-gpu logistic regre…
Browse files Browse the repository at this point in the history
…ssion, from C++, Cython, to Dask Python class (#5565)

Github issue: #5501

This PR depends on and has included [PR 5558](#5558).

Authors:
  - Jinfeng Li (https://github.com/lijinf2)

Approvers:
  - Simon Adorf (https://github.com/csadorf)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5565
  • Loading branch information
lijinf2 committed Sep 29, 2023
1 parent 9c18259 commit 3c4ceb9
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 22 deletions.
12 changes: 12 additions & 0 deletions cpp/include/cuml/linear_model/qn_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,24 @@

#include <cumlprims/opg/matrix/data.hpp>
#include <cumlprims/opg/matrix/part_descriptor.hpp>
#include <vector>
using namespace MLCommon;

namespace ML {
namespace GLM {
namespace opg {

/**
* @brief Calculate unique class labels across multiple GPUs in a multi-node environment.
* @param[in] handle: the internal cuml handle object
* @param[in] input_desc: PartDescriptor object for the input
* @param[in] labels: labels data
* @returns host vector that stores the distinct labels
*/
std::vector<float> getUniquelabelsMG(const raft::handle_t& handle,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels);

/**
* @brief performs MNMG fit operation for the logistic regression using quasi newton methods
* @param[in] handle: the internal cuml handle object
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/glm/qn/mg/qn_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ inline void qn_fit_x_mg(const raft::handle_t& handle,
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
} break;
case QN_LOSS_SOFTMAX: {
ASSERT(C > 2, "qn_mg.cuh: softmax invalid C");
ML::GLM::detail::Softmax<T> loss(handle, D, C, pams.fit_intercept);
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
} break;
default: {
ASSERT(false, "qn_mg.cuh: unknown loss function type (id = %d).", pams.loss);
}
Expand Down
66 changes: 55 additions & 11 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,59 @@
#include <cuml/linear_model/qn.h>
#include <cuml/linear_model/qn_mg.hpp>
#include <raft/core/comms.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
#include <raft/core/handle.hpp>
#include <raft/label/classlabels.cuh>
#include <raft/util/cudart_utils.hpp>
#include <vector>
using namespace MLCommon;

namespace ML {
namespace GLM {
namespace opg {

template <typename T>
std::vector<T> distinct_mg(const raft::handle_t& handle, T* y, size_t n)
{
cudaStream_t stream = handle.get_stream();
raft::comms::comms_t const& comm = raft::resource::get_comms(handle);
int rank = comm.get_rank();
int n_ranks = comm.get_size();

rmm::device_uvector<T> unique_y(0, stream);
raft::label::getUniquelabels(unique_y, y, n, stream);

rmm::device_uvector<size_t> recv_counts(n_ranks, stream);
auto send_count = raft::make_device_scalar<size_t>(handle, unique_y.size());
comm.allgather(send_count.data_handle(), recv_counts.data(), 1, stream);
comm.sync_stream(stream);

std::vector<size_t> recv_counts_host(n_ranks);
raft::copy(recv_counts_host.data(), recv_counts.data(), n_ranks, stream);

std::vector<size_t> displs(n_ranks);
size_t pos = 0;
for (int i = 0; i < n_ranks; ++i) {
displs[i] = pos;
pos += recv_counts_host[i];
}

rmm::device_uvector<T> recv_buff(displs.back() + recv_counts_host.back(), stream);
comm.allgatherv(
unique_y.data(), recv_buff.data(), recv_counts_host.data(), displs.data(), stream);
comm.sync_stream(stream);

rmm::device_uvector<T> global_unique_y(0, stream);
int n_distinct =
raft::label::getUniquelabels(global_unique_y, recv_buff.data(), recv_buff.size(), stream);

std::vector<T> global_unique_y_host(global_unique_y.size());
raft::copy(global_unique_y_host.data(), global_unique_y.data(), global_unique_y.size(), stream);

return global_unique_y_host;
}

template <typename T>
void qnFit_impl(const raft::handle_t& handle,
const qn_params& pams,
Expand All @@ -46,17 +90,6 @@ void qnFit_impl(const raft::handle_t& handle,
int rank,
int n_ranks)
{
switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
RAFT_EXPECTS(
C == 2,
"qn_mg.cu: only the LOGISTIC loss is supported currently. The number of classes must be 2");
} break;
default: {
RAFT_EXPECTS(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss);
}
}

auto X_simple = SimpleDenseMat<T>(X, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR);

ML::GLM::opg::qn_fit_x_mg(handle,
Expand Down Expand Up @@ -113,6 +146,17 @@ void qnFit_impl(raft::handle_t& handle,
input_desc.uniqueRanks().size());
}

std::vector<float> getUniquelabelsMG(const raft::handle_t& handle,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels)
{
RAFT_EXPECTS(labels.size() == 1,
"getUniqueLabelsMG currently does not accept more than one data chunk");
Matrix::Data<float>* data_y = labels[0];
int n_rows = input_desc.totalElementsOwnedBy(input_desc.rank);
return distinct_mg<float>(handle, data_y->ptr, n_rows);
}

void qnFit(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_data,
Matrix::PartDescriptor& input_desc,
Expand Down
9 changes: 8 additions & 1 deletion python/cuml/dask/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,11 @@ def _create_model(sessionId, datatype, **kwargs):
def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank):
inp_X = concatenate([X for X, _ in data])
inp_y = concatenate([y for _, y in data])
return f.fit([(inp_X, inp_y)], n_rows, n_cols, partsToSizes, rank)
n_ranks = max([p[0] for p in partsToSizes]) + 1
aggregated_partsToSizes = [[i, 0] for i in range(n_ranks)]
for p in partsToSizes:
aggregated_partsToSizes[p[0]][1] += p[1]

return f.fit(
[(inp_X, inp_y)], n_rows, n_cols, aggregated_partsToSizes, rank
)
28 changes: 23 additions & 5 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,18 @@ cdef extern from "cuml/linear_model/qn_mg.hpp" namespace "ML::GLM::opg" nogil:
float *f,
int *num_iters) except +

cdef vector[float] getUniquelabelsMG(
const handle_t& handle,
PartDescriptor &input_desc,
vector[floatData_t*] labels) except+


class LogisticRegressionMG(MGFitMixin, LogisticRegression):

def __init__(self, **kwargs):
super(LogisticRegressionMG, self).__init__(**kwargs)
if self.penalty != "l2" and self.penalty != "none":
assert False, "Currently only support 'l2' and 'none' penalty"

@property
@cuml.internals.api_base_return_array_skipall
Expand All @@ -102,8 +109,8 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):

self.solver_model.coef_ = value

def prepare_for_fit(self, n_classes):
self.solver_model.qnparams = QNParams(
def create_qnparams(self):
return QNParams(
loss=self.loss,
penalty_l1=self.l1_strength,
penalty_l2=self.l2_strength,
Expand All @@ -118,8 +125,11 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
penalty_normalized=self.penalty_normalized
)

def prepare_for_fit(self, n_classes):
self.solver_model.qnparams = self.create_qnparams()

# modified
qnpams = self.qnparams.params
qnpams = self.solver_model.qnparams.params

# modified qnp
solves_classification = qnpams['loss'] in {
Expand Down Expand Up @@ -174,8 +184,14 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
cdef float objective32
cdef int num_iters

# TODO: calculate _num_classes at runtime
self._num_classes = 2
cdef vector[float] c_classes_
c_classes_ = getUniquelabelsMG(
handle_[0],
deref(<PartDescriptor*><uintptr_t>input_desc),
deref(<vector[floatData_t*]*><uintptr_t>y))
self.classes_ = np.sort(list(c_classes_)).astype('float32')

self._num_classes = len(self.classes_)
self.loss = "sigmoid" if self._num_classes <= 2 else "softmax"
self.prepare_for_fit(self._num_classes)
cdef uintptr_t mat_coef_ptr = self.coef_.ptr
Expand All @@ -194,6 +210,8 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
self._num_classes,
<float*> &objective32,
<int*> &num_iters)
else:
assert False, "dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589"

self.solver_model._calc_intercept()

Expand Down
102 changes: 97 additions & 5 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ def _prep_training_data(c, X_train, y_train, partitions_per_worker):
return X_train_df, y_train_df


def make_classification_dataset(datatype, nrows, ncols, n_info):
def make_classification_dataset(datatype, nrows, ncols, n_info, n_classes=2):
X, y = make_classification(
n_samples=nrows, n_features=ncols, n_informative=n_info, random_state=0
n_samples=nrows,
n_features=ncols,
n_informative=n_info,
n_classes=n_classes,
random_state=0,
)
X = X.astype(datatype)
y = y.astype(datatype)
Expand Down Expand Up @@ -176,6 +180,16 @@ def imp():

assert_array_equal(preds, y, strict=True)

# assert error on float64
X = X.astype(np.float64)
y = y.astype(np.float64)
X_df, y_df = _prep_training_data(client, X, y, n_parts)
with pytest.raises(
RuntimeError,
match="dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589",
):
lr.fit(X_df, y_df)


def test_lbfgs_init(client):
def imp():
Expand Down Expand Up @@ -267,6 +281,7 @@ def test_lbfgs(
delayed,
client,
penalty="l2",
n_classes=2,
):
tolerance = 0.005

Expand All @@ -283,7 +298,9 @@ def imp():
n_info = 5
nrows = int(nrows)
ncols = int(ncols)
X, y = make_classification_dataset(datatype, nrows, ncols, n_info)
X, y = make_classification_dataset(
datatype, nrows, ncols, n_info, n_classes=n_classes
)

X_df, y_df = _prep_training_data(client, X, y, n_parts)

Expand All @@ -303,12 +320,13 @@ def imp():
assert lr_intercept == pytest.approx(sk_intercept, abs=tolerance)

# test predict
cu_preds = lr.predict(X_df, delayed=delayed)
accuracy_cuml = accuracy_score(y, cu_preds.compute().to_numpy())
cu_preds = lr.predict(X_df, delayed=delayed).compute().to_numpy()
accuracy_cuml = accuracy_score(y, cu_preds)

sk_preds = sk_model.predict(X)
accuracy_sk = accuracy_score(y, sk_preds)

assert len(cu_preds) == len(sk_preds)
assert (accuracy_cuml >= accuracy_sk) | (
np.abs(accuracy_cuml - accuracy_sk) < 1e-3
)
Expand Down Expand Up @@ -336,3 +354,77 @@ def test_noreg(fit_intercept, client):
l1_strength, l2_strength = lr._get_qn_params()
assert l1_strength == 0.0
assert l2_strength == 0.0


def test_n_classes_small(client):
def assert_small(X, y, n_classes):
X_df, y_df = _prep_training_data(client, X, y, partitions_per_worker=1)
from cuml.dask.linear_model import LogisticRegression as cumlLBFGS_dask

lr = cumlLBFGS_dask()
lr.fit(X_df, y_df)
assert lr._num_classes == n_classes
return lr

X = np.array([(1, 2), (1, 3)], np.float32)
y = np.array([1.0, 0.0], np.float32)
lr = assert_small(X=X, y=y, n_classes=2)
assert np.array_equal(
lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32)
)

X = np.array([(1, 2), (1, 3), (1, 2.5)], np.float32)
y = np.array([1.0, 0.0, 1.0], np.float32)
lr = assert_small(X=X, y=y, n_classes=2)
assert np.array_equal(
lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32)
)

X = np.array([(1, 2), (1, 2.5), (1, 3)], np.float32)
y = np.array([1.0, 1.0, 0.0], np.float32)
lr = assert_small(X=X, y=y, n_classes=2)
assert np.array_equal(
lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32)
)

X = np.array([(1, 2), (1, 3), (1, 2.5)], np.float32)
y = np.array([10.0, 50.0, 20.0], np.float32)
lr = assert_small(X=X, y=y, n_classes=3)
assert np.array_equal(
lr.classes_.to_numpy(), np.array([10.0, 20.0, 50.0], np.float32)
)


@pytest.mark.parametrize("n_parts", [2, 23])
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("n_classes", [8])
def test_n_classes(n_parts, fit_intercept, n_classes, client):
lr = test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=n_parts,
fit_intercept=fit_intercept,
datatype=np.float32,
delayed=True,
client=client,
penalty="l2",
n_classes=n_classes,
)

assert lr._num_classes == n_classes


@pytest.mark.parametrize("penalty", ["l1", "elasticnet"])
@pytest.mark.parametrize("l1_ratio", [0.1])
def test_l1_and_elasticnet(penalty, l1_ratio, client):
X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], np.float32)
y = np.array([1.0, 1.0, 0.0, 0.0], np.float32)
X_df, y_df = _prep_training_data(client, X, y, partitions_per_worker=1)

from cuml.dask.linear_model import LogisticRegression

lr = LogisticRegression(penalty=penalty, l1_ratio=l1_ratio)
with pytest.raises(
RuntimeError, match="Currently only support 'l2' and 'none' penalty"
):
lr.fit(X_df, y_df)

0 comments on commit 3c4ceb9

Please sign in to comment.