Skip to content

Commit

Permalink
[SPARK-15892][ML] Incorrectly merged AFTAggregator with zero total count
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently, `AFTAggregator` is not being merged correctly. For example, if there is any single empty partition in the data, this creates an `AFTAggregator` with zero total count which causes the exception below:

```
IllegalArgumentException: u'requirement failed: The number of instances should be greater than 0.0, but got 0.'
```

Please see [AFTSurvivalRegression.scala#L573-L575](https://github.com/apache/spark/blob/6ecedf39b44c9acd58cdddf1a31cf11e8e24428c/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala#L573-L575) as well.

Just to be clear, the python example `aft_survival_regression.py` seems using 5 rows. So, if there exist partitions more than 5, it throws the exception above since it contains empty partitions which results in an incorrectly merged `AFTAggregator`.

Executing `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py` on a machine with CPUs more than 5 is being failed because it creates tasks with some empty partitions with defualt  configurations (AFAIK, it sets the parallelism level to the number of CPU cores).

## How was this patch tested?

An unit test in `AFTSurvivalRegressionSuite.scala` and manually tested by `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py`.

Author: hyukjinkwon <gurwls223@gmail.com>
Author: Hyukjin Kwon <gurwls223@gmail.com>

Closes apache#13619 from HyukjinKwon/SPARK-15892.
  • Loading branch information
HyukjinKwon authored and jkbradley committed Jun 12, 2016
1 parent 0ff8a68 commit e355460
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
Expand Up @@ -538,7 +538,7 @@ private class AFTAggregator(
* @return This AFTAggregator object.
*/
def merge(other: AFTAggregator): this.type = {
if (totalCnt != 0) {
if (other.count != 0) {
totalCnt += other.totalCnt
lossSum += other.lossSum

Expand Down
Expand Up @@ -390,6 +390,18 @@ class AFTSurvivalRegressionSuite
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
}

test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
// This `dataset` will contain an empty partition because it has two rows but
// the parallelism is bigger than that. Because the issue was about `AFTAggregator`s
// being merged incorrectly when it has an empty partition, running the codes below
// should not throw an exception.
val dataset = spark.createDataFrame(
sc.parallelize(generateAFTInput(
1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3))
val trainer = new AFTSurvivalRegression()
trainer.fit(dataset)
}
}

object AFTSurvivalRegressionSuite {
Expand Down

0 comments on commit e355460

Please sign in to comment.