From aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 1 Mar 2015 16:26:57 -0800 Subject: [PATCH] [SPARK-6053][MLLIB] support save/load in PySpark's ALS A simple wrapper to save/load `MatrixFactorizationModel` in Python. jkbradley Author: Xiangrui Meng Closes #4811 from mengxr/SPARK-5991 and squashes the following commits: f135dac [Xiangrui Meng] update save doc 57e5200 [Xiangrui Meng] address comments 06140a4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5991 282ec8d [Xiangrui Meng] support save/load in PySpark's ALS --- docs/mllib-collaborative-filtering.md | 8 ++- .../spark/mllib/util/modelSaveLoad.scala | 2 +- python/pyspark/mllib/recommendation.py | 20 ++++++- python/pyspark/mllib/util.py | 58 +++++++++++++++++++ 4 files changed, 82 insertions(+), 6 deletions(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 27aa4d38b7617..76140282a2dd0 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -200,10 +200,8 @@ In the following example we load rating data. Each row consists of a user, a pro We use the default ALS.train() method which assumes ratings are explicit. We evaluate the recommendation by measuring the Mean Squared Error of rating prediction. -Note that the Python API does not yet support model save/load but will in the future. - {% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") @@ -220,6 +218,10 @@ predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = MatrixFactorizationModel.load(sc, "myModelPath") {% endhighlight %} If the rating matrix is derived from other source of information (i.e., it is inferred from other diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index 4458340497f0b..526d055c87387 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -48,7 +48,7 @@ trait Saveable { * * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. - * This directory and any intermediate directory will be created if needed. + * If the directory already exists, this method throws an exception. */ def save(sc: SparkContext, path: String): Unit diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 0d99e6dedfad9..03d7d011474cb 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -19,7 +19,8 @@ from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.util import Saveable, JavaLoader __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -39,7 +40,8 @@ def __reduce__(self): return Rating, (int(self.user), int(self.product), float(self.rating)) -class MatrixFactorizationModel(JavaModelWrapper): +@inherit_doc +class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader): """A matrix factorisation model trained by regularized alternating least-squares. @@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) 0.43... + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = MatrixFactorizationModel.load(sc, path) + >>> sameModel.predict(2,2) + 0.43... + >>> try: + ... os.removedirs(path) + ... except: + ... pass """ def predict(self, user, product): return self._java_model.predict(int(user), int(product)) @@ -98,6 +111,9 @@ def userFeatures(self): def productFeatures(self): return self.call("getProductFeatures") + def save(self, sc, path): + self.call("save", sc._jsc.sc(), path) + class ALS(object): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 4ed978b45409c..17d43eadba12b 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -168,6 +168,64 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) +class Saveable(object): + """ + Mixin for models and transformers which may be saved as files. + """ + + def save(self, sc, path): + """ + Save this model to the given path. + + This saves: + * human-readable (JSON) model metadata to path/metadata/ + * Parquet formatted data to path/data/ + + The model may be loaded using py:meth:`Loader.load`. + + :param sc: Spark context used to save model data. + :param path: Path specifying the directory in which to save + this model. If the directory already exists, + this method throws an exception. + """ + raise NotImplementedError + + +class Loader(object): + """ + Mixin for classes which can load saved models from files. + """ + + @classmethod + def load(cls, sc, path): + """ + Load a model from the given path. The model should have been + saved using py:meth:`Saveable.save`. + + :param sc: Spark context used for loading model files. + :param path: Path specifying the directory to which the model + was saved. + :return: model instance + """ + raise NotImplemented + + +class JavaLoader(Loader): + """ + Mixin for classes which can load saved models using its Scala + implementation. + """ + + @classmethod + def load(cls, sc, path): + java_package = cls.__module__.replace("pyspark", "org.apache.spark") + java_class = ".".join([java_package, cls.__name__]) + java_obj = sc._jvm + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return cls(java_obj.load(sc._jsc.sc(), path)) + + def _test(): import doctest from pyspark.context import SparkContext