Skip to content

Commit

Permalink
[SPARK-30820][SPARKR][ML] Add FMClassifier to SparkR
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This pull request adds SparkR wrapper for `FMClassifier`:

- Supporting ` org.apache.spark.ml.r.FMClassifierWrapper`.
- `FMClassificationModel` S4 class.
- Corresponding `spark.fmClassifier`, `predict`, `summary` and `write.ml` generics.
- Corresponding docs and tests.

### Why are the changes needed?

Feature parity.

### Does this PR introduce any user-facing change?

No (new API).

### How was this patch tested?

New unit tests.

Closes apache#27570 from zero323/SPARK-30820.

Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
zero323 authored and Seongjin Cho committed Apr 14, 2020
1 parent 7e4a591 commit 44a78fb
Show file tree
Hide file tree
Showing 11 changed files with 484 additions and 36 deletions.
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ exportMethods("glm",
"spark.freqItemsets",
"spark.associationRules",
"spark.findFrequentSequentialPatterns",
"spark.assignClusters")
"spark.assignClusters",
"spark.fmClassifier")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,10 @@ setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
setGeneric("spark.bisectingKmeans",
function(data, formula, ...) { standardGeneric("spark.bisectingKmeans") })

#' @rdname spark.fmClassifier
setGeneric("spark.fmClassifier",
function(data, formula, ...) { standardGeneric("spark.fmClassifier") })

#' @rdname spark.gaussianMixture
setGeneric("spark.gaussianMixture",
function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
Expand Down
157 changes: 157 additions & 0 deletions R/pkg/R/mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"
#' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj"))

#' S4 class that represents a FMClassificationModel
#'
#' @param jobj a Java object reference to the backing Scala FMClassifierWrapper
#' @note FMClassificationModel since 3.1.0
setClass("FMClassificationModel", representation(jobj = "jobj"))

#' Linear SVM Model
#'
#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package.
Expand Down Expand Up @@ -649,3 +655,154 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

#' Factorization Machines Classification Model
#'
#' \code{spark.fmClassifier} fits a factorization classification model against a SparkDataFrame.
#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#' Only categorical data is supported.
#'
#' @param data a \code{SparkDataFrame} of observations and labels for model fitting.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param factorSize dimensionality of the factors.
#' @param fitLinear whether to fit linear term. # TODO Can we express this with formula?
#' @param regParam the regularization parameter.
#' @param miniBatchFraction the mini-batch fraction parameter.
#' @param initStd the standard deviation of initial coefficients.
#' @param maxIter maximum iteration number.
#' @param stepSize stepSize parameter.
#' @param tol convergence tolerance of iterations.
#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "adamW".
#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of
#' class label 1 is > threshold, then predict 1, else 0. A high threshold
#' encourages the model to predict 0 more often; a low threshold encourages the
#' model to predict 1 more often. Note: Setting this with threshold p is
#' equivalent to setting thresholds c(1-p, p).
#' @param seed seed parameter for weights initialization.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and
#' label column of string type.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in
#' a special additional bucket, at index numLabels). Default
#' is "error".
#' @param ... additional arguments passed to the method.
#' @return \code{spark.fmClassifier} returns a fitted Factorization Machines Classification Model.
#' @rdname spark.fmClassifier
#' @aliases spark.fmClassifier,SparkDataFrame,formula-method
#' @name spark.fmClassifier
#' @seealso \link{read.ml}
#' @examples
#' \dontrun{
#' df <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm")
#'
#' # fit Factorization Machines Classification Model
#' model <- spark.fmClassifier(
#' df, label ~ features,
#' regParam = 0.01, maxIter = 10, fitLinear = TRUE
#' )
#'
#' # get the summary of the model
#' summary(model)
#'
#' # make predictions
#' predictions <- predict(model, df)
#'
#' # save and load the model
#' path <- "path/to/model"
#' write.ml(model, path)
#' savedModel <- read.ml(path)
#' summary(savedModel)
#' }
#' @note spark.fmClassifier since 3.1.0
setMethod("spark.fmClassifier", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, factorSize = 8, fitLinear = TRUE, regParam = 0.0,
miniBatchFraction = 1.0, initStd = 0.01, maxIter = 100, stepSize=1.0,
tol = 1e-6, solver = c("adamW", "gd"), thresholds = NULL, seed = NULL,
handleInvalid = c("error", "keep", "skip")) {

formula <- paste(deparse(formula), collapse = "")

if (!is.null(seed)) {
seed <- as.character(as.integer(seed))
}

if (!is.null(thresholds)) {
thresholds <- as.list(thresholds)
}

solver <- match.arg(solver)
handleInvalid <- match.arg(handleInvalid)

jobj <- callJStatic("org.apache.spark.ml.r.FMClassifierWrapper",
"fit",
data@sdf,
formula,
as.integer(factorSize),
as.logical(fitLinear),
as.numeric(regParam),
as.numeric(miniBatchFraction),
as.numeric(initStd),
as.integer(maxIter),
as.numeric(stepSize),
as.numeric(tol),
solver,
seed,
thresholds,
handleInvalid)
new("FMClassificationModel", jobj = jobj)
})

