In [1]:
import dash
from dash import dcc, html
import plotly.graph_objs as go
import pandas as pd
from dash.dependencies import Input, Output


### Import all data 
Here we bring in all the datasets for the Bitcoin data. This includes:
1. Linear
- Univariate:
    - Classic
    - Quantile
- Multivariate:
    - Classic
    - Quantile
2. DLNN
- Univariate
   - BDLSTM:
       - Classic
       - Quantile
   - CLSTM:
        - Classic
        - Quantile
- Multivariate
    - BDLSTM:
        - Classic
        - Quantile
    - EDLSTM:
        - Classic
        - Quantile

In [23]:
import dash
from dash import dcc, html
import plotly.graph_objs as go
import pandas as pd
from dash.dependencies import Input, Output

# Load the datasets
datasets = {
    'linear_uni_classic': {
        'predictions': pd.read_csv('linear_uni_classic_b_predictions.csv'),
        'mae': pd.read_csv('linear_uni_classic_b_mae.csv'),
        'mape': pd.read_csv('linear_uni_classic_b_mape.csv'),
        'rmse': pd.read_csv('linear_uni_classic_b_rmse.csv')
    },
    'linear_uni_quantile': {
        'predictions': pd.read_csv('linear_uni_quantile_b_predictions.csv'),
        'mae': pd.read_csv('linear_uni_quantile_b_mae.csv'),
        'mape': pd.read_csv('linear_uni_quantile_b_mape.csv'),
        'rmse': pd.read_csv('linear_uni_quantile_b_rmse.csv')
    },
    'linear_multi_classic': {
        'predictions': pd.read_csv('linear_multi_classic_b_predictions.csv'),
        'mae': pd.read_csv('linear_multi_classic_b_mae.csv'),
        'mape': pd.read_csv('linear_multi_classic_b_mape.csv'),
        'rmse': pd.read_csv('linear_multi_classic_b_rmse.csv')
    },
    'linear_multi_quantile': {
        'predictions': pd.read_csv('linear_multi_quantile_b_predictions.csv'),
        'mae': pd.read_csv('linear_multi_quantile_b_mae.csv'),
        'mape': pd.read_csv('linear_multi_quantile_b_mape.csv'),
        'rmse': pd.read_csv('linear_multi_quantile_b_rmse.csv')
    },
    'dlnn_uni_bdlstm_classic': {
        'predictions': pd.read_csv('dlnn_uni_classic_bdlstm_b_predictions.csv'),
        'mae': pd.read_csv('dlnn_uni_classic_bdlstm_b_mae.csv'),
        'mape': pd.read_csv('dlnn_uni_classic_bdlstm_b_mape.csv'),
        'rmse': pd.read_csv('dlnn_uni_classic_bdlstm_b_rmse.csv')
    },
    'dlnn_uni_bdlstm_quantile': {
        'predictions': pd.read_csv('dlnn_uni_quantile_bdlstm_b_predictions.csv'),
        'mae': pd.read_csv('dlnn_uni_quantile_bdlstm_b_mae.csv'),
        'mape': pd.read_csv('dlnn_uni_quantile_bdlstm_b_mape.csv'),
        'rmse': pd.read_csv('dlnn_uni_quantile_bdlstm_b_rmse.csv')
    },
    'dlnn_uni_clstm_classic': {
        'predictions': pd.read_csv('dlnn_uni_classic_clstm_b_predictions.csv'),
        'mae': pd.read_csv('dlnn_uni_classic_clstm_b_mae.csv'),
        'mape': pd.read_csv('dlnn_uni_classic_clstm_b_mape.csv'),
        'rmse': pd.read_csv('dlnn_uni_classic_clstm_b_rmse.csv')
    },
    'dlnn_uni_clstm_quantile': {
        'predictions': pd.read_csv('dlnn_uni_quantile_clstm_b_predictions.csv'),
        'mae': pd.read_csv('dlnn_uni_quantile_clstm_b_mae.csv'),
        'mape': pd.read_csv('dlnn_uni_quantile_clstm_b_mape.csv'),
        'rmse': pd.read_csv('dlnn_uni_quantile_clstm_b_rmse.csv')
    },
    'dlnn_multi_bdlstm_classic': {
        'predictions': pd.read_csv('dlnn_multi_classic_bdlstm_b_predictions.csv'),
        'mae': pd.read_csv('dlnn_multi_classic_bdlstm_b_mae.csv'),
        'mape': pd.read_csv('dlnn_multi_classic_bdlstm_b_mape.csv'),
        'rmse': pd.read_csv('dlnn_multi_classic_bdlstm_b_rmse.csv')
    },
    'dlnn_multi_bdlstm_quantile': {
        'predictions': pd.read_csv('dlnn_multi_quantile_bdlstm_b_predictions.csv'),
        'mae': pd.read_csv('dlnn_multi_quantile_bdlstm_b_mae.csv'),
        'mape': pd.read_csv('dlnn_multi_quantile_bdlstm_b_mape.csv'),
        'rmse': pd.read_csv('dlnn_multi_quantile_bdlstm_b_rmse.csv')
    },
    'dlnn_multi_edlstm_classic': {
        'predictions': pd.read_csv('dlnn_multi_classic_edlstm_b_predictions.csv'),
        'mae': pd.read_csv('dlnn_multi_classic_edlstm_b_mae.csv'),
        'mape': pd.read_csv('dlnn_multi_classic_edlstm_b_mape.csv'),
        'rmse': pd.read_csv('dlnn_multi_classic_edlstm_b_rmse.csv')
    },
    'dlnn_multi_edlstm_quantile': {
        'predictions': pd.read_csv('dlnn_multi_quantile_edlstm_b_predictions.csv'),
        'mae': pd.read_csv('dlnn_multi_quantile_edlstm_b_mae.csv'),
        'mape': pd.read_csv('dlnn_multi_quantile_edlstm_b_mape.csv'),
        'rmse': pd.read_csv('dlnn_multi_quantile_edlstm_b_rmse.csv')
    }
}

