diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 0000000..710c918 --- /dev/null +++ b/NEWS.md @@ -0,0 +1,4 @@ +# lime 0.3.0.9999 + +* Added a `NEWS.md` file to track changes to the package. +* Fixed bug when explaining regression models, due to drop=TRUE defaults (#33) diff --git a/R/character.R b/R/character.R index e1c68cb..075e528 100644 --- a/R/character.R +++ b/R/character.R @@ -69,9 +69,9 @@ explain.character <- function(x, explainer, labels = NULL, n_labels = NULL, if (m_type == 'regression') { if (!is.null(labels) || !is.null(n_labels)) { warning('"labels" and "n_labels" arguments are ignored when explaining regression models') - n_labels <- 1 - labels <- NULL } + n_labels <- 1 + labels <- NULL } assert_that(is.null(labels) + is.null(n_labels) == 1, msg = "You need to choose between labels and n_labels parameters.") assert_that(is.count(n_features)) @@ -98,7 +98,7 @@ explain.character <- function(x, explainer, labels = NULL, n_labels = NULL, res <- lapply(seq_along(case_ind), function(ind) { i <- case_ind[[ind]] - res <- model_permutations(case_perm$tabular[i, ], case_res[i, ], case_perm$permutation_distances[i], labels, n_labels, n_features, feature_select) + res <- model_permutations(case_perm$tabular[i, ], case_res[i, , drop = FALSE], case_perm$permutation_distances[i], labels, n_labels, n_features, feature_select) res$feature_value <- res$feature res$feature_desc <- res$feature res$case <- ind diff --git a/R/dataframe.R b/R/dataframe.R index 21e0bd9..66d9621 100644 --- a/R/dataframe.R +++ b/R/dataframe.R @@ -76,9 +76,9 @@ explain.data.frame <- function(x, explainer, labels = NULL, n_labels = NULL, if (m_type == 'regression') { if (!is.null(labels) || !is.null(n_labels)) { warning('"labels" and "n_labels" arguments are ignored when explaining regression models') - n_labels <- 1 - labels <- NULL } + n_labels <- 1 + labels <- NULL } assert_that(is.null(labels) + is.null(n_labels) == 1, msg = "You need to choose between labels and n_labels parameters.") assert_that(is.count(n_features)) @@ -98,7 +98,7 @@ explain.data.frame <- function(x, explainer, labels = NULL, n_labels = NULL, perms <- numerify(case_perm[i, ], explainer$feature_type, explainer$bin_continuous, explainer$bin_cuts) dist <- c(0, dist(feature_scale(perms, explainer$feature_distribution, explainer$feature_type, explainer$bin_continuous), method = dist_fun)[seq_len(n_permutations-1)]) - res <- model_permutations(as.matrix(perms), case_res[i, ], kernel(dist), labels, n_labels, n_features, feature_select) + res <- model_permutations(as.matrix(perms), case_res[i, , drop = FALSE], kernel(dist), labels, n_labels, n_features, feature_select) res$feature_value <- unlist(case_perm[i[1], res$feature]) res$feature_desc <- describe_feature(res$feature, case_perm[i[1], ], explainer$feature_type, explainer$bin_continuous, explainer$bin_cuts) guess <- which.max(abs(case_res[i[1], ]))