In [1]:
import ccxt

# Set up Binance API
binance = ccxt.binance()
symbol = "BTC/USDT"


In [2]:
import import_ipynb
from predictor import XGBoostPredictor, RNNPredictor, LSTMPredictor

import numpy as np


importing Jupyter notebook from predictor.ipynb


In [3]:
rnn = RNNPredictor("../models/rnn_model.h5")
lstm = LSTMPredictor("../models/lstm_model.h5")
xgb = XGBoostPredictor("../models/xgboost_model.json")


def predict(inputs, model_name, n_to_predict):
    if model_name == "RNN":
        model = rnn
    elif model_name == "LSTM":
        model = lstm
    elif model_name == "XGBoost":
        model = xgb

    return model.predict(inputs, n_to_predict)


In [4]:
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Output, Input, State

import plotly.graph_objs as go

import dash_bootstrap_components as dbc


The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`
  import dash_core_components as dcc
The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  import dash_html_components as html


In [5]:
dbc_css = "https://cdn.jsdelivr.net/gh/AnnMarieW/dash-bootstrap-templates/dbc.min.css"
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, dbc_css])

app.layout = html.Div(
    [
        html.H4(
            "BTC-USDT Analysis Dashboard",
            className="bg-primary text-white p-2 mb-2 text-center",
        ),
        dbc.Card(
            [
                html.Div(
                    [
                        dbc.Row(
                            [
                                dbc.Col(dbc.Label("Model"), width=2),
                                dbc.Col(
                                    html.Div(
                                        [
                                            dbc.RadioItems(
                                                id="model-name",
                                                className="btn-group",
                                                inputClassName="btn-check",
                                                labelClassName="btn btn-outline-primary",
                                                labelCheckedClassName="active",
                                                options=["LSTM", "RNN", "XGBoost"],
                                                value="LSTM",
                                            ),
                                        ],
                                        className="radio-group",
                                    ),
                                    width="auto",
                                ),
                            ]
                        ),
                        dbc.Row(
                            [
                                dbc.Col(dbc.Label("Features extraction"), width=2),
                                dbc.Col(
                                    dcc.Dropdown(
                                        id="features",
                                        options=[
                                            "Close",
                                            "Price of Change",
                                            # "RSI",
                                            # "Bolling Bands",
                                            # "Moving Average",
                                        ],
                                        value="Close",
                                        multi=True,
                                    ),
                                    width=7,
                                ),
                            ]
                        ),
                    ]
                ),
            ],
            body=True,
            style={
                "margin": "2rem 1rem 0rem",
            },
        ),
        dbc.Card(
            dbc.Tabs([dbc.Tab([dcc.Graph(id="live-graph")], label="Chart")]),
            style={
                "margin": "2rem 1rem 0rem",
            },
        ),
        html.Div(id="remaining-time"),
        dcc.Interval(
            id="interval-component",
            interval=15 * 1000,  # in milliseconds
            n_intervals=0,
        ),
    ]
)


In [6]:
import pandas as pd


@app.callback(
    [Output("live-graph", "figure"), Output("remaining-time", "children")],
    [Input("interval-component", "n_intervals"), Input("model-name", "value")],
)
def update_graph(n, model_name):
    ohlcv = binance.fetch_ohlcv(symbol, "1m", limit=1000)
    df = pd.DataFrame(
        ohlcv, columns=["timestamp", "open", "high", "low", "close", "volume"]
    )
    df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")

    input_data = df[["timestamp", "close"]]

    n_to_predict = 100

    predictions = predict(input_data[["timestamp", "close"]], model_name, n_to_predict)
    predictions = list(np.concatenate(predictions))

    trace = go.Candlestick(
        x=df["timestamp"],
        open=df["open"],
        high=df["high"],
        low=df["low"],
        close=df["close"],
        name=symbol,
    )

    latest_timestamp = df["timestamp"].max()
    new_timestamp = latest_timestamp + pd.Timedelta(minutes=1)
    df = df.append({"timestamp": new_timestamp}, ignore_index=True)

    trace_predictions = go.Scatter(
        x=df["timestamp"].iloc[-n_to_predict:],
        y=predictions,
        mode="lines",
        name="Predicted Close",
        line=dict(color="orange"),
    )

    rangeselector = {
        "buttons": list(
            [
                dict(count=1, label="1m", step="minute", stepmode="backward"),
                dict(count=2, label="2m", step="minute", stepmode="backward"),
                dict(count=5, label="5m", step="minute", stepmode="backward"),
                dict(count=1, label="1h", step="hour", stepmode="backward"),
                dict(count=2, label="2h", step="hour", stepmode="backward"),
                dict(count=5, label="5h", step="hour", stepmode="backward"),
                dict(step="all")
            ]
        )
    }

    layout = go.Layout(
        title=f"Live OHLCV Chart - {symbol}",
        xaxis=dict(
            title="Time",
            rangeselector=rangeselector,
        ),
        yaxis=dict(title="Price"),
        showlegend=True,
        height=700,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    )

    return {
        "data": [trace, trace_predictions],
        "layout": layout,
    }, f"input: {model_name}"


# Start the app.

In [7]:
import webbrowser

if __name__ == "__main__":
    # webbrowser.open('http://127.0.0.1:8050/', new=0, autoraise=True)
    app.run_server(debug=False)


before
[[0.29117394]
 [0.25068525]
 [0.22366669]
 [0.17229227]
 [0.13094213]
 [0.11441773]
 [0.11441773]
 [0.1441773 ]
 [0.14425562]
 [0.15232203]
 [0.17534654]
 [0.257342  ]
 [0.257342  ]
 [0.25655885]
 [0.27276999]
 [0.3032344 ]
 [0.33550004]
 [0.33479521]
 [0.30033675]
 [0.31247553]
 [0.31239721]
 [0.31247553]
 [0.37591041]
 [0.43801394]
 [0.45406845]
 [0.43965855]
 [0.42564022]
 [0.39157334]
 [0.39157334]
 [0.39157334]
 [0.42759809]
 [0.42759809]
 [0.42751977]
 [0.42751977]
 [0.43817057]
 [0.39165166]
 [0.39149503]
 [0.46174328]
 [0.43182708]
 [0.40715796]
 [0.41436291]
 [0.3966638 ]
 [0.3966638 ]
 [0.40731459]
 [0.47662307]
 [0.44921294]
 [0.37598872]
 [0.41506774]
 [0.37661524]
 [0.39157334]
 [0.34466286]
 [0.34458454]
 [0.33683139]
 [0.28185449]
 [0.28185449]
 [0.31701778]
 [0.31325867]
 [0.31239721]
 [0.3123189 ]
 [0.26462526]
 [0.24261884]
 [0.24254053]
 [0.24261884]
 [0.1932806 ]
 [0.18356958]
 [0.12162268]
 [0.10611638]
 [0.18325632]
 [0.1879552 ]
 [0.18803352]
 [0.01558462]