In [42]:
import dash
from dash import dcc, html
import plotly.graph_objs as go
import pandas as pd
from dash.dependencies import Input, Output

def format_decimals(df, decimals=4):
    return df.applymap(lambda x: round(x, decimals) if isinstance(x, (int, float)) else x)

def clean_columns(df):
    df.columns = df.columns.str.strip()
    return df

def rename_columns(df):
    new_columns = ['Number of Experiments'] + [f'Time Step {i}' for i in range(1, 6)]
    df.columns = new_columns
    return df

for key in datasets.keys():
    datasets[key]['predictions'] = clean_columns(datasets[key]['predictions'])
    datasets[key]['mae'] = clean_columns(datasets[key]['mae'])
    datasets[key]['mape'] = clean_columns(datasets[key]['mape'])
    datasets[key]['rmse'] = clean_columns(datasets[key]['rmse'])
    datasets[key]['predictions'] = format_decimals(datasets[key]['predictions'])
    datasets[key]['mae'] = format_decimals(datasets[key]['mae'])
    datasets[key]['mape'] = format_decimals(datasets[key]['mape'])
    datasets[key]['rmse'] = format_decimals(datasets[key]['rmse'])
    datasets[key]['mae'] = rename_columns(datasets[key]['mae'])
    datasets[key]['mape'] = rename_columns(datasets[key]['mape'])
    datasets[key]['rmse'] = rename_columns(datasets[key]['rmse'])

close_price_df = pd.read_csv('close_price.csv', parse_dates=['Date'])
close_price_df = clean_columns(close_price_df)

close_price_df.sort_values('Date', inplace=True)

date_ranges = {
    '2013': ['2013-01-01', '2013-12-31'],
    '2014': ['2014-01-01', '2014-12-31'],
    '2015': ['2015-01-01', '2015-12-31'],
    '2016': ['2016-01-01', '2016-12-31'],
    '2017': ['2017-01-01', '2017-12-31'],
    '2018': ['2018-01-01', '2018-12-31'],
    '2019': ['2019-01-01', '2019-12-31'],
    '2020': ['2020-01-01', '2020-12-31'],
    '2021': ['2021-01-01', '2021-12-31']
}


for key in date_ranges:
    date_ranges[key] = [pd.Timestamp(date) for date in date_ranges[key]]

app = dash.Dash(__name__)

