Skip to content

Commit

Permalink
add return values to lasso train (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
blbarker authored Sep 2, 2016
1 parent a7290d1 commit c83f0be
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ Consider the following frame containing two columns.
>>> model = ta.LassoModel()
<progress>

>>> model.train(frame, 'y', ['x1'])
>>> results = model.train(frame, 'y', ['x1'])
<progress>

>>> results
{u'intercept': 0.0, u'weights': [2.4387285895043913]}

>>> predicted_frame = model.predict(frame)
<progress>
>>> predicted_frame.inspect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ object MLLibJsonProtocol {
implicit val classificationWithSGDPredictFormat = jsonFormat3(ClassificationWithSGDPredictArgs)
implicit val classificationWithSGDTestFormat = jsonFormat4(ClassificationWithSGDTestArgs)
implicit val lassoTrainFormat = jsonFormat9(LassoTrainArgs)
implicit val lassoTrainReturnFormat = jsonFormat2(LassoTrainReturn)
implicit val lassoDataFormat = jsonFormat2(LassoData)
implicit val svmDataFormat = jsonFormat2(SVMData)
implicit val kmeansDataFormat = jsonFormat3(KMeansData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,15 @@ case class LassoTrainArgs(@ArgDoc("""Handle to the model to be used.""") model:
require(regParam >= 0, "regParam should be greater than or equal to 0")
}

case class LassoTrainReturn(@ArgDoc("""A list of n trained weights, where n is the number of features""") weights: List[Double],
@ArgDoc("""float value representing the independent term in decision function of the model""") intercept: Double)

@PluginDoc(oneLine = "Train Lasso Model",
extended = """Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
* in gradient descent are initialized using the initial weights provided.""")
class LassoTrainPlugin extends SparkCommandPlugin[LassoTrainArgs, UnitReturn] {
class LassoTrainPlugin extends SparkCommandPlugin[LassoTrainArgs, LassoTrainReturn] {
/**
* The name of the command.
*
Expand All @@ -86,7 +89,7 @@ class LassoTrainPlugin extends SparkCommandPlugin[LassoTrainArgs, UnitReturn] {
* @param arguments user supplied arguments to running this plugin
* @return a value of type declared as the Return type.
*/
override def execute(arguments: LassoTrainArgs)(implicit invocation: Invocation): UnitReturn = {
override def execute(arguments: LassoTrainArgs)(implicit invocation: Invocation): LassoTrainReturn = {
val model: Model = arguments.model
val frame: SparkFrame = arguments.frame
val weights: Array[Double] = if (arguments.initialWeights.isDefined) {
Expand All @@ -109,7 +112,7 @@ class LassoTrainPlugin extends SparkCommandPlugin[LassoTrainArgs, UnitReturn] {
val j = jsonModel.toJson.toString
println(s"jsonModel=$j")
model.data = jsonModel.toJson.asJsObject

LassoTrainReturn(trainedModel.weights.toArray.toList, trainedModel.intercept)
}

}
Expand Down

0 comments on commit c83f0be

Please sign in to comment.