Skip to content

Commit

Permalink
[SPARK-3541][MLLIB] New ALS implementation with improved storage
Browse files Browse the repository at this point in the history
This PR adds a new ALS implementation to `spark.ml` using the pipeline API, which should be able to scale to billions of ratings. Compared with the ALS under `spark.mllib`, the new implementation

1. uses the same algorithm,
2. uses float type for ratings,
3. uses primitive arrays to avoid GC,
4. sorts and compresses ratings on each block so that we can solve least squares subproblems one by one using only one normal equation instance.

The following figure shows performance comparison on copies of the Amazon Reviews dataset using a 16-node (m3.2xlarge) EC2 cluster (the same setup as in http://databricks.com/blog/2014/07/23/scalable-collaborative-filtering-with-spark-mllib.html):
![als-wip](https://cloud.githubusercontent.com/assets/829644/5659447/4c4ff8e0-96c7-11e4-87a9-73c1c63d07f3.png)

I keep the `spark.mllib`'s ALS untouched for easy comparison. If the new implementation works well, I'm going to match the features of the ALS under `spark.mllib` and then make it a wrapper of the new implementation, in a separate PR.

TODO:
- [X] Add unit tests for implicit preferences.

Author: Xiangrui Meng <meng@databricks.com>

Closes apache#3720 from mengxr/SPARK-3541 and squashes the following commits:

1b9e852 [Xiangrui Meng] fix compile
5129be9 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3541
dd0d0e8 [Xiangrui Meng] simplify test code
c627de3 [Xiangrui Meng] add tests for implicit feedback
b84f41c [Xiangrui Meng] address comments
a76da7b [Xiangrui Meng] update ALS tests
2a8deb3 [Xiangrui Meng] add some ALS tests
857e876 [Xiangrui Meng] add tests for rating block and encoded block
d3c1ac4 [Xiangrui Meng] rename some classes for better code readability add more doc and comments
213d163 [Xiangrui Meng] org imports
771baf3 [Xiangrui Meng] chol doc update
ca9ad9d [Xiangrui Meng] add unit tests for chol
b4fd17c [Xiangrui Meng] add unit tests for NormalEquation
d0f99d3 [Xiangrui Meng] add tests for LocalIndexEncoder
80b8e61 [Xiangrui Meng] fix imports
4937fd4 [Xiangrui Meng] update ALS example
56c253c [Xiangrui Meng] rename product to item
bce8692 [Xiangrui Meng] doc for parameters and project the output columns
3f2d81a [Xiangrui Meng] add doc
1efaecf [Xiangrui Meng] add example code
8ae86b5 [Xiangrui Meng] add a working copy of the new ALS implementation
  • Loading branch information
mengxr committed Jan 23, 2015
1 parent e0f7fb7 commit ea74365
Show file tree
Hide file tree
Showing 5 changed files with 1,584 additions and 2 deletions.
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml

import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.{Row, SQLContext}

/**
* An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
* Run with
* {{{
* bin/run-example ml.MovieLensALS
* }}}
*/
object MovieLensALS {

case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)

object Rating {
def parseRating(str: String): Rating = {
val fields = str.split("::")
assert(fields.size == 4)
Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
}
}

case class Movie(movieId: Int, title: String, genres: Seq[String])

object Movie {
def parseMovie(str: String): Movie = {
val fields = str.split("::")
assert(fields.size == 3)
Movie(fields(0).toInt, fields(1), fields(2).split("|"))
}
}

case class Params(
ratings: String = null,
movies: String = null,
maxIter: Int = 10,
regParam: Double = 0.1,
rank: Int = 10,
numBlocks: Int = 10) extends AbstractParams[Params]

def main(args: Array[String]) {
val defaultParams = Params()

val parser = new OptionParser[Params]("MovieLensALS") {
head("MovieLensALS: an example app for ALS on MovieLens data.")
opt[String]("ratings")
.required()
.text("path to a MovieLens dataset of ratings")
.action((x, c) => c.copy(ratings = x))
opt[String]("movies")
.required()
.text("path to a MovieLens dataset of movies")
.action((x, c) => c.copy(movies = x))
opt[Int]("rank")
.text(s"rank, default: ${defaultParams.rank}}")
.action((x, c) => c.copy(rank = x))
opt[Int]("maxIter")
.text(s"max number of iterations, default: ${defaultParams.maxIter}")
.action((x, c) => c.copy(maxIter = x))
opt[Double]("regParam")
.text(s"regularization parameter, default: ${defaultParams.regParam}")
.action((x, c) => c.copy(regParam = x))
opt[Int]("numBlocks")
.text(s"number of blocks, default: ${defaultParams.numBlocks}")
.action((x, c) => c.copy(numBlocks = x))
note(
"""
|Example command line to run this app:
|
| bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
| examples/target/scala-*/spark-examples-*.jar \
| --rank 10 --maxIter 15 --regParam 0.1 \
| --movies path/to/movielens/movies.dat \
| --ratings path/to/movielens/ratings.dat
""".stripMargin)
}

parser.parse(args, defaultParams).map { params =>
run(params)
} getOrElse {
System.exit(1)
}
}

def run(params: Params) {
val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._

val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()

val numRatings = ratings.count()
val numUsers = ratings.map(_.userId).distinct().count()
val numMovies = ratings.map(_.movieId).distinct().count()

println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")

val splits = ratings.randomSplit(Array(0.8, 0.2), 0L)
val training = splits(0).cache()
val test = splits(1).cache()

val numTraining = training.count()
val numTest = test.count()
println(s"Training: $numTraining, test: $numTest.")

ratings.unpersist(blocking = false)

val als = new ALS()
.setUserCol("userId")
.setItemCol("movieId")
.setRank(params.rank)
.setMaxIter(params.maxIter)
.setRegParam(params.regParam)
.setNumBlocks(params.numBlocks)

val model = als.fit(training)

val predictions = model.transform(test).cache()

// Evaluate the model.
// TODO: Create an evaluator to compute RMSE.
val mse = predictions.select('rating, 'prediction)
.flatMap { case Row(rating: Float, prediction: Float) =>
val err = rating.toDouble - prediction
val err2 = err * err
if (err2.isNaN) {
None
} else {
Some(err2)
}
}.mean()
val rmse = math.sqrt(mse)
println(s"Test RMSE = $rmse.")

// Inspect false positives.
predictions.registerTempTable("prediction")
sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie")
sqlContext.sql(
"""
|SELECT userId, prediction.movieId, title, rating, prediction
| FROM prediction JOIN movie ON prediction.movieId = movie.movieId
| WHERE rating <= 1 AND prediction >= 4
| LIMIT 100
""".stripMargin)
.collect()
.foreach(println)

sc.stop()
}
}

0 comments on commit ea74365

Please sign in to comment.