In [2]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from pyspark.sql import SparkSession

if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("TrainValidationSplit")\
        .getOrCreate()

    # Prepare training and test data.
    data = spark.read.format("libsvm")\
        .load("sample_linear_regression_data.txt")
    train, test = data.randomSplit([0.9, 0.1], seed=12345)

    lr = LinearRegression(maxIter=10)

    # We use a ParamGridBuilder to construct a grid of parameters to search over.
    # TrainValidationSplit will try all combinations of values and determine best model using
    # the evaluator.
    paramGrid = ParamGridBuilder()\
        .addGrid(lr.regParam, [0.1, 0.01]) \
        .addGrid(lr.fitIntercept, [False, True])\
        .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])\
        .build()

    # In this case the estimator is simply the linear regression.
    # A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
    tvs = TrainValidationSplit(estimator=lr,
                               estimatorParamMaps=paramGrid,
                               evaluator=RegressionEvaluator(),
                               # 80% of the data will be used for training, 20% for validation.
                               trainRatio=0.8)

    # Run TrainValidationSplit, and choose the best set of parameters.
    model = tvs.fit(train)

    # Make predictions on test data. model is the model with combination of parameters
    # that performed best.
    model.transform(test)\
        .select("features", "label", "prediction")\
        .show()

    spark.stop()

+--------------------+--------------------+--------------------+
|            features|               label|          prediction|
+--------------------+--------------------+--------------------+
|(10,[0,1,2,3,4,5,...| -17.026492264209548|  -1.780062242348691|
|(10,[0,1,2,3,4,5,...|  -16.71909683360509| -0.1893325701092588|
|(10,[0,1,2,3,4,5,...| -15.375857723312297|  0.7252323736487188|
|(10,[0,1,2,3,4,5,...| -13.772441561702871|  3.2696413241677718|
|(10,[0,1,2,3,4,5,...| -13.039928064104615| 0.18817684046065775|
|(10,[0,1,2,3,4,5,...|   -9.42898793151394|  -3.449987079269568|
|(10,[0,1,2,3,4,5,...|    -9.2679651250406|-0.33109075490696316|
|(10,[0,1,2,3,4,5,...|  -9.173693798406978|-0.42727135281551937|
|(10,[0,1,2,3,4,5,...| -7.1500991588127265|   2.936884251408867|
|(10,[0,1,2,3,4,5,...|  -6.930603551528371|-0.02839768193150...|
|(10,[0,1,2,3,4,5,...|  -6.456944198081549| -0.9224776887934015|
|(10,[0,1,2,3,4,5,...| -3.2843694575334834| -1.0821208483033875|
|(10,[0,1,2,3,4,5,...|   