In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("RLR").getOrCreate()

In [7]:
from pyspark.ml.regression import LinearRegression

In [11]:
# read from local
data = spark.read.csv("data/joined_cac40_covid_data.csv", header = True, sep=";", inferSchema = True)

In [4]:
# read from hdfs
data = spark.read.csv("hdfs://localhost:9000/project/joined_cac40_covid_data.csv", header = True, sep=";", inferSchema = True)

In [3]:
data.head(5)

[Row(Date=datetime.datetime(2020, 3, 18, 0, 0), Adj Close=3754.840088, hospitalized=2972, in intensive care=771, returning home=816, deceased=218),
 Row(Date=datetime.datetime(2020, 3, 19, 0, 0), Adj Close=3855.5, hospitalized=4073, in intensive care=1002, returning home=1180, deceased=327),
 Row(Date=datetime.datetime(2020, 3, 20, 0, 0), Adj Close=4048.800049, hospitalized=5226, in intensive care=1297, returning home=1587, deceased=450),
 Row(Date=datetime.datetime(2020, 3, 23, 0, 0), Adj Close=3914.310059, hospitalized=8673, in intensive care=2080, returning home=2567, deceased=860),
 Row(Date=datetime.datetime(2020, 3, 24, 0, 0), Adj Close=4242.700195, hospitalized=10176, in intensive care=2516, returning home=3281, deceased=1100)]

In [13]:
data.printSchema()

root
 |-- Date: timestamp (nullable = true)
 |-- Adj Close: double (nullable = true)
 |-- hospitalized: integer (nullable = true)
 |-- in intensive care: integer (nullable = true)
 |-- returning home: integer (nullable = true)
 |-- deceased: integer (nullable = true)



In [16]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [17]:
data.columns

['Date',
 'Adj Close',
 'hospitalized',
 'in intensive care',
 'returning home',
 'deceased']

In [21]:
assembler = VectorAssembler(inputCols=['hospitalized'],
                            outputCol='features')

In [22]:
output = assembler.transform(data)

In [23]:
output.head(1)

[Row(Date=datetime.datetime(2020, 3, 18, 0, 0), Adj Close=3754.840088, hospitalized=2972, in intensive care=771, returning home=816, deceased=218, features=DenseVector([2972.0]))]

In [24]:
final_data = output.select('features', 'Adj Close')

In [40]:
final_data.show()

+---------+------------------+
| features|         Adj Close|
+---------+------------------+
| [2972.0]|       3754.840088|
| [4073.0]|            3855.5|
| [5226.0]|       4048.800049|
| [8673.0]|       3914.310059|
|[10176.0]|       4242.700195|
|[12072.0]|       4432.299805|
|[13879.0]|       4543.580078|
|[15701.0]| 4351.490234000001|
|[20946.0]| 4378.509765999999|
|[22672.0]|       4396.120117|
|[24543.0]| 4207.240234000001|
|[26131.0]|       4220.959961|
|[27302.0]|       4154.580078|
|[29569.0]|4346.1401369999985|
|[29871.0]|        4438.27002|
|[30217.0]|           4442.75|
|[30608.0]|       4506.850098|
|[32131.0]|       4523.910156|
|[31623.0]|       4353.720215|
|[31172.0]|       4350.160156|
+---------+------------------+
only showing top 20 rows



In [26]:
train_data, test_data = final_data.randomSplit([0.7, 0.3])

In [39]:
lr = LinearRegression(featuresCol='features', labelCol='Adj Close', predictionCol='prediction')

In [41]:
lr_model = lr.fit(train_data)

In [42]:
test_results = lr_model.evaluate(test_data)

In [44]:
test_results.residuals.show()

+-------------------+
|          residuals|
+-------------------+
|  8.616815223835147|
| 111.48987200987631|
|  49.24656405344194|
|  132.4974404572913|
| 188.91282401403078|
|-203.65181823431067|
| -77.10100532360411|
|-55.117782195969085|
|  -138.005833333581|
|-57.000503785469846|
+-------------------+



In [45]:
test_results.rootMeanSquaredError

118.71156532254412

In [46]:
test_results.r2

0.2751864637719861

In [47]:
unlabeled_data = test_data.select('features')

In [48]:
unlabeled_data.show()

+---------+
| features|
+---------+
| [5226.0]|
|[10176.0]|
|[20946.0]|
|[21530.0]|
|[22657.0]|
|[26131.0]|
|[29627.0]|
|[29871.0]|
|[29984.0]|
|[30217.0]|
+---------+



In [49]:
pred = lr_model.transform(unlabeled_data)

In [50]:
pred.show()

+---------+-----------------+
| features|       prediction|
+---------+-----------------+
| [5226.0]|4040.183233776165|
|[10176.0]|4131.210322990124|
|[20946.0]|4329.263201946557|
|[21530.0]|4340.002559542709|
|[22657.0]|4360.727312985968|
|[26131.0]| 4424.61177923431|
|[29627.0]|4488.900810323604|
|[29871.0]|4493.387802195969|
|[29984.0]|4495.465794333581|
|[30217.0]| 4499.75050378547|
+---------+-----------------+

