Skip to content

Commit

Permalink
Merge pull request #630 from m-dz/fix_620-patch
Browse files Browse the repository at this point in the history
Amendments and additional test for PR #624 (issue #620)
  • Loading branch information
topepo committed Apr 13, 2017
2 parents 0a7c6bc + 7967d20 commit 8d09280
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pkg/caret/R/ggplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ ggplot.train <- function(data = NULL, mapping = NULL, metric = data$metric[1], p
strip_lab <- as.character(subset(data$modelInfo$parameters, parameter %in% strip_vars)$label)
for(i in seq_along(strip_vars))
dat[, strip_vars[i]] <- factor(
paste(strip_lab[i], dat[, strip_vars[i]], sep = ": "),
levels = paste(strip_lab[i], sort(unique(dat[, strip_vars[i]])), sep = ": ")
paste(strip_lab[i], format(dat[, strip_vars[i]]), sep = ": "),
levels = paste(strip_lab[i], format(sort(unique(dat[, strip_vars[i]]))), sep = ": ")
)
}
## TODO: use factor(format(x)) to make a solid block of colors?
Expand Down
48 changes: 43 additions & 5 deletions pkg/caret/tests/testthat/test_ggplot.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
context("Test ggplot")

test_that("ggplot.train correctly orders factors", {
library(caret)
library(kernlab)
data(mtcars)
m <- train(mpg ~ cyl + disp,
data = mtcars,
method="svmRadial",
tuneGrid = expand.grid(C=1:2, sigma=c(0.0001, 0.01, 1)))
g <- ggplot(m, plotType="level")
m <- train(
mpg ~ cyl + disp,
data = mtcars,
method = "svmRadial",
tuneGrid = expand.grid(C = 1:2, sigma = c(0.0001, 0.01, 1))
)
g <- ggplot(m, plotType = "level")

# Test plot data
obj_sigma <- as.numeric(levels(g$data$sigma))
Expand All @@ -21,3 +25,37 @@ test_that("ggplot.train correctly orders factors", {
expect_equal(obj_x, sort(obj_x))
expect_equal(obj_y, sort(obj_y))
})

test_that("ggplot.train correctly orders facets' labels", {
library(caret)
library(kernlab)
data(mtcars)
m <- suppressWarnings(train(
mpg ~ cyl + disp,
data = mtcars,
method = "svmPoly",
tuneGrid = expand.grid(
degree = c(0.0001, 0.01, 1),
scale = c(0.0001, 0.01, 1),
C = c(0.0001, 0.01, 1)
)
))
g <- ggplot(m, plotType = "level", nameInStrip = TRUE)

# Test plot data
obj_C <- as.numeric(gsub(
'Cost: ',
'',
levels(g$data$C)
))
expect_equal(obj_C, sort(obj_C))

# Test axes' labels on a built plot
build <- ggplot2::ggplot_build(g)
obj_labels <- as.numeric(gsub(
'Cost: ',
'',
levels(build$layout$panel_layout$C)
))
expect_equal(obj_labels, sort(obj_labels))
})

0 comments on commit 8d09280

Please sign in to comment.