Skip to content

Commit

Permalink
Merge pull request #148 from schalkdaniel/general_updates
Browse files Browse the repository at this point in the history
bernoulli loss printe + tests
  • Loading branch information
Daniel Schalk committed Mar 30, 2018
2 parents 558ca19 + f054064 commit f0ec965
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 4 deletions.
8 changes: 7 additions & 1 deletion R/class_printer.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
# Helper functions:
# -----------------

glueLoss = function (name, definition = NULL)
glueLoss = function (name, definition = NULL, additional.desc = "")
{
if (is.null(definition)) {
definition = "No function specified, probably you are using a custom loss."
Expand All @@ -54,6 +54,7 @@ glueLoss = function (name, definition = NULL)
{definition}
{additional.desc}
")

Expand Down Expand Up @@ -193,6 +194,11 @@ ignore.me = setMethod("show", "Rcpp_AbsoluteLoss", function (object) {
glueLoss("AbsoluteLoss", "|y - f(x)|")
})

setClass("Rcpp_BernoulliLoss")
ignore.me = setMethod("show", "Rcpp_BernoulliLoss", function (object) {
glueLoss("BernoulliLoss", "log(1 + exp(-yf(x))", "Labels should be coded as -1 and 1!")
})

setClass("Rcpp_CustomLoss")
ignore.me = setMethod("show", "Rcpp_CustomLoss", function (object) {
glueLoss("CustomLoss")
Expand Down
2 changes: 1 addition & 1 deletion man/cpp_man/mystyle.css
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ div.dynheader {
background-color: rgba(0, 0, 0, 0.1);
border-bottom-left-radius: 0px;
border-bottom-right-radius: 0px;
box-shadow: 5px 5px 5px rgba(0, 0, 0, 0.15);
box-shadow: none;
-moz-border-radius-bottomleft: 0px;
-moz-border-radius-bottomright: 0px;
-moz-box-shadow: none;
Expand Down
2 changes: 1 addition & 1 deletion src/compboost_modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ RCPP_MODULE (loss_module)
.method("testConstantInitializer", &AbsoluteLossWrapper::testConstantInitializer, "Test the constant initializer function of th eloss")
;

class_<BernoulliLossWrapper> ("AbsoluteLoss")
class_<BernoulliLossWrapper> ("BernoulliLoss")
.derives<LossWrapper> ("Loss")
.constructor ()
.method("testLoss", &BernoulliLossWrapper::testLoss, "Test the defined loss function of the loss")
Expand Down
2 changes: 1 addition & 1 deletion src/loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ double AbsoluteLoss::constantInitializer (const arma::vec& true_value) const

arma::vec BernoulliLoss::definedLoss (const arma::vec& true_value, const arma::vec& prediction) const
{
return arma::log(1 + arma::exp(- true_value * prediction));
return arma::log(1 + arma::exp(- true_value % prediction));
}

/**
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,28 @@ test_that("Absolute loss works", {
absolute.loss$testConstantInitializer(true.value),
median.default(true.value)
)
})

test_that("Bernoulli loss works", {
true.value = rbinom(100, 1, 0.4) * 2 - 1
prediction = runif(100, -1, 1)

bernoulli.loss = BernoulliLoss$new()

# Tests:
# -----------
expect_equal(
bernoulli.loss$testLoss(true.value, prediction),
as.matrix(log(1 + exp(-true.value * prediction)))
)
expect_equal(
bernoulli.loss$testGradient(true.value, prediction),
as.matrix(-true.value / (1 + exp(true.value * prediction)))
)
expect_equal(
bernoulli.loss$testConstantInitializer(true.value),
log(mean(true.value > 0) / (1 - mean(true.value > 0))) / 2
)
})

test_that("Custom loss works", {
Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_printer.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ test_that("Loss printer works", {

quadratic.loss = QuadraticLoss$new()
absolute.loss = AbsoluteLoss$new()
bernoulli.loss = BernoulliLoss$new()

# Function for Custom Loss:
myLossFun = function (true.value, prediction) NULL
Expand All @@ -58,6 +59,7 @@ test_that("Loss printer works", {
test.quadratic.printer = show(quadratic.loss)
test.absolute.printer = show(absolute.loss)
test.custom.printer = show(custom.loss)
test.bernoulliprinter = show(bernoulli.loss)

sink()
close(tc)
Expand All @@ -67,6 +69,7 @@ test_that("Loss printer works", {

expect_equal(test.quadratic.printer, "QuadraticLossPrinter")
expect_equal(test.absolute.printer, "AbsoluteLossPrinter")
expect_equal(test.bernoulliprinter, "BernoulliLossPrinter")
expect_equal(test.custom.printer, "CustomLossPrinter")

})
Expand Down

0 comments on commit f0ec965

Please sign in to comment.