<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Loading-the-stocks" data-toc-modified-id="Loading-the-stocks-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Loading the stocks</a></span></li><li><span><a href="#Getting-the-AI-started" data-toc-modified-id="Getting-the-AI-started-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Getting the AI started</a></span><ul class="toc-item"><li><span><a href="#Regression" data-toc-modified-id="Regression-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Regression</a></span></li></ul></li></ul></div>

# AI Part

In [39]:
# cell for autoreload includes
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Loading the stocks

In [40]:
from src.stocks import Stocks
# The warnings are due to the includes in the file

In [41]:
from pyspark.sql.types import DoubleType, StringType, DateType, StructType, StructField

In [42]:
data_schema = StructType([
    StructField('Date', DateType(), True),
    StructField('High', DoubleType(), True),
    StructField('Low', DoubleType(), True),
    StructField('Open', DoubleType(), True),
    StructField('Close', DoubleType(), True),
    StructField('Volume', DoubleType(), True),
    StructField('Adj Close', DoubleType(), True),
    StructField('company_name', StringType(), True)
])

In [43]:
stocks = Stocks(header=True, delimiter=',', schema=data_schema)

In [44]:
stock = stocks.stocks[0]

In [90]:
stock.predict.load_insights()

In [91]:
df = stock.predict.predDF

## Getting the AI started

In [93]:
df.show()

22/06/14 23:19:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
22/06/14 23:19:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.


+----------+------------------+------------------+------------------+------------------+----------+------------------+------------------+------------------+------------------+-------------------+-----------------+--------------------+--------------------+
|      Date|              High|               Low|              Open|             Close|    Volume|         Adj Close|        next_Close|               CCI|               roc|           momentum|    R_de_williams|moving_average_Close|    daily_return_day|
+----------+------------------+------------------+------------------+------------------+----------+------------------+------------------+------------------+------------------+-------------------+-----------------+--------------------+--------------------+
|2017-01-04|29.127500534057617|           28.9375|28.962499618530273|  29.0049991607666| 8.44724E7|27.247108459472656| 29.15250015258789|133.33333333333334|0.9988807156168351| -0.032501220703125|71.99968610491072|  29.02124977111816

### Regression

In [10]:
trainDF, testDF = df.randomSplit([.8, .2], seed=42)
print(f"There are {trainDF.cache().count()} rows in the training set, and {testDF.cache().count()} in the test set")

There are 828 rows in the training set, and 159 in the test set


In [11]:
(trainRepartitionDF, testRepartitionDF) = (df
                                           .repartition(24)
                                           .randomSplit([.8, .2], seed=42))
print(f"There are {trainRepartitionDF.cache().count()} rows in the training set, and {testRepartitionDF.cache().count()} in the test set")



There are 784 rows in the training set, and 203 in the test set


                                                                                

In [12]:
(trainDF.select("Volume", "Close").summary()).show()

+-------+--------------------+------------------+
|summary|              Volume|             Close|
+-------+--------------------+------------------+
|  count|                 828|               828|
|   mean|1.2550357079589371E8| 56.78291882869702|
| stddev| 5.841617752747542E7|23.539322423050674|
|    min|           2.01978E7|  29.0049991607666|
|    25%|           8.72472E7|41.084999084472656|
|    50%|          1.105776E8| 47.83250045776367|
|    75%|          1.463228E8| 65.48999786376953|
|    max|            4.4794E8|134.17999267578125|
+-------+--------------------+------------------+



In [13]:
from pyspark.ml.feature import VectorAssembler


# C'est faux ! Mais on fait ça pour avoir une idée de ce que l'on peut faire!
vecAssembler = VectorAssembler(inputCols=["Volume", "Open"], outputCol="features")

vecTrainDF = vecAssembler.transform(trainDF)

vecTrainDF.select("features", "Close").show(10)

+--------------------+------------------+
|            features|             Close|
+--------------------+------------------+
|[1.151276E8,28.95...|29.037500381469727|
|[8.44724E7,28.962...|  29.0049991607666|
|[1.270076E8,29.19...|29.477500915527344|
|[1.342476E8,29.48...|29.747499465942383|
|[9.78484E7,29.692...| 29.77750015258789|
|[1.083448E8,29.72...|           29.8125|
|[1.377592E8,29.58...|              30.0|
|     [9.4852E7,30.0]|29.997499465942383|
|[1.023892E8,29.85...| 29.94499969482422|
|[1.303916E8,30.11...|              30.0|
+--------------------+------------------+
only showing top 10 rows



In [60]:
from pyspark.ml.regression import LinearRegression

lr = LinearRegression(featuresCol="features", labelCol="Close")
#lrModel = lr.fit(vecTrainDF)

In [62]:
from pyspark.ml import Pipeline

pipeline = Pipeline(stages=[vecAssembler, lr])
pipelineModel = pipeline.fit(trainDF)

22/06/14 12:52:18 WARN Instrumentation: [65402247] regParam is zero, which might cause numerical instability and overfitting.


In [63]:
predDF = pipelineModel.transform(testDF)

predDF.select("Close", "prediction").show(10)

+-----------------+-----------------+
|            Close|       prediction|
+-----------------+-----------------+
|  794.02001953125| 788.046447971487|
|807.9099731445312|807.5158913894134|
|807.8800048828125|809.9208106445692|
|819.3099975585938|807.8378142630181|
|796.7899780273438|797.0407586626543|
|801.3400268554688|801.9694418669648|
|820.4500122070312|821.5192276807949|
|831.3300170898438|831.7263528499154|
|838.6799926757812|838.0507203871839|
|           843.25|844.3712353335196|
+-----------------+-----------------+
only showing top 10 rows



Series temporelles.

regression