In [None]:
"""Fetch raw data"""

from stockpredict.fetcher.main import execute
from pyspark.sql import SparkSession

local_path = "../data/raw"
markers = ["AAPL", "MSFT", "AMZN"]
api_key = "" ### Set api key

builder = (SparkSession.builder.appName("Fetcher")
    .config('spark.jars.packages', 'io.delta:delta-core_2.12:2.3.0,')
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
    .config("spark.sql.catalogImplementation", "hive")
    .config("spark.memory.offHeap.enabled","true")
    .config("spark.memory.offHeap.size","20g")
    .config("spark.driver.memory", "20g")
    .config("spark.executor.memory", "20g")
    )

spark = builder.getOrCreate()

execute(api_key, markers, local_path, spark)

In [None]:
"""Process data for training"""
"""Spark UI is available at http://127.0.0.1:4040/"""
from pyspark.sql import SparkSession
from stockpredict.etl.cleaner import Preprocess, save, normalize

input_path = "../data/raw"
output_path = "../data/clean"
builder = (SparkSession.builder.appName("ETL")
    .config('spark.jars.packages', 'io.delta:delta-core_2.12:2.3.0,')
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
    .config("spark.sql.catalogImplementation", "hive")
    .config("spark.memory.offHeap.enabled","true")
    .config("spark.memory.offHeap.size","20g")
    .config("spark.driver.memory", "20g")
    .config("spark.executor.memory", "20g")
    )

spark = builder.getOrCreate()

data = Preprocess(spark, input_path).execute()

normalized = normalize(data)
save(normalized, output_path)

In [None]:
"""Load tensorboard to show metrics"""
"""If you will not see any output window here - open http://127.0.0.1:6006/"""
%reload_ext tensorboard
%tensorboard --logdir=../model/lightning_logs/ --bind_all

In [None]:
"""Train network"""
from stockpredict.train.main import run_train

input_data = "../data/clean"
output_model = "../model"

run_train(input_data=input_data,
          batch_size=64,
          shuffle=True,
          learning_rate=0.01,
          epochs=100,
          output_model=output_model,
          progress_bar=False)


In [None]:
"""Inference"""
import json
from stockpredict.inference.stock import init, run
from unittest import mock

raw_data = json.dumps(
    {
        "symbol" : "AAPL",
        "start_date" : "2023-10-09",
        "end_date" : "2023-12-25"
    }
)
with mock.patch("inference.stock.os.path.join") as mock_path:
    mock_path.return_value = "../model/checkpoint.ckpt"
    init()
with mock.patch("inference.stock.get_api_key") as mock_get_api_key:
    mock_get_api_key.return_value = "" ### Set api key
    result = run(raw_data)
result