Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into UDAF
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 21, 2015
2 parents 88c7d4d + 87d890c commit 8a8ac4a
Show file tree
Hide file tree
Showing 118 changed files with 4,027 additions and 1,584 deletions.
1 change: 1 addition & 0 deletions R/pkg/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Collate:
'client.R'
'context.R'
'deserialize.R'
'mllib.R'
'serialize.R'
'sparkR.R'
'utils.R'
4 changes: 4 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ export("sparkR.init")
export("sparkR.stop")
export("print.jobj")

# MLlib integration
exportMethods("glm",
"predict")

# Job group lifecycle management methods
export("setJobGroup",
"clearJobGroup",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
#' @rdname column
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })

#' @rdname glm
#' @export
setGeneric("glm")
73 changes: 73 additions & 0 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# mllib.R: Provides methods for MLlib integration

#' @title S4 class that represents a PipelineModel
#' @param model A Java object reference to the backing Scala PipelineModel
#' @export
setClass("PipelineModel", representation(model = "jobj"))

#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~' and '+'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
#' @return a fitted MLlib model
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlContext <- sparkRSQL.init(sc)
#' data(iris)
#' df <- createDataFrame(sqlContext, iris)
#' model <- glm(Sepal_Length ~ Sepal_Width, df)
#'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
family <- match.arg(family)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", deparse(formula), data@sdf, family, lambda,
alpha)
return(new("PipelineModel", model = model))
})

#' Make predictions from a model
#'
#' Makes predictions from a model produced by glm(), similarly to R's predict().
#'
#' @param model A fitted MLlib model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted values
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' model <- glm(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#'}
setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})
13 changes: 8 additions & 5 deletions R/pkg/R/schema.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ structType.structField <- function(x, ...) {
#' @param ... further arguments passed to or from other methods
print.structType <- function(x, ...) {
cat("StructType\n",
sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
"\", type = \"", field$dataType.toString(),
"\", nullable = ", field$nullable(), "\n",
sep = "") })
, sep = "")
sapply(x$fields(),
function(field) {
paste("|-", "name = \"", field$name(),
"\", type = \"", field$dataType.toString(),
"\", nullable = ", field$nullable(), "\n",
sep = "")
}),
sep = "")
}

