In [1]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window, Row
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql import functions as F
from pyspark.sql import types as T
import pandas as pd
from pyspark.sql.window import Window

In [38]:
from pyspark.ml.regression import *
from pyspark.ml.evaluation import *
from pyspark.ml.feature import *

In [884]:
import numpy as np

In [621]:
spark = SparkSession.builder.appName('covid19').getOrCreate()

In [634]:
df_cases = spark.read.csv('covid_numconf.csv', header=True)
df_words = spark.read.csv('word_count_pivot.csv', header=True)

In [635]:
df_cases = df_cases.select(
        'date',
        F.col('numconf').cast('Long'),
        F.lit(1).alias('temp')
    )

In [636]:
most_corr_cols = [
    "battl",
    "case_death",
    "coronaviru_lockdown",
    "covid_lockdown",
    "covid_patient",
    "death",
    "donat",
    "equip",
    "hero",
    "lockdown",
    "new_york",
    "nurs",
    "polic",
    "ppe",
    "stayhom",
    "stayhomesavel",
    "staysaf",
    "total",
    "ventil",
    "york"
]

In [637]:
df_words = df_words.select('date', *most_corr_cols)

In [638]:
df_words.toPandas()

Unnamed: 0,date,battl,case_death,coronaviru_lockdown,covid_lockdown,covid_patient,death,donat,equip,hero,...,new_york,nurs,polic,ppe,stayhom,stayhomesavel,staysaf,total,ventil,york
0,2020-03-04,15,94,0,0,6,380,19,11,5,...,27,20,5,8,0,0,2,232,2,28
1,2020-03-05,38,156,3,2,28,714,20,38,6,...,37,58,29,13,0,0,19,422,4,39
2,2020-03-06,20,137,2,0,21,546,11,16,6,...,39,62,18,5,0,0,22,432,1,39
3,2020-03-07,19,153,2,0,11,492,9,10,5,...,46,34,33,6,1,0,15,360,5,46
4,2020-03-08,13,92,5,2,12,448,12,6,1,...,45,39,29,3,1,0,5,265,0,45
5,2020-03-09,57,233,29,16,49,1219,48,41,26,...,133,129,58,19,3,0,36,776,16,136
6,2020-03-10,199,509,69,47,348,3260,179,212,88,...,437,589,158,97,26,0,90,1855,122,448
7,2020-03-11,269,621,47,60,528,4212,296,335,179,...,457,907,175,170,83,0,178,2312,193,472
8,2020-03-12,348,503,129,689,391,4462,412,361,382,...,483,922,230,109,156,0,368,2330,200,510
9,2020-03-13,997,874,624,971,820,9863,2020,911,829,...,1214,2164,698,306,596,0,1166,4770,641,1302


In [901]:
window = Window.partitionBy('temp').orderBy('date')

data = df_words.join(df_cases, on='date').select(
        'date',
        'numconf',
        F.lag('numconf', 1).over(window).alias('numconf_lag1'),
        F.lag('numconf', 2).over(window).alias('numconf_lag2'),
        F.lag('numconf', 3).over(window).alias('numconf_lag3'),
        F.lag('numconf', 4).over(window).alias('numconf_lag4'),
        F.lag('numconf', 5).over(window).alias('numconf_lag5'),
        F.lead('numconf', 3).over(window).alias('label'),
        *[F.col(c).cast('Long') for c in most_corr_cols]
    )\
    .dropna()

In [902]:
data.toPandas()