app.layout = html.Div(
    children=[
        html.H1(children='Cryptocurrency Close Price Dashboard: Bitcoin'),
        html.Table(
            [
                html.Thead(
                    html.Tr(
                        [html.Th("Category", style={'border': '1px solid lightgray'}), 
                         html.Th("Subcategory", style={'border': '1px solid lightgray'}), 
                         html.Th("Type", style={'border': '1px solid lightgray'}), 
                         html.Th("Loss function", style={'border': '1px solid lightgray'})]
                    )
                ),
                html.Tbody([
                    html.Tr([html.Td("Linear", rowSpan=4, style={'border': '1px solid lightgray'}), 
                             html.Td("Univariate", rowSpan=2, style={'border': '1px solid lightgray'}), 
                             html.Td("", rowSpan=2, style={'border': '1px solid lightgray'}), 
                             html.Td("Classic", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("Quantile", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("Multivariate", rowSpan=2, style={'border': '1px solid lightgray'}), 
                             html.Td("", rowSpan=2, style={'border': '1px solid lightgray'}), 
                             html.Td("Classic", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("Quantile", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("DLNN", rowSpan=8, style={'border': '1px solid lightgray'}), 
                             html.Td("Univariate", rowSpan=4, style={'border': '1px solid lightgray'}), 
                             html.Td("BDLSTM", rowSpan=2, style={'border': '1px solid lightgray'}), 
                             html.Td("Classic", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("Quantile", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("CLSTM", rowSpan=2, style={'border': '1px solid lightgray'}), 
                             html.Td("Classic", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("Quantile", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("Multivariate", rowSpan=4, style={'border': '1px solid lightgray'}), 
                             html.Td("BDLSTM", rowSpan=2, style={'border': '1px solid lightgray'}), 
                             html.Td("Classic", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("Quantile", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("EDLSTM", rowSpan=2, style={'border': '1px solid lightgray'}), 
                             html.Td("Classic", style={'border': '1px solid lightgray'})]),
                    html.Tr([html.Td("Quantile", style={'border': '1px solid lightgray'})]),
                ])
            ],
            style={
                'width': '50%',
                'margin': 'auto',
                'border': '1px solid lightgray',
                'borderCollapse': 'collapse',
                'textAlign': 'center'
            }
        ),
        dcc.Dropdown(
            id='model-dropdown',
            options=[
                {'label': 'Linear Univariate Classic', 'value': 'linear_uni_classic'},
                {'label': 'Linear Univariate Quantile', 'value': 'linear_uni_quantile'},
                {'label': 'Linear Multivariate Classic', 'value': 'linear_multi_classic'},
                {'label': 'Linear Multivariate Quantile', 'value': 'linear_multi_quantile'},
                {'label': 'DLNN Univariate BDLSTM Classic', 'value': 'dlnn_uni_bdlstm_classic'},
                {'label': 'DLNN Univariate BDLSTM Quantile', 'value': 'dlnn_uni_bdlstm_quantile'},
                {'label': 'DLNN Univariate CLSTM Classic', 'value': 'dlnn_uni_clstm_classic'},
                {'label': 'DLNN Univariate CLSTM Quantile', 'value': 'dlnn_uni_clstm_quantile'},
                {'label': 'DLNN Multivariate BDLSTM Classic', 'value': 'dlnn_multi_bdlstm_classic'},
                {'label': 'DLNN Multivariate BDLSTM Quantile', 'value': 'dlnn_multi_bdlstm_quantile'},
                {'label': 'DLNN Multivariate EDLSTM Classic', 'value': 'dlnn_multi_edlstm_classic'},
                {'label': 'DLNN Multivariate EDLSTM Quantile', 'value': 'dlnn_multi_edlstm_quantile'}
            ],
            value='linear_uni_classic'  # Default value
        ),
        dcc.Dropdown(
            id='date-range-dropdown',
            options=[{'label': k, 'value': k} for k in date_ranges.keys()],
            value='2013'
        ),
        dcc.Graph(id='close-price-plot'),
        dcc.Slider(
            id='prediction-slider',
            min=0,
            max=1,
            step=1,
            value=0,
            marks={}
        ),
        html.Div(id='prediction-detail-plots'),
        dcc.Store(id='filtered-predictions'),
        html.H2('Error Metrics'),
        dcc.Dropdown(
            id='error-type-dropdown',
            options=[
                {'label': 'MAE', 'value': 'mae'},
                {'label': 'MAPE', 'value': 'mape'},
                {'label': 'RMSE', 'value': 'rmse'}
            ],
            value='mae'  # Default value
        ),
        html.Div(id='error-table')
    ],
    style={
        'backgroundColor': 'white',
        'fontFamily': 'Helvetica',
        'padding': '10px',
        'width': '100%',
        'boxSizing': 'border-box',
        'overflow': 'auto'
    }
)

@app.callback(
    [Output('close-price-plot', 'figure'),
     Output('prediction-slider', 'marks'),
     Output('prediction-slider', 'max'),
     Output('filtered-predictions', 'data')],
    [Input('date-range-dropdown', 'value'),
     Input('model-dropdown', 'value')]
)
def update_close_price_plot(selected_range, model_value):
    
    predictions_df = datasets[model_value]['predictions']

    
    start_date, end_date = date_ranges[selected_range]
    filtered_df = close_price_df[(close_price_df['Date'] >= start_date) & (close_price_df['Date'] <= end_date)]

    
    close_price_trace = go.Scatter(
        x=filtered_df['Date'],
        y=filtered_df['Close'],
        mode='lines',
        name='Close Price'
    )

    
    pred_traces = []
    slider_marks = {}
    filtered_predictions = []

    if 'classic' in model_value:
        for idx, row in predictions_df.iterrows():
            dates = close_price_df.index[close_price_df['Close'] == row['Actual 0']].tolist()
            if dates:
                start_idx = dates[0]
                end_idx = start_idx + 5
                if (start_idx + 4) < len(close_price_df):
                    if close_price_df.iloc[start_idx]['Date'] >= start_date and close_price_df.iloc[end_idx - 1]['Date'] <= end_date:
                        pred_trace = go.Scatter(
                            x=close_price_df['Date'][start_idx:end_idx],
                            y=[row[f'Pred {i}'] for i in range(5)],
                            mode='lines',
                            name=f'Prediction {idx}',
                            customdata=[idx] * 5,
                            hoverinfo='x+y+name'
                        )
                        pred_traces.append(pred_trace)
                        slider_marks[len(filtered_predictions)] = str(idx)
                        filtered_predictions.append(idx)
    else:  
        for idx, row in predictions_df.iterrows():
            dates = close_price_df.index[close_price_df['Close'] == row['Actual 0']].tolist()
            if dates:
                start_idx = dates[0]
                end_idx = start_idx + 5
                if (start_idx + 4) < len(close_price_df):
                    if close_price_df.iloc[start_idx]['Date'] >= start_date and close_price_df.iloc[end_idx - 1]['Date'] <= end_date:
                        median_trace = go.Scatter(
                            x=close_price_df['Date'][start_idx:end_idx],
                            y=[row[f'Q3_T{i + 1}'] for i in range(5)],
                            mode='lines',
                            line=dict(dash='dot'),
                            name=f'Median Quantile {idx}',
                            customdata=[idx] * 5,
                            hoverinfo='x+y+name'
                        )
                        pred_traces.append(median_trace)
                        slider_marks[len(filtered_predictions)] = str(idx)
                        filtered_predictions.append(idx)

    
    figure = {
        'data': [close_price_trace] + pred_traces,
        'layout': {
            'title': f'Close Price with Predictions ({selected_range})',
            'xaxis': {'title': 'Date'},
            'yaxis': {'title': 'Price'},
            'clickmode': 'event+select'
        }
    }

    return figure, slider_marks, len(filtered_predictions) - 1, filtered_predictions

@app.callback(
    Output('prediction-detail-plots', 'children'),
    [Input('prediction-slider', 'value'),
     Input('filtered-predictions', 'data'),
     Input('model-dropdown', 'value')]
)
def update_prediction_detail_plots(slider_value, filtered_predictions, model_value):
    if slider_value is None or slider_value >= len(filtered_predictions):
        return []

    prediction_index = filtered_predictions[slider_value]

    # Load the relevant data
    predictions_df = datasets[model_value]['predictions']

    if 'classic' in model_value:
        row = predictions_df.iloc[prediction_index]
        actual_values = [row[f'Actual {i}'] for i in range(5)]
        pred_values = [row[f'Pred {i}'] for i in range(5)]

        dates = close_price_df.index[close_price_df['Close'] == row['Actual 0']].tolist()
        if not dates:
            return []

        start_idx = dates[0]
        end_idx = start_idx + 5
        close_dates = close_price_df['Date'][start_idx:end_idx]

        actual_trace = go.Scatter(
            x=close_dates,
            y=actual_values,
            mode='lines+markers',
            name='Actual Close Prices'
        )
        pred_trace = go.Scatter(
            x=close_dates,
            y=pred_values,
            mode='lines+markers',
            name='Predicted Prices'
        )

        figure = {
            'data': [actual_trace, pred_trace],
            'layout': {
                'title': f'Prediction Detail for Index {prediction_index}',
                'xaxis': {'title': 'Date'},
                'yaxis': {'title': 'Price'}
            }
        }
    else:  # Quantile model
        row = predictions_df.iloc[prediction_index]
        actual_values = [row[f'Actual {i}'] for i in range(5)]
        median_values = [row[f'Q3_T{i + 1}'] for i in range(5)]
        quantile_traces = []

        dates = close_price_df.index[close_price_df['Close'] == row['Actual 0']].tolist()
        if not dates:
            return []

        start_idx = dates[0]
        end_idx = start_idx + 5
        close_dates = close_price_df['Date'][start_idx:end_idx]

       
        quantile_colors = ['rgba(31, 119, 180, 0.2)',  # Blue
                           'rgba(255, 127, 14, 0.2)',  # Orange
                           'rgba(44, 160, 44, 0.2)',   # Green
                           'rgba(214, 39, 40, 0.2)',   # Red
                           'rgba(148, 103, 189, 0.2)'] # Purple

        
        for q in range(1, 6):
            quantile_trace = go.Scatter(
                x=close_dates,
                y=[row[f'Q{q}_T{i + 1}'] for i in range(5)],
                mode='lines',
                name=f'Quantile {q}',
                line=dict(color=quantile_colors[q-1][:16] + ')'),
                showlegend=(q == 1)
            )
            quantile_traces.append(quantile_trace)

        median_trace = go.Scatter(
            x=close_dates,
            y=median_values,
            mode='lines',
            line=dict(dash='dot', color='rgba(44, 160, 44, 1)'),  
            name='Median Quantile'
        )

        actual_trace = go.Scatter(
            x=close_dates,
            y=actual_values,
            mode='lines+markers',
            name='Actual Close Prices',
            line=dict(color='black')  
        )

        fills = []
        for i in range(1, 5):
            fill = go.Scatter(
                x=close_dates.tolist() + close_dates[::-1].tolist(),
                y=[row[f'Q{i}_T{j + 1}'] for j in range(5)] + [row[f'Q{i + 1}_T{j + 1}'] for j in range(5)][::-1],
                fill='toself',
                fillcolor=quantile_colors[i-1],
                line=dict(color='rgba(255,255,255,0)'),
                showlegend=False
            )
            fills.append(fill)

        figure = {
            'data': fills + quantile_traces + [median_trace, actual_trace],
            'layout': {
                'title': f'Prediction Detail for Index {prediction_index}',
                'xaxis': {'title': 'Date'},
                'yaxis': {'title': 'Price'}
            }
        }

    return html.Div(
        children=[html.Div(dcc.Graph(figure=figure), style={'flex': '1 0 100%', 'margin': '10px'})],
        style={'display': 'flex', 'flexWrap': 'wrap', 'justifyContent': 'center'}
    )

@app.callback(
    Output('error-table', 'children'),
    [Input('error-type-dropdown', 'value'),
     Input('model-dropdown', 'value')]
)
def update_error_table(error_type, model_value):
    df = datasets[model_value][error_type]
    
    return html.Table(
        [
            html.Thead(html.Tr([html.Th(col, style={'textAlign': 'center'}) for col in df.columns])),
            html.Tbody([
                html.Tr([
                    html.Td(df.iloc[i][col], style={'textAlign': 'center'}) for col in df.columns
                ]) for i in range(len(df))
            ])
        ],
        style={'width': '100%', 'margin': '20px 0', 'textAlign': 'center'}
    )


if __name__ == '__main__':
    app.run_server(debug=True, port=8050)



DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.

