Skip to content

Commit

Permalink
[SPARK-9223] [PySpark] Support model save/load in LDA
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 22, 2015
1 parent 89db3c0 commit c8e4ea7
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

if sys.version > '3':
xrange = range
basestring = str

from math import exp, log

Expand Down Expand Up @@ -579,7 +580,7 @@ class LDAModel(JavaModelWrapper):
Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
>>> from pyspark.mllib.linalg import Vectors
>>> from numpy.testing import assert_almost_equal
>>> from numpy.testing import assert_almost_equal, assert_equal
>>> data = [
... [1, Vectors.dense([0.0, 1.0])],
... [2, SparseVector(2, {0: 1.0})],
Expand All @@ -591,6 +592,19 @@ class LDAModel(JavaModelWrapper):
>>> topics = model.topicsMatrix()
>>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
>>> assert_almost_equal(topics, topics_expect, 1)
>>> import os, tempfile
>>> from shutil import rmtree
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = LDAModel.load(sc, path)
>>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix())
>>> sameModel.vocabSize() == model.vocabSize()
True
>>> try:
... rmtree(path)
... except OSError:
... pass
"""

def topicsMatrix(self):
Expand All @@ -601,6 +615,33 @@ def vocabSize(self):
"""Vocabulary size (number of terms or terms in the vocabulary)"""
return self.call("vocabSize")

def save(self, sc, path):
"""Save the LDAModel on to disk.
:param sc: SparkContext
:param path: str, path to where the model needs to be stored.
"""
if not isinstance(sc, SparkContext):
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
self._java_model.save(sc._jsc.sc(), path)

@classmethod
def load(cls, sc, path):
"""Load the LDAModel from disk.
:param sc: SparkContext
:param path: str, path to where the model is stored.
"""
if not isinstance(sc, SparkContext):
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load(
sc._jsc.sc(), path)
return cls(java_model)


class LDA(object):

Expand Down

0 comments on commit c8e4ea7

Please sign in to comment.