Skip to content

Commit

Permalink
removed softmax loss, removed a redundant variable, moved mg files in…
Browse files Browse the repository at this point in the history
…to same directory
  • Loading branch information
lijinf2 committed Aug 25, 2023
1 parent 6dc3c0d commit e8d6ccf
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
int rank;
int64_t n_samples;
int n_ranks;
T l2;

GLMWithDataMG(raft::handle_t const& handle,
int rank,
Expand All @@ -99,15 +98,13 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
GLMObjective* obj,
const SimpleMat<T>& X,
const SimpleVec<T>& y,
SimpleDenseMat<T>& Z,
T l2)
SimpleDenseMat<T>& Z)
: ML::GLM::detail::GLMWithData<T, GLMObjective>(obj, X, y, Z)
{
this->handle_p = &handle;
this->rank = rank;
this->n_ranks = n_ranks;
this->n_samples = n_samples;
this->l2 = l2;
}

inline T operator()(const SimpleVec<T>& wFlat,
Expand All @@ -125,7 +122,7 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
auto reg = regularizer_obj->reg;
G.fill(0, stream);
float reg_host = 0;
if (this->l2 != 0) {
if (reg->l2_penalty != 0) {
reg->reg_grad(dev_scalar, G, W, lossFunc->fit_intercept, stream);
raft::update_host(&reg_host, dev_scalar, 1, stream);
// note: avoid syncing here because there's a sync before reg_host is used.
Expand Down
25 changes: 9 additions & 16 deletions cpp/src/glm/qn/qn_mg.cuh → cpp/src/glm/qn/mg/qn_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
*/

#include "glm_base_mg.cuh"
#include "glm_logistic.cuh"
#include "glm_regularizer.cuh"
#include "glm_softmax.cuh"
#include "glm_svm.cuh"
#include "qn_solvers.cuh"
#include "qn_util.cuh"
#include <glm/qn/glm_logistic.cuh>
#include <glm/qn/glm_regularizer.cuh>
#include <glm/qn/glm_softmax.cuh>
#include <glm/qn/glm_svm.cuh>
#include <glm/qn/qn_solvers.cuh>
#include <glm/qn/qn_util.cuh>

#include <cuml/linear_model/qn.h>
#include <rmm/device_uvector.hpp>
Expand Down Expand Up @@ -56,8 +56,7 @@ int qn_fit_mg(const raft::handle_t& handle,
ML::GLM::detail::Tikhonov<T> reg(l2);
ML::GLM::detail::RegularizedGLM<T, LossFunction, decltype(reg)> regularizer_obj(&loss, &reg);

auto obj_function =
GLMWithDataMG(handle, rank, n_ranks, n_samples, &regularizer_obj, X, y, Z, l2);
auto obj_function = GLMWithDataMG(handle, rank, n_ranks, n_samples, &regularizer_obj, X, y, Z);
return ML::GLM::detail::qn_minimize(
handle, w0, fx, num_iters, obj_function, l1, opt_param, pams.verbose);
}
Expand Down Expand Up @@ -99,19 +98,13 @@ inline void qn_fit_x_mg(const raft::handle_t& handle,

switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
ASSERT(C == 2, "qn.h: logistic loss invalid C");
ASSERT(C == 2, "qn_mg.cuh: logistic loss invalid C");
ML::GLM::detail::LogisticLoss<T> loss(handle, D, pams.fit_intercept);
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
} break;
case QN_LOSS_SOFTMAX: {
ASSERT(C > 2, "qn.h: softmax invalid C");
ML::GLM::detail::Softmax<T> loss(handle, D, C, pams.fit_intercept);
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
} break;
default: {
ASSERT(false, "qn.h: unknown loss function type (id = %d).", pams.loss);
ASSERT(false, "qn_mg.cuh: unknown loss function type (id = %d).", pams.loss);
}
}
}
Expand Down
6 changes: 1 addition & 5 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
* limitations under the License.
*/

#include "qn/glm_logistic.cuh"
#include "qn/glm_regularizer.cuh"
#include "qn/qn_mg.cuh"
#include "qn/qn_solvers.cuh"
#include "qn/qn_util.cuh"
#include "qn/mg/qn_mg.cuh"
#include "qn/simple_mat/dense.hpp"
#include <cuda_runtime.h>
#include <cuml/common/logger.hpp>
Expand Down

0 comments on commit e8d6ccf

Please sign in to comment.