#' structField
Expand Down
33 changes: 22 additions & 11 deletions R/pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else { # if node[[1]] is length of 1, check for some R special functions.
} else {
# if node[[1]] is length of 1, check for some R special functions.
nodeChar <- as.character(node[[1]])
if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
if (nodeChar == "{" || nodeChar == "(") {
# Skip start symbol.
for (i in 2:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "<-" || nodeChar == "=" ||
nodeChar == "<<-") { # Assignment Ops.
nodeChar == "<<-") {
# Assignment Ops.
defVar <- node[[2]]
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
# Add the defined variable name into defVars.
Expand All @@ -408,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "function") { # Function definition.
} else if (nodeChar == "function") {
# Function definition.
# Add parameter names.
newArgs <- names(node[[2]])
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "$") { # Skip the field.
} else if (nodeChar == "$") {
# Skip the field.
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
} else if (nodeChar == "::" || nodeChar == ":::") {
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
Expand All @@ -429,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
(typeof(node) == "symbol" || typeof(node) == "language")) {
# Base case: current AST node is a leaf node and a symbol or a function call.
nodeChar <- as.character(node)
if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
if (!nodeChar %in% defVars$data) {
# Not a function parameter or local variable.
func.env <- oldEnv
topEnv <- parent.env(.GlobalEnv)
# Search in function environment, and function's enclosing environments
Expand All @@ -439,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
while (!identical(func.env, topEnv)) {
# Namespaces other than "SparkR" will not be searched.
if (!isNamespace(func.env) ||
(getNamespaceName(func.env) == "SparkR" &&
!(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
(getNamespaceName(func.env) == "SparkR" &&
!(nodeChar %in% getNamespaceExports("SparkR")))) {
# Only include SparkR internals.

# Set parameter 'inherits' to FALSE since we do not need to search in
# attached package environments.
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
error = function(e) { FALSE })) {
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
if (is.function(obj)) { # If the node is a function call.
if (is.function(obj)) {
# If the node is a function call.
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
ifelse(identical(func, obj), TRUE, FALSE)
})
if (sum(found) > 0) { # If function has been examined, ignore.
if (sum(found) > 0) {
# If function has been examined, ignore.
break
}
# Function has not been examined, record it and recursively clean its closure.
Expand Down Expand Up @@ -495,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
# environment. First, function's arguments are added to defVars.
defVars <- initAccumulator()
argNames <- names(as.list(args(func)))
for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist.
for (i in 1:(length(argNames) - 1)) {
# Remove the ending NULL in pairlist.
addItemToAccumulator(defVars, argNames[i])
}
# Recursively examine variables in the function body.
Expand Down
42 changes: 42 additions & 0 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

library(testthat)

context("MLlib functions")

# Tests for MLlib functions in SparkR

sc <- sparkR.init()

sqlContext <- sparkRSQL.init(sc)

test_that("glm and predict", {
training <- createDataFrame(sqlContext, iris)
test <- select(training, "Sepal_Length")
model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
prediction <- predict(model, test)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
})

test_that("predictions match with native glm", {
training <- createDataFrame(sqlContext, iris)
model <- glm(Sepal_Width ~ Sepal_Length, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
})
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.10</artifactId>
<artifactId>jackson-module-scala_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.derby</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.Utils;

@Private
public class PrefixComparators {
Expand Down Expand Up @@ -82,7 +83,7 @@ public static final class FloatPrefixComparator extends PrefixComparator {
public int compare(long aPrefix, long bPrefix) {
float a = Float.intBitsToFloat((int) aPrefix);
float b = Float.intBitsToFloat((int) bPrefix);
return (a < b) ? -1 : (a > b) ? 1 : 0;
return Utils.nanSafeCompareFloats(a, b);
}

public long computePrefix(float value) {
Expand All @@ -97,7 +98,7 @@ public static final class DoublePrefixComparator extends PrefixComparator {
public int compare(long aPrefix, long bPrefix) {
double a = Double.longBitsToDouble(aPrefix);
double b = Double.longBitsToDouble(bPrefix);
return (a < b) ? -1 : (a > b) ? 1 : 0;
return Utils.nanSafeCompareDoubles(a, b);
}

public long computePrefix(double value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,63 @@
* to be registered after the page loads. */
$(function() {
$("span.expand-additional-metrics").click(function(){
var status = window.localStorage.getItem("expand-additional-metrics") == "true";
status = !status;

// Expand the list of additional metrics.
var additionalMetricsDiv = $(this).parent().find('.additional-metrics');
$(additionalMetricsDiv).toggleClass('collapsed');

// Switch the class of the arrow from open to closed.
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');

window.localStorage.setItem("expand-additional-metrics", "" + status);
});

if (window.localStorage.getItem("expand-additional-metrics") == "true") {
// Set it to false so that the click function can revert it
window.localStorage.setItem("expand-additional-metrics", "false");
$("span.expand-additional-metrics").trigger("click");
}

stripeSummaryTable();

$('input[type="checkbox"]').click(function() {
var column = "table ." + $(this).attr("name");
var name = $(this).attr("name")
var column = "table ." + name;
var status = window.localStorage.getItem(name) == "true";
status = !status;
$(column).toggle();
stripeSummaryTable();
window.localStorage.setItem(name, "" + status);
});

$("#select-all-metrics").click(function() {
var status = window.localStorage.getItem("select-all-metrics") == "true";
status = !status;
if (this.checked) {
// Toggle all un-checked options.
$('input[type="checkbox"]:not(:checked)').trigger('click');
} else {
// Toggle all checked options.
$('input[type="checkbox"]:checked').trigger('click');
}
window.localStorage.setItem("select-all-metrics", "" + status);
});

if (window.localStorage.getItem("select-all-metrics") == "true") {
$("#select-all-metrics").attr('checked', status);
}

$("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() {
var name = $(this).attr("name")
// If name is undefined, then skip it because it's the "select-all-metrics" checkbox
if (name && window.localStorage.getItem(name) == "true") {
// Set it to false so that the click function can revert it
window.localStorage.setItem(name, "false");
$(this).trigger("click")
}
});

// Trigger a click on the checkbox if a user clicks the label next to it.
Expand Down
Loading

0 comments on commit 8a8ac4a

Please sign in to comment.