-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
298 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# ============================================================================ # | ||
# ___. __ # | ||
# ____ ____ _____ ______\_ |__ ____ ____ _______/ |_ # | ||
# _/ ___\/ _ \ / \\____ \| __ \ / _ \ / _ \/ ___/\ __\ # | ||
# \ \__( <_> ) Y Y \ |_> > \_\ ( <_> | <_> )___ \ | | # | ||
# \___ >____/|__|_| / __/|___ /\____/ \____/____ > |__| # | ||
# \/ \/|__| \/ \/ # | ||
# # | ||
# ============================================================================ # | ||
# | ||
# Compboost is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# Compboost is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# You should have received a copy of the GNU General Public License | ||
# along with Compboost. If not, see <http:#www.gnu.org/licenses/>. | ||
# | ||
# This file contains: | ||
# ------------------- | ||
# | ||
# R API which wraps the imported c++ class wrapper of the "Compboost" class | ||
# and acts as the accessor for the user to a high level function within R. | ||
# | ||
# Written by: | ||
# ----------- | ||
# | ||
# Daniel Schalk | ||
# Institut für Statistik | ||
# Ludwig-Maximilians-Universität München | ||
# Ludwigstraße 33 | ||
# D-80539 München | ||
# | ||
# https:#www.compstat.statistik.uni-muenchen.de | ||
# | ||
# =========================================================================== # | ||
|
||
#' @title Parameter plotter for a trained compboost object | ||
#' | ||
#' @description This function can be used to print the trace of the parameters | ||
#' of a trained compboost object. | ||
#' | ||
#' @param object [\code{character(1)}] \cr | ||
#' Trained compboost object. | ||
#' @param legend [\code{logical(1)}] \cr | ||
#' Logical to specify if a legend should be plotted. | ||
#' @param ... \cr | ||
#' Additional parameter given to plot. | ||
#' @export | ||
|
||
plotCompboostParameter = function (object, legend = TRUE, ...) | ||
{ | ||
if (! object$isTrained()) { | ||
warning ("Your given compboost object is not trained!") | ||
return (invisible(1)) | ||
} | ||
|
||
plot.params = list(...) | ||
|
||
parameter.matrix = object$getParameterMatrix() | ||
|
||
parameter.matrix.df = as.data.frame(parameter.matrix$parameter.matrix) | ||
colnames(parameter.matrix.df) = parameter.matrix$parameter.names | ||
|
||
if (! "ylim" %in% names(plot.params)) { | ||
ylim = c(min(parameter.matrix.df), max(parameter.matrix.df)) | ||
} | ||
if (! "xlab" %in% names(plot.params)) { | ||
xlab = "Iterations" | ||
} | ||
if (! "ylab" %in% names(plot.params)) { | ||
ylab = "Parameter Value" | ||
} | ||
if (! "type" %in% names(plot.params)) { | ||
type = "l" | ||
} | ||
if (! "col" %in% names(plot.params)) { | ||
col = rgb( | ||
red = seq(0, 154, length.out = ncol(parameter.matrix.df)), | ||
green = seq(178, 205, length.out = ncol(parameter.matrix.df)), | ||
blue = seq(238, 50, length.out = ncol(parameter.matrix.df)), | ||
alpha = 255, | ||
maxColorValue = 255 | ||
) | ||
} | ||
|
||
if (legend) { | ||
layout(mat = matrix( | ||
data = c( | ||
2, 2, 2, 1, 1, | ||
2, 2, 2, 1, 1, | ||
2, 2, 2, 1, 1 | ||
), | ||
nrow = 3, | ||
byrow = TRUE | ||
)) | ||
|
||
par(mar = c(0, 0, 0, 0)) | ||
|
||
plot(1, type = "n", xlab = "", ylab = "", axes = FALSE, | ||
xlim = c(0, 10), ylim = c(0, 10)) | ||
|
||
legend( | ||
x = 0, | ||
y = 5.5, | ||
legend = colnames(parameter.matrix.df), | ||
lty = 1, | ||
lwd = 2, | ||
col = col, | ||
yjust = 0.5, | ||
xpd = TRUE, | ||
bty = "n" | ||
) | ||
par(mar = c(5.1, 4.1, 4.1, 2.1)) | ||
} | ||
|
||
|
||
|
||
plot( | ||
x = seq_len(nrow(parameter.matrix.df)), | ||
y = parameter.matrix.df[, 1], | ||
ylim = ylim, | ||
xlab = xlab, | ||
ylab = ylab, | ||
type = type, | ||
col = col[1], | ||
... | ||
) | ||
|
||
if (ncol(parameter.matrix.df) > 1) { | ||
for (i in 2:ncol(parameter.matrix.df)) { | ||
points( | ||
x = seq_len(nrow(parameter.matrix.df)), | ||
y = parameter.matrix.df[, i], | ||
type = "l", | ||
col = col[i] | ||
) | ||
} | ||
} | ||
par(mfrow = c(1,1)) | ||
|
||
return (invisible(0)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
context("Plotter works") | ||
|
||
test_that("Parameter plotter works", { | ||
|
||
df = mtcars | ||
|
||
# # Create new variable to check the polynomial baselearner with degree 2: | ||
# df$hp2 = df[["hp"]]^2 | ||
|
||
# Data for the baselearner are matrices: | ||
X.hp = cbind(1, df[["hp"]]) | ||
X.wt = cbind(1, df[["wt"]]) | ||
|
||
# Target variable: | ||
y = df[["mpg"]] | ||
|
||
# Next lists are the same as the used data. Then we can have a look if the oob | ||
# and inbag logger and the train prediction and prediction on newdata are doing | ||
# the same. | ||
|
||
# List for oob logging: | ||
eval.oob.test = list( | ||
"hp" = X.hp, | ||
"wt" = X.wt | ||
) | ||
|
||
# List to test prediction on newdata: | ||
eval.data = eval.oob.test | ||
|
||
|
||
# Prepare compboost: | ||
# ------------------ | ||
|
||
## Baselearner | ||
|
||
# Create new linear baselearner of hp and wt: | ||
linear.factory.hp = PolynomialFactory$new(X.hp, "hp", 1) | ||
linear.factory.wt = PolynomialFactory$new(X.wt, "wt", 1) | ||
|
||
# Create new quadratic baselearner of hp: | ||
quadratic.factory.hp = PolynomialFactory$new(X.hp, "hp", 2) | ||
|
||
# Create new factory list: | ||
factory.list = FactoryList$new() | ||
|
||
# Register factorys: | ||
factory.list$registerFactory(linear.factory.hp) | ||
factory.list$registerFactory(linear.factory.wt) | ||
factory.list$registerFactory(quadratic.factory.hp) | ||
|
||
## Loss | ||
|
||
# Use quadratic loss: | ||
loss.quadratic = QuadraticLoss$new() | ||
|
||
|
||
## Optimizer | ||
|
||
# Use the greedy optimizer: | ||
optimizer = GreedyOptimizer$new() | ||
|
||
## Logger | ||
|
||
# Define logger. We want just the iterations as stopper but also track the | ||
# time, inbag risk and oob risk: | ||
log.iterations = LogIterations$new(TRUE, 500) | ||
log.time = LogTime$new(FALSE, 500, "microseconds") | ||
log.inbag = LogInbagRisk$new(FALSE, loss.quadratic, 0.05) | ||
log.oob = LogOobRisk$new(FALSE, loss.quadratic, 0.05, eval.oob.test, y) | ||
|
||
# Define new logger list: | ||
logger.list = LoggerList$new() | ||
|
||
# Register the logger: | ||
logger.list$registerLogger(log.iterations) | ||
logger.list$registerLogger(log.time) | ||
logger.list$registerLogger(log.inbag) | ||
logger.list$registerLogger(log.oob) | ||
|
||
# Run compboost: | ||
# -------------- | ||
|
||
# Initialize object: | ||
cboost = Compboost$new( | ||
response = y, | ||
learning_rate = 0.05, | ||
stop_if_all_stopper_fulfilled = FALSE, | ||
factory_list = factory.list, | ||
loss = loss.quadratic, | ||
logger_list = logger.list, | ||
optimizer = optimizer | ||
) | ||
|
||
suppressWarnings({ | ||
failed.plotter = plotCompboostParameter(cboost) | ||
}) | ||
|
||
# Train the model (we want to print the trace): | ||
cboost$train(trace = FALSE) | ||
|
||
plotter = plotCompboostParameter(cboost) | ||
|
||
# Test: | ||
# --------- | ||
|
||
expect_equal(failed.plotter, 1) | ||
expect_equal(plotter, 0) | ||
|
||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters