In [1]:
import ccxt

# Set up Binance API
binance = ccxt.binance()
symbol = "BTC/USDT"


In [2]:
import import_ipynb
from predictor import *

import numpy as np


importing Jupyter notebook from predictor.ipynb


In [3]:
rnn_close = RNNPredictor("../models/rnn_close.h5")
lstm_close = LSTMPredictor("../models/lstm_close.h5")
xgb_close = XGBoostPredictor("../models/xgb_close.json")
rnn_poc = RNNPredictor("../models/rnn_poc.h5")
lstm_poc = LSTMPredictor("../models/rnn_poc.h5")
xgb_poc = XGBoostPredictor("../models/xgb_poc.json")


def ROC(cp, tf):
    roc = []
    x = tf
    while x < len(cp):
        rocs = (cp[x] - cp[x - tf]) / cp[x - tf]
        roc.append(rocs)
        x += 1

    return roc


def predict(inputs, model_name, n_to_predict):
    if model_name == "RNN":
        model = rnn_close
    elif model_name == "LSTM":
        model = lstm_close
    elif model_name == "XGBoost":
        model = xgb_close

    return model.predict(inputs, n_to_predict)


In [4]:
import dash
from dash import dcc
from dash import html
from dash.dependencies import Output, Input

import plotly.graph_objs as go

import dash_bootstrap_components as dbc


In [5]:
dbc_css = "https://cdn.jsdelivr.net/gh/AnnMarieW/dash-bootstrap-templates/dbc.min.css"
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, dbc_css])

graph_update_interval_seconds = 20

app.layout = html.Div(
    [
        html.H4(
            "BTC-USDT Analysis Dashboard",
            className="bg-primary text-white p-2 mb-2 text-center",
        ),
        dbc.Card(
            [
                html.Div(
                    [
                        dbc.Row(
                            [
                                dbc.Col(dbc.Label("Model"), width=2),
                                dbc.Col(
                                    html.Div(
                                        [
                                            dbc.RadioItems(
                                                id="model-name",
                                                className="btn-group",
                                                inputClassName="btn-check",
                                                labelClassName="btn btn-outline-primary",
                                                labelCheckedClassName="active",
                                                options=["LSTM", "RNN", "XGBoost"],
                                                value="LSTM",
                                            ),
                                        ],
                                        className="radio-group",
                                    ),
                                    width="auto",
                                ),
                            ]
                        ),
                    ]
                ),
            ],
            body=True,
            style={
                "margin": "0rem 1rem 0rem",
            },
        ),
        dbc.Card(
            dbc.Tabs(
                [
                    dbc.Tab(
                        [
                            dcc.Graph(id="live-graph",style={
                                "margin_bottom": 0}),
                            dbc.Progress(
                                id="fetch-progress",
                                max=graph_update_interval_seconds,
                                style={
                                    "margin": "-1rem 5rem 2rem auto",
                                    "height": "5px",
                                    "width": "200px",
                                },
                            ),
                        ],
                        label="Chart",
                    ),
                    dbc.Tab(
                        [html.Div(id="live-table")],
                        label="Table",
                    ),
                ]
            ),
            style={
                "margin": "2rem 1rem 0rem",
            },
        ),
        dcc.Interval(
            id="interval-component",
            interval=1 * 100,  # in milliseconds
            n_intervals=0,
        ),
        dcc.Interval(
            id="update-interval-component",
            interval=graph_update_interval_seconds * 1000,  # in milliseconds
            n_intervals=0,
        ),
        html.Div(id="remaining-time"),
    ]
)


In [6]:
import pandas as pd
from plotly.subplots import make_subplots


@app.callback(
    [Output("remaining-time", "children"), Output("fetch-progress", "value")],
    [
        Input("interval-component", "n_intervals"),
    ],
)
def update_timer(n):
    remaining_time = (
        graph_update_interval_seconds - (n / 10) % graph_update_interval_seconds
    )
    remaining_time_text = ""  # f"Remaining time until next fetch: {remaining_time}"

    return [remaining_time_text, remaining_time]


@app.callback(
    [Output("live-graph", "figure"), Output("live-table", "children")],
    [Input("update-interval-component", "n_intervals"), Input("model-name", "value")],
)
def update_graph(n, model_name):
    ohlcv = binance.fetch_ohlcv(symbol, "1m", limit=1000)
    df = pd.DataFrame(
        ohlcv, columns=["timestamp", "open", "high", "low", "close", "volume"]
    )
    df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
    df["roc"] = np.concatenate((np.zeros(30), ROC(df["close"], 30)))

    n_to_predict = 100

    predictions_close = predict(df[["timestamp", "close"]], model_name, n_to_predict)
    predictions_close = list(np.concatenate(predictions_close))

    predictions_roc = predict(df[["timestamp", "roc"]], model_name, n_to_predict)
    predictions_roc = list(np.concatenate(predictions_roc))

    trace_close = go.Candlestick(
        x=df["timestamp"],
        open=df["open"],
        high=df["high"],
        low=df["low"],
        close=df["close"],
        name=symbol,
    )

    trace_roc = go.Scatter(
        x=df["timestamp"],
        y=df["roc"],
        mode="lines",
        name="ROC (30)",
        line=dict(color="cornflowerblue", width=1),
    )

    latest_timestamp = df["timestamp"].max()
    new_timestamp = latest_timestamp + pd.Timedelta(minutes=1)
    df = df.append({"timestamp": new_timestamp}, ignore_index=True)

    trace_predictions_close = go.Scatter(
        x=df["timestamp"].iloc[-n_to_predict:],
        y=predictions_close,
        mode="lines",
        name="Predicted Close",
        line=dict(color="orange", width=3),
    )

    trace_predictions_roc = go.Scatter(
        x=df["timestamp"].iloc[-n_to_predict:],
        y=predictions_roc,
        mode="lines",
        name="Predicted ROC (30)",
        line=dict(color="lightseagreen", width=3),
    )

    df["prediction_close"] = np.concatenate((np.zeros(901), predictions_close))
    df["prediction_roc"] = np.concatenate((np.zeros(901), predictions_roc))

    df["roc"] = df["roc"] * 100
    df["prediction_roc"] = df["prediction_roc"] * 100

    rangeselector = {
        "buttons": list(
            [
                dict(count=1, label="1m", step="minute", stepmode="backward"),
                dict(count=2, label="2m", step="minute", stepmode="backward"),
                dict(count=5, label="5m", step="minute", stepmode="backward"),
                dict(count=1, label="1h", step="hour", stepmode="backward"),
                dict(count=2, label="2h", step="hour", stepmode="backward"),
                dict(count=5, label="5h", step="hour", stepmode="backward"),
                dict(step="all"),
            ]
        )
    }

    layout = go.Layout(
        title=f"Live OHLCV Chart - {symbol}",
        xaxis=dict(
            rangeselector=rangeselector,
        ),
        yaxis=dict(title="Price", domain=[0.2, 1]),
        yaxis2=dict(title="ROC (30)", domain=[0.0, 0.2]),
        showlegend=True,
        height=700,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        xaxis_rangeslider_visible=False,
    )

    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1)
    fig.add_trace(trace_close, row=1, col=1)
    fig.add_trace(trace_predictions_close, row=1, col=1)
    fig.add_trace(trace_roc, row=2, col=1)
    fig.add_trace(trace_predictions_roc, row=2, col=1)
    fig.update_layout(layout)

    table = dash.dash_table.DataTable(
        df.iloc[::-1].to_dict("records"), style_cell={"font-family": "sans-serif"}
    )

    return fig, table


# Start the app.

In [7]:
import webbrowser

if __name__ == "__main__":
    # webbrowser.open('http://127.0.0.1:8050/', new=0, autoraise=True)
    app.run_server(debug=False)
