## Linear

In [45]:
import pandas as pd

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Pyspark course") \
    .getOrCreate()

In [46]:
from sklearn.datasets import load_boston
data = load_boston()

data['feature_names']

array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
       'TAX', 'PTRATIO', 'B', 'LSTAT'], dtype='<U7')

In [47]:
data_df = pd.DataFrame(data.data, columns=data['feature_names'])
data_df['label'] = data.target

In [48]:
data_sdf = spark.createDataFrame(data_df[['CRIM', 'LSTAT', 'label']])

In [49]:
data_sdf.show()

+-------+-----+-----+
|   CRIM|LSTAT|label|
+-------+-----+-----+
|0.00632| 4.98| 24.0|
|0.02731| 9.14| 21.6|
|0.02729| 4.03| 34.7|
|0.03237| 2.94| 33.4|
|0.06905| 5.33| 36.2|
|0.02985| 5.21| 28.7|
|0.08829|12.43| 22.9|
|0.14455|19.15| 27.1|
|0.21124|29.93| 16.5|
|0.17004| 17.1| 18.9|
|0.22489|20.45| 15.0|
|0.11747|13.27| 18.9|
|0.09378|15.71| 21.7|
|0.62976| 8.26| 20.4|
|0.63796|10.26| 18.2|
|0.62739| 8.47| 19.9|
|1.05393| 6.58| 23.1|
| 0.7842|14.67| 17.5|
|0.80271|11.69| 20.2|
| 0.7258|11.28| 18.2|
+-------+-----+-----+
only showing top 20 rows



In [50]:
from pyspark.ml.feature import VectorAssembler


assembler = VectorAssembler(
    inputCols=['CRIM', 'LSTAT'],
    outputCol="features")

data = assembler.transform(data_sdf)

In [51]:
data.printSchema()

root
 |-- CRIM: double (nullable = true)
 |-- LSTAT: double (nullable = true)
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)



In [52]:
train, test = data.randomSplit([0.9, 0.1], seed=12345)

In [53]:
from pyspark.ml.regression import LinearRegression

reg = LinearRegression(maxIter=100)

In [54]:
model = reg.fit(train)

In [57]:
model.transform(test).show()

+-------+-----+-----+---------------+------------------+
|   CRIM|LSTAT|label|       features|        prediction|
+-------+-----+-----+---------------+------------------+
|0.01951| 8.05| 33.0| [0.01951,8.05]|26.760799675492624|
|0.02763| 4.32| 30.8| [0.02763,4.32]|30.091008048496008|
|0.09744|11.41| 20.0|[0.09744,11.41]|23.756232907479912|
|0.12744| 4.84| 26.6| [0.12744,4.84]|29.621176661541597|
|0.13262|16.47| 19.5|[0.13262,16.47]|19.236030155124016|
| 0.1396|12.33| 20.1| [0.1396,12.33]|22.932405688085776|
|0.14932|13.15| 18.7|[0.14932,13.15]|22.199661091910347|
|0.22927| 18.8| 16.6| [0.22927,18.8]|17.150156503250034|
|0.25387|30.81| 14.4|[0.25387,30.81]| 6.424622911667935|
|0.62739| 8.47| 19.9| [0.62739,8.47]| 26.35224467243256|
|0.01439| 4.38| 29.1| [0.01439,4.38]|30.038161925473855|
| 0.0315| 4.56| 34.9|  [0.0315,4.56]| 29.87648968546921|
|0.05602| 4.45| 50.0| [0.05602,4.45]|29.973360641994073|
|0.06588| 7.56| 39.8| [0.06588,7.56]| 27.19578189712888|
| 0.0837| 5.39| 34.9|  [0.0837,