Skip to content

Commit

Permalink
Merge branch 'master' into SPARK-7446-inverse-transform-for-string-in…
Browse files Browse the repository at this point in the history
…dexer
  • Loading branch information
holdenk committed Jul 31, 2015
2 parents 4f06c59 + 83670fc commit 64dd3a3
Show file tree
Hide file tree
Showing 298 changed files with 10,652 additions and 3,143 deletions.
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export("print.jobj")

# MLlib integration
exportMethods("glm",
"predict")
"predict",
"summary")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
4 changes: 3 additions & 1 deletion R/pkg/R/backend.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) {

# TODO: check the status code to output error information
returnStatus <- readInt(conn)
stopifnot(returnStatus == 0)
if (returnStatus != 0) {
stop(readString(conn))
}
readObject(conn)
}
2 changes: 1 addition & 1 deletion R/pkg/R/client.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack
jars <- paste("--jars", jars)
}

if (packages != "") {
if (!identical(packages, "")) {
packages <- paste("--packages", packages)
}

Expand Down
14 changes: 7 additions & 7 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues")

# @rdname intersection
# @export
setGeneric("intersection", function(x, other, numPartitions = 1) {
standardGeneric("intersection") })
setGeneric("intersection",
function(x, other, numPartitions = 1) {
standardGeneric("intersection")
})

# @rdname keys
# @export
Expand Down Expand Up @@ -489,9 +491,7 @@ setGeneric("sample",
#' @rdname sample
#' @export
setGeneric("sample_frac",
function(x, withReplacement, fraction, seed) {
standardGeneric("sample_frac")
})
function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })

#' @rdname saveAsParquetFile
#' @export
Expand Down Expand Up @@ -553,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn

#' @rdname withColumnRenamed
#' @export
setGeneric("withColumnRenamed", function(x, existingCol, newCol) {
standardGeneric("withColumnRenamed") })
setGeneric("withColumnRenamed",
function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") })


###################### Column Methods ##########################
Expand Down
28 changes: 27 additions & 1 deletion R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~' and '+'.
#' operators are supported, including '~', '+', '-', and '.'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
Expand Down Expand Up @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})

#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
#'
#' @param model A fitted MLlib model
#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
#' summary.glm for more information.
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(object = "PipelineModel"),
function(object) {
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", object@model)
weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelWeights", object@model)
coefficients <- as.matrix(unlist(weights))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
})
4 changes: 2 additions & 2 deletions R/pkg/R/pairRDD.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ setMethod("partitionBy",

packageNamesArr <- serialize(.sparkREnv$.packages,
connection = NULL)
broadcastArr <- lapply(ls(.broadcastNames), function(name) {
get(name, .broadcastNames) })
broadcastArr <- lapply(ls(.broadcastNames),
function(name) { get(name, .broadcastNames) })
jrdd <- getJRDD(x)

# We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])],
Expand Down
9 changes: 6 additions & 3 deletions R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
connExists <- function(env) {
tryCatch({
exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]])
}, error = function(err) {
},
error = function(err) {
return(FALSE)
})
}
Expand Down Expand Up @@ -153,7 +154,8 @@ sparkR.init <- function(
.sparkREnv$backendPort <- backendPort
tryCatch({
connectBackend("localhost", backendPort)
}, error = function(err) {
},
error = function(err) {
stop("Failed to connect JVM\n")
})

Expand Down Expand Up @@ -264,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) {
ssc <- callJMethod(sc, "sc")
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
}, error = function(err) {
},
error = function(err) {
stop("Spark SQL is not built with Hive support")
})

Expand Down
4 changes: 4 additions & 0 deletions R/pkg/inst/tests/test_client.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ test_that("no package specified doesn't add packages flag", {
expect_equal(gsub("[[:space:]]", "", args),
"")
})

test_that("multiple packages don't produce a warning", {
expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning()))
})
19 changes: 19 additions & 0 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,22 @@ test_that("predictions match with native glm", {
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})

