In [None]:
from datetime import datetime, timedelta
from typing import Callable, Any
from inspect import signature
import warnings, logging

import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline
import ipywidgets as widgets
import yfinance as yf
import pandas as pd
from IPython.display import display

from openbb_terminal.sdk import openbb, theme

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
theme.applyMPLstyle()

In [None]:
try:
    has_forecast = True
    model_opts = {
        "expo": openbb.forecast.expo_chart,
        "theta": openbb.forecast.theta_chart,
        "linregr": openbb.forecast.linregr_chart,
        "regr": openbb.forecast.regr_chart,
        "rnn": openbb.forecast.rnn_chart,
        "brnn": openbb.forecast.brnn_chart,
        "nbeats": openbb.forecast.nbeats_chart,
        "tcn": openbb.forecast.tcn_chart,
        "trans": openbb.forecast.trans_chart,
        "tft": openbb.forecast.tft_chart,
        "nhits": openbb.forecast.nhits_chart,
    }

    feat_engs = {
        "ema": openbb.forecast.ema,
        "sto": openbb.forecast.sto,
        "rsi": openbb.forecast.rsi,
        "roc": openbb.forecast.roc,
        "mom": openbb.forecast.mom,
        "atr": openbb.forecast.atr,
        "delta": openbb.forecast.delta,
        "signal": openbb.forecast.signal,
    }
except AttributeError as e:
    print(e)
    has_forecast = False
    model_opts = {}
    feat_engs = {}

interval_opts = [
    "1m",
    "2m",
    "5m",
    "15m",
    "30m",
    "60m",
    "90m",
    "1h",
    "1d",
    "5d",
    "1wk",
    "1mo",
    "3mo",
]


def format_df(df: pd.DataFrame) -> pd.DataFrame:
    if len(df.columns) != 6:
        df.columns = ["_".join(col).strip() for col in df.columns.values]
    df.reset_index(inplace=True)
    df.columns = [x.lower() for x in df.columns]
    return df


def has_parameter(func: Callable[..., Any], parameter: str) -> bool:
    params = signature(func).parameters
    parameters = params.keys()
    return parameter in parameters

In [None]:
class Chart:
    def __init__(self):
        self.last_tickers = ""
        self.last_interval = "1d"
        self.df = pd.DataFrame()
        self.result = pd.DataFrame()
        self.infos = {}

    def handle_changes(
        self,
        past_covariates,
        start,
        end,
        interval,
        tickers,
        target_column,
        model,
        naive,
        forecast_only,
    ):
        if tickers and tickers[-1] == ",":
            if tickers != self.last_tickers or interval != self.last_interval:
                if interval in ["1d", "5d", "1wk", "1mo", "3mo"]:
                    self.df = yf.download(
                        tickers, period="max", interval=interval, progress=False
                    )
                else:
                    end_date = end + timedelta(days=1)
                    self.df = yf.download(
                        tickers,
                        start=start,
                        end=end_date,
                        interval=interval,
                        progress=False,
                    )
                self.df = format_df(self.df)
                self.last_tickers = tickers
                self.last_interval = interval
            forecast_model = model_opts[model]
            self.forecast_model = forecast_model
            contains_covariates = has_parameter(forecast_model, "past_covariates")

            # Update Inputs
            if list(target_widget.options) != [
                x for x in self.df.columns if x != "date"
            ]:
                target_widget.options = [x for x in self.df.columns if x != "date"]
                return
            if list(past_covs_widget.options) != [
                x for x in self.df.columns if x != "date"
            ]:
                past_covs_widget.options = [x for x in self.df.columns if x != "date"]
                past_covs_widget.disabled = not contains_covariates
                return
            if past_covs_widget.disabled == contains_covariates:
                past_covs_widget.disabled = not contains_covariates
            column_widget.options = [x for x in self.df.columns if x != "date"]

            start_n = datetime(start.year, start.month, start.day)
            end_n = datetime(end.year, end.month, end.day)
            calcs = self.df
            if interval in ["1d", "5d", "1wk", "1mo", "3mo"]:
                result = calcs.loc[
                    (calcs["date"] >= start_n) & (calcs["date"] <= end_n)
                ]
            else:
                result = calcs
            if not target_column:
                target_column = self.df.columns[0]
            kwargs = {}
            if contains_covariates and past_covariates != ():
                kwargs["past_covariates"] = ",".join(past_covariates)
            if has_parameter(forecast_model, "naive"):
                kwargs["naive"] = naive
            if has_parameter(forecast_model, "forecast_only"):
                kwargs["forecast_only"] = forecast_only
            # This sets up everything to run the function on button click

            self.result = result
            self.target_column = target_column
            self.n_predict = 5
            self.kwargs = kwargs

    def handle_click(self, to_run):
        if to_run:
            run_widget.value = False
        else:
            df = self.result.dropna()
            if not df.empty:
                self.forecast_model(
                    self.result,
                    target_column=self.target_column,
                    n_predict=5,
                    **self.kwargs
                )

    def handle_eng(self, target, feature):
        self.feature_target = target
        self.feature_model = feat_engs[feature]

    def handle_click2(self, to_run):
        if to_run:
            add_widget.value = False
        else:
            kwargs = {}
            if has_parameter(self.feature_model, "target_column"):
                kwargs["target_column"] = self.feature_target
            self.df = self.feature_model(self.df, **kwargs)
            past_covs_widget.options = self.df.columns