Unnamed: 0,date,numconf,numconf_lag1,numconf_lag2,numconf_lag3,numconf_lag4,numconf_lag5,label,battl,case_death,...,new_york,nurs,polic,ppe,stayhom,stayhomesavel,staysaf,total,ventil,york
0,2020-03-09,77,62,57,51,45,39,138,57,233,...,133,129,58,19,3,0,36,776,16,136
1,2020-03-10,90,77,62,57,51,45,176,199,509,...,437,589,158,97,26,0,90,1855,122,448
2,2020-03-11,103,90,77,62,57,51,193,269,621,...,457,907,175,170,83,0,178,2312,193,472
3,2020-03-12,138,103,90,77,62,57,249,348,503,...,483,922,230,109,156,0,368,2330,200,510
4,2020-03-13,176,138,103,90,77,62,324,997,874,...,1214,2164,698,306,596,0,1166,4770,641,1302
5,2020-03-14,193,176,138,103,90,77,424,516,836,...,580,1228,582,238,861,1,768,3466,426,599
6,2020-03-15,249,193,176,138,103,90,569,551,717,...,770,1331,659,262,4429,451,699,3722,849,793
7,2020-03-16,324,249,193,176,138,103,846,658,701,...,1154,1524,854,295,2685,1087,1304,3762,984,1210
8,2020-03-17,424,324,249,193,176,138,971,1044,873,...,1195,2263,1162,531,3465,895,2235,4776,1469,1285
9,2020-03-18,569,424,324,249,193,176,1302,864,854,...,890,1932,1105,587,2504,257,1588,3918,904,943


In [1019]:
train_data = data.filter(F.col('date') < '2020-03-25')
test_data = data.filter(F.col('date') > '2020-03-25').filter(F.col('date') <= '2020-04-01')

In [1020]:
feature_cols = train_data.columns[1:-1]
feature_cols

['numconf',
 'numconf_lag1',
 'numconf_lag2',
 'numconf_lag3',
 'numconf_lag4',
 'numconf_lag5',
 'label',
 'battl',
 'case_death',
 'coronaviru_lockdown',
 'covid_lockdown',
 'covid_patient',
 'death',
 'donat',
 'equip',
 'hero',
 'lockdown',
 'new_york',
 'nurs',
 'polic',
 'ppe',
 'stayhom',
 'stayhomesavel',
 'staysaf',
 'total',
 'ventil']

In [1021]:
assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')
train_data = assembler.transform(train_data)
test_data = assembler.transform(test_data)

In [1022]:
train_data = train_data.select(['features', 'label'])
test_data = test_data.select(['features', 'label'])

In [1023]:
train_data.toPandas()

Unnamed: 0,features,label
0,"[77.0, 62.0, 57.0, 51.0, 45.0, 39.0, 138.0, 57...",138
1,"[90.0, 77.0, 62.0, 57.0, 51.0, 45.0, 176.0, 19...",176
2,"[103.0, 90.0, 77.0, 62.0, 57.0, 51.0, 193.0, 2...",193
3,"[138.0, 103.0, 90.0, 77.0, 62.0, 57.0, 249.0, ...",249
4,"[176.0, 138.0, 103.0, 90.0, 77.0, 62.0, 324.0,...",324
5,"[193.0, 176.0, 138.0, 103.0, 90.0, 77.0, 424.0...",424
6,"[249.0, 193.0, 176.0, 138.0, 103.0, 90.0, 569....",569
7,"[324.0, 249.0, 193.0, 176.0, 138.0, 103.0, 846...",846
8,"[424.0, 324.0, 249.0, 193.0, 176.0, 138.0, 971...",971
9,"[569.0, 424.0, 324.0, 249.0, 193.0, 176.0, 130...",1302


In [1024]:
lr = LinearRegression(featuresCol='features', labelCol='label', maxIter=1000, regParam=0.6)
model = lr.fit(train_data)

In [1025]:
prediction = model.transform(test_data)
result = prediction.toPandas()

In [1026]:
evaluator = RegressionEvaluator(labelCol="label", predictionCol="prediction", metricName="rmse")

In [1027]:
rmse = evaluator.evaluate(prediction)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

Root Mean Squared Error (RMSE) on test data = 250.549


In [1028]:
np.abs(result['prediction']-result['label'])/result['label']

0    0.038449
1    0.026594
2    0.020486
3    0.036656
dtype: float64

In [1029]:
result

Unnamed: 0,features,label,prediction
0,"[4018.0, 3385.0, 1959.0, 1646.0, 1430.0, 1302....",6255,6014.498932
1,"[4675.0, 4018.0, 3385.0, 1959.0, 1646.0, 1430....",7424,7226.568536
2,"[5386.0, 4675.0, 4018.0, 3385.0, 1959.0, 1646....",8536,8361.132907
3,"[6255.0, 5386.0, 4675.0, 4018.0, 3385.0, 1959....",9595,9946.711199
