From c8e4ea71d1372589a7f6bb6a96ed9074067f2055 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 22 Jul 2015 12:14:14 +0530 Subject: [PATCH] [SPARK-9223] [PySpark] Support model save/load in LDA --- python/pyspark/mllib/clustering.py | 43 +++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 8a92f6911c24b..58ad99d46e23b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -20,6 +20,7 @@ if sys.version > '3': xrange = range + basestring = str from math import exp, log @@ -579,7 +580,7 @@ class LDAModel(JavaModelWrapper): Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. >>> from pyspark.mllib.linalg import Vectors - >>> from numpy.testing import assert_almost_equal + >>> from numpy.testing import assert_almost_equal, assert_equal >>> data = [ ... [1, Vectors.dense([0.0, 1.0])], ... [2, SparseVector(2, {0: 1.0})], @@ -591,6 +592,19 @@ class LDAModel(JavaModelWrapper): >>> topics = model.topicsMatrix() >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) >>> assert_almost_equal(topics, topics_expect, 1) + + >>> import os, tempfile + >>> from shutil import rmtree + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = LDAModel.load(sc, path) + >>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix()) + >>> sameModel.vocabSize() == model.vocabSize() + True + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def topicsMatrix(self): @@ -601,6 +615,33 @@ def vocabSize(self): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") + def save(self, sc, path): + """Save the LDAModel on to disk. + + :param sc: SparkContext + :param path: str, path to where the model needs to be stored. + """ + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + """Load the LDAModel from disk. + + :param sc: SparkContext + :param path: str, path to where the model is stored. + """ + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load( + sc._jsc.sc(), path) + return cls(java_model) + class LDA(object):