In [None]:
w_auto = widgets.Layout(width="auto")
model_value = list(model_opts)[0] if model_opts else None
model_widget = widgets.Select(
    options=list(model_opts),
    value=model_value,
    layout=widgets.Layout(
        width="8%",
    ),
)
past_covs_widget = widgets.SelectMultiple(
    options=[""],
    value=[""],
    layout=widgets.Layout(
        width="8%",
    ),
)

base_date = (datetime.today() - timedelta(days=365)).date()
start_widget = widgets.DatePicker(value=base_date, layout=w_auto, description="Start")
end_widget = widgets.DatePicker(
    value=datetime.today().date(), layout=w_auto, description="End"
)

target_widget = widgets.Dropdown(options=[""], value="", description="Target")
interval_widget = widgets.Dropdown(
    options=interval_opts, value="1d", layout=w_auto, description="Interval"
)
tickers_widget = widgets.Textarea(
    value="TSLA,", layout=widgets.Layout(width="auto", height="100%")
)

# Output box
naive_widget = widgets.ToggleButton(value=False, description="Show Naive")
forecast_only_widget = widgets.ToggleButton(value=False, description="Forecast Only")
run_widget = widgets.ToggleButton(value=False, description="Run Model")

# feat_box
feat_value = list(feat_engs.keys())[0] if feat_engs else None
column_widget = widgets.Dropdown(options=[""], value="", description="Target")
feat_widget = widgets.Dropdown(
    options=list(feat_engs.keys()), value=feat_value, description="Feature"
)
add_widget = widgets.ToggleButton(
    value=False, description="Add Feature", layout=widgets.Layout(align="flex_end")
)  # get this to work

selection_box = widgets.VBox([tickers_widget, target_widget])
date_box = widgets.VBox([start_widget, end_widget, interval_widget])
output_box = widgets.VBox([naive_widget, forecast_only_widget, run_widget])
feat_box = widgets.VBox([column_widget, feat_widget, add_widget])
controls = widgets.HBox(
    [selection_box, model_widget, past_covs_widget, date_box, output_box, feat_box],
)
if has_forecast:
    chart = Chart()
    widgets.interactive_output(
        chart.handle_changes,
        {
            "past_covariates": past_covs_widget,
            "start": start_widget,
            "end": end_widget,
            "interval": interval_widget,
            "tickers": tickers_widget,
            "target_column": target_widget,
            "model": model_widget,
            "naive": naive_widget,
            "forecast_only": forecast_only_widget,
        },
    )

    widgets.interactive_output(
        chart.handle_eng, {"target": column_widget, "feature": feat_widget}
    )
    widgets.interactive_output(chart.handle_click2, {"to_run": add_widget})

    stocks_view = widgets.interactive_output(chart.handle_click, {"to_run": run_widget})

    title_html = "<h1>Timeseries Forecasting Dashboard</h1>"
    warning_html = '<p style="color:red"=>Use a comma after EVERY stock typed.</p>'

    app_contents = [
        widgets.HTML(title_html),
        controls,
        widgets.HTML(warning_html),
        stocks_view,
    ]
    app = widgets.VBox(app_contents)
else:
    title_html = "<h1>Timeseries Forecasting Dashboard</h1>"
    warning_html = (
        '<p style="color:red"=>The forecasting dependencies are not installed.</p>'
    )
    app = widgets.VBox([widgets.HTML(title_html), widgets.HTML(warning_html)])
display(app)