test_that("dot minus and intercept vs native glm", {
training <- createDataFrame(sqlContext, iris)
model <- glm(Sepal_Width ~ . - Species + 0, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})

test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
coefs <- as.vector(stats$coefficients)
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
expect_true(all(abs(rCoefs - coefs) < 1e-6))
expect_true(all(
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
})
11 changes: 9 additions & 2 deletions R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ test_that("create DataFrame from RDD", {
df <- jsonFile(sqlContext, jsonPathNa)
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
}, error = function(err) {
},
error = function(err) {
skip("Hive is not build with SparkSQL, skipped")
})
sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)")
Expand Down Expand Up @@ -602,7 +603,8 @@ test_that("write.df() as parquet file", {
test_that("test HiveContext", {
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
}, error = function(err) {
},
error = function(err) {
skip("Hive is not build with SparkSQL, skipped")
})
df <- createExternalTable(hiveCtx, "json", jsonPath, "json")
Expand Down Expand Up @@ -1000,6 +1002,11 @@ test_that("crosstab() on a DataFrame", {
expect_identical(expected, ordered)
})

test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
expect_equal(grepl("Table Not Found: blah", retError), TRUE)
})

unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
2 changes: 1 addition & 1 deletion R/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ FAILED=0
LOGFILE=$FWDIR/unit-tests.out
rm -f $LOGFILE

SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
FAILED=$((PIPESTATUS[0]||$FAILED))

if [[ $FAILED != 0 ]]; then
Expand Down
39 changes: 10 additions & 29 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
<name>Spark Project Core</name>
<url>http://spark.apache.org/</url>
<dependencies>
<dependency>
<groupId>org.apache.avro</groupId>
<artifactId>avro-mapred</artifactId>
<classifier>${avro.mapred.classifier}</classifier>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
Expand Down Expand Up @@ -281,7 +286,7 @@
<dependency>
<groupId>org.tachyonproject</groupId>
<artifactId>tachyon-client</artifactId>
<version>0.6.4</version>
<version>0.7.0</version>
<exclusions>
<exclusion>
<groupId>org.apache.hadoop</groupId>
Expand All @@ -292,36 +297,12 @@
<artifactId>curator-recipes</artifactId>
</exclusion>
<exclusion>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-jsp</artifactId>
</exclusion>
<exclusion>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-webapp</artifactId>
</exclusion>
<exclusion>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</exclusion>
<exclusion>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-servlet</artifactId>
<groupId>org.tachyonproject</groupId>
<artifactId>tachyon-underfs-glusterfs</artifactId>
</exclusion>
<exclusion>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</exclusion>
<exclusion>
<groupId>org.powermock</groupId>
<artifactId>powermock-module-junit4</artifactId>
</exclusion>
<exclusion>
<groupId>org.powermock</groupId>
<artifactId>powermock-api-mockito</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.curator</groupId>
<artifactId>curator-test</artifactId>
<groupId>org.tachyonproject</groupId>
<artifactId>tachyon-underfs-s3</artifactId>
</exclusion>
</exclusions>
</dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ final class UnsafeShuffleExternalSorter {

private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);

private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
@VisibleForTesting
static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;

private final int initialSize;
private final int numPartitions;
private final int pageSizeBytes;
@VisibleForTesting
final int maxRecordSizeBytes;
private final TaskMemoryManager memoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
Expand Down Expand Up @@ -109,7 +109,10 @@ public UnsafeShuffleExternalSorter(
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;

this.pageSizeBytes = (int) Math.min(
PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
conf.getSizeAsBytes("spark.buffer.pageSize", "64m"));
this.maxRecordSizeBytes = pageSizeBytes - 4;
this.writeMetrics = writeMetrics;
initializeForWriting();
}
Expand Down Expand Up @@ -272,7 +275,11 @@ void spill() throws IOException {
}

private long getMemoryUsage() {
return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
long totalPageSize = 0;
for (MemoryBlock page : allocatedPages) {
totalPageSize += page.size();
}
return sorter.getMemoryUsage() + totalPageSize;
}

private long freeMemory() {
Expand Down Expand Up @@ -346,23 +353,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
// TODO: we should track metrics on the amount of space wasted when we roll over to a new page
// without using the free space at the end of the current page. We should also do this for
// BytesToBytesMap.
if (requiredSpace > PAGE_SIZE) {
if (requiredSpace > pageSizeBytes) {
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
PAGE_SIZE + ")");
pageSizeBytes + ")");
} else {
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
if (memoryAcquired < PAGE_SIZE) {
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquired < pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquired);
spill();
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquiredAfterSpilling != pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
currentPage = memoryManager.allocatePage(PAGE_SIZE);
currentPage = memoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = PAGE_SIZE;
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ public UnsafeShuffleWriter(
open();
}

@VisibleForTesting
public int maxRecordSizeBytes() {
return sorter.maxRecordSizeBytes;
}

/**
* This convenience method should only be called in test code.
*/
Expand Down
Loading

0 comments on commit 64dd3a3

Please sign in to comment.