# Returns the summary of a FM Classification model produced by \code{spark.fmClassifier}

#' @param object a FM Classification model fitted by \code{spark.fmClassifier}.
#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' @rdname spark.fmClassifier
#' @note summary(FMClassificationModel) since 3.1.0
setMethod("summary", signature(object = "FMClassificationModel"),
function(object) {
jobj <- object@jobj
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
numClasses <- callJMethod(jobj, "numClasses")
numFeatures <- callJMethod(jobj, "numFeatures")
raw_factors <- unlist(callJMethod(jobj, "rFactors"))
factor_size <- callJMethod(jobj, "factorSize")

list(
coefficients = coefficients,
factors = matrix(raw_factors, ncol = factor_size),
numClasses = numClasses, numFeatures = numFeatures,
factorSize = factor_size
)
})

# Predicted values based on an FMClassificationModel model

#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns the predicted values based on a FM Classification model.
#' @rdname spark.fmClassifier
#' @aliases predict,FMClassificationModel,SparkDataFrame-method
#' @note predict(FMClassificationModel) since 3.1.0
setMethod("predict", signature(object = "FMClassificationModel"),
function(object, newData) {
predict_internal(object, newData)
})

# Save fitted FMClassificationModel to the input path

#' @param path The directory where the model is saved.
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.fmClassifier
#' @aliases write.ml,FMClassificationModel,character-method
#' @note write.ml(FMClassificationModel, character) since 3.1.0
setMethod("write.ml", signature(object = "FMClassificationModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
2 changes: 2 additions & 0 deletions R/pkg/R/mllib_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ read.ml <- function(path) {
new("LinearSVCModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) {
new("FPGrowthModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FMClassifierWrapper")) {
new("FMClassificationModel", jobj = jobj)
} else {
stop("Unsupported model: ", jobj)
}
Expand Down
34 changes: 34 additions & 0 deletions R/pkg/tests/fulltests/test_mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,38 @@ test_that("spark.naiveBayes", {
expect_equal(class(collect(predictions)$clicked[1]), "character")
})

test_that("spark.fmClassifier", {
df <- withColumn(
suppressWarnings(createDataFrame(iris)),
"Species", otherwise(when(column("Species") == "Setosa", "Setosa"), "Not-Setosa")
)

model1 <- spark.fmClassifier(
df, Species ~ .,
regParam = 0.01, maxIter = 10, fitLinear = TRUE, factorSize = 3
)

prediction1 <- predict(model1, df)
expect_is(prediction1, "SparkDataFrame")
expect_equal(summary(model1)$factorSize, 3)

# Test model save/load
if (windows_with_hadoop()) {
modelPath <- tempfile(pattern = "spark-fmclassifier", fileext = ".tmp")
write.ml(model1, modelPath)
model2 <- read.ml(modelPath)

expect_is(model2, "FMClassificationModel")

expect_equal(summary(model1), summary(model2))

prediction2 <- predict(model2, df)
expect_equal(
collect(drop(prediction1, c("rawPrediction", "probability"))),
collect(drop(prediction2, c("rawPrediction", "probability")))
)
unlink(modelPath)
}
})

sparkR.session.stop()
20 changes: 20 additions & 0 deletions R/pkg/vignettes/sparkr-vignettes.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ SparkR supports the following machine learning models and algorithms.

* Naive Bayes

* Factorization Machines (FM) Classifier

#### Regression

* Accelerated Failure Time (AFT) Survival Model
Expand Down Expand Up @@ -705,6 +707,24 @@ naiveBayesPrediction <- predict(naiveBayesModel, titanicDF)
head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction"))
```

#### Factorization Machines Classifier

Factorization Machines for classification problems.

For background and details about the implementation of factorization machines,
refer to the [Factorization Machines section](https://spark.apache.org/docs/latest/ml-classification-regression.html#factorization-machines).

```{r}
t <- as.data.frame(Titanic)
training <- createDataFrame(t)
model <- spark.fmClassifier(training, Survived ~ Age + Sex)
summary(model)
predictions <- predict(model, training)
head(select(predictions, predictions$prediction))
```

#### Accelerated Failure Time Survival Model

Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring.
Expand Down
Loading

0 comments on commit 44a78fb

Please sign in to comment.