# Stock Comparison Tool

Trying to implement something similar to https://www.barchart.com/myBarchart/quotes/SPY/interactive-chart but in Gradio.

The code in the first code block of this notebook is executed when the `app.py` file is run. This notebook is used for the simplicity of interactive development and testing.

In [1]:
from datetime import timedelta

import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import yfinance as yf

# retrieve historical stock prices
tickers = ["AAPL", "MSFT", "GOOGL"]
df = yf.download(
    tickers,
    interval="1d",
    period="max",
    progress=False,
).Close.bfill()
df

Ticker,AAPL,GOOGL,MSFT
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1980-12-12,0.098834,2.501941,0.059946
1980-12-15,0.093678,2.501941,0.059946
1980-12-16,0.086802,2.501941,0.059946
1980-12-17,0.088951,2.501941,0.059946
1980-12-18,0.091530,2.501941,0.059946
...,...,...,...
2024-12-26,259.019989,195.600006,438.109985
2024-12-27,255.589996,192.759995,430.529999
2024-12-30,252.199997,191.240005,424.829987
2024-12-31,250.419998,189.300003,421.500000


### Gradio plot -- this works but is slow to update

In [2]:
def plot_asset_prices(period, shift):
    end_date = df.index[-1] - timedelta(days=shift)
    start_date = end_date - timedelta(days=period)
    df_normalized = df[(df.index >= start_date) & (df.index <= end_date)]
    df_normalized = df_normalized / df_normalized.iloc[0] - 1
    df_normalized.reset_index(names="Date", inplace=True)
    return gr.LinePlot(
        value=pd.melt(df_normalized, id_vars=["Date"], var_name="Asset", value_name="Price"),
        x="Date",
        y="Price",
        color="Asset",
        title="Normalized Asset Prices",
        y_title="Relative Change",
        x_label_angle=45,
        # height=600,
    )


with gr.Blocks() as demo:
    plot = plot_asset_prices(365, 0)
    with gr.Row():
        period = gr.Radio(
            choices=[
                ("5y", 5 * 365),
                ("3y", 2 * 365),
                ("2y", 2 * 365),
                ("1y", 365),
                ("6mo", 182),
                ("1mo", 30),
                ("1w", 7),
            ],
            value=365,
            label="Period",
        )
        shift = gr.Slider(minimum=0, maximum=365, value=0, label="End Date")
    period.change(plot_asset_prices, inputs=[period, shift], outputs=plot)
    shift.change(plot_asset_prices, inputs=[period, shift], outputs=plot)

demo.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




### Plotly plot -- this is super responsive but does not update

In [3]:
dg = pd.DataFrame(np.random.randn(1000, 3).cumsum(axis=0), columns=["A", "B", "C"])

fig = go.Figure()
for col in dg:
    fig.add_trace(go.Scatter(x=dg.index, y=dg[col], mode="lines", name=col))
fig.update_layout(
    xaxis_rangeslider_visible=True,
)

In [4]:
# import numpy as np
# import plotly.graph_objects as go
# from dash import Dash, Input, Output, State, dcc, html

# # Generate sample data
# x = np.linspace(0, 100, 1000)
# y1 = np.sin(x)
# y2 = np.cos(x)

# # Initialize Dash app
# app = Dash(__name__)

# # Layout of the Dash app
# app.layout = html.Div(
#     [
#         dcc.Graph(id="interactive-plot"),
#     ]
# )


# @app.callback(
#     Output("interactive-plot", "figure"),
#     Input("interactive-plot", "relayoutData"),
#     State("interactive-plot", "figure"),  # Preserve current state to avoid resetting
# )
# def update_plot(relayout_data, current_figure):
#     # Determine current x-axis range
#     if relayout_data and "xaxis.range[0]" in relayout_data and "xaxis.range[1]" in relayout_data:
#         x_start = relayout_data["xaxis.range[0]"]
#         x_end = relayout_data["xaxis.range[1]"]
#     elif current_figure:  # Use the existing figure's range if available
#         x_start, x_end = current_figure["layout"]["xaxis"]["range"]
#     else:
#         x_start, x_end = 0, 20  # Default initial range

#     # Filter and shift data for the selected range
#     mask = (x >= x_start) & (x <= x_end)
#     x_window = x[mask]
#     y1_window = y1[mask]
#     y2_window = y2[mask]

#     if len(x_window) > 0:  # Avoid errors with empty slices
#         y1_shifted = y1_window - y1_window[0]
#         y2_shifted = y2_window - y2_window[0]
#     else:
#         y1_shifted, y2_shifted = [], []  # Empty data if the range is invalid

#     # Create the updated figure
#     fig = go.Figure()
#     fig.add_trace(go.Scatter(x=x_window, y=y1_shifted, mode="lines", name="sin(x)"))
#     fig.add_trace(go.Scatter(x=x_window, y=y2_shifted, mode="lines", name="cos(x)"))

#     # Keep the `rangeslider` showing the full range of the data
#     fig.update_layout(
#         xaxis=dict(
#             rangeslider=dict(visible=True, range=[x[0], x[-1]]),  # Full data extent
#             range=[x_start, x_end],  # Main plot range matches user interaction
#         ),
#         title="Interactive Plot with Dynamic Data Shifting",
#     )

#     return fig


# if __name__ == "__main__":
#     app.run_server(debug=True, port=8051)

In [5]:
# import numpy as np
# import plotly.graph_objects as go
# from dash import Dash, Input, Output, State, dcc, html

# # Generate sample data
# x = np.linspace(0, 100, 1000)
# y1 = np.sin(x)
# y2 = np.cos(x)

# # Initialize Dash app
# app = Dash(__name__)

# # Layout of the Dash app
# app.layout = html.Div(
#     [
#         dcc.Graph(id="interactive-plot"),
#     ]
# )


# @app.callback(
#     Output("interactive-plot", "figure"),
#     Input("interactive-plot", "relayoutData"),
#     State("interactive-plot", "figure"),  # Preserve current state to avoid resetting
# )
# def update_plot(relayout_data, current_figure):
#     # Determine current x-axis range
#     if relayout_data and "xaxis.range[0]" in relayout_data and "xaxis.range[1]" in relayout_data:
#         x_start = relayout_data["xaxis.range[0]"]
#         x_end = relayout_data["xaxis.range[1]"]
#     elif current_figure:  # Use the existing figure's range if available
#         x_start, x_end = current_figure["layout"]["xaxis"]["range"]
#     else:
#         x_start, x_end = 0, 20  # Default initial range

#     # Filter and shift data for the selected range
#     mask = (x >= x_start) & (x <= x_end)
#     x_window = x[mask]
#     y1_window = y1[mask]
#     y2_window = y2[mask]

#     if len(x_window) > 0:  # Avoid errors with empty slices
#         y1_shifted = y1_window - y1_window[0]
#         y2_shifted = y2_window - y2_window[0]
#     else:
#         y1_shifted, y2_shifted = [], []  # Empty data if the range is invalid

#     # Create the updated figure
#     fig = go.Figure()

#     # Main plot: Add shifted data
#     fig.add_trace(go.Scatter(x=x_window, y=y1_shifted, mode="lines", name="Shifted sin(x)"))
#     fig.add_trace(go.Scatter(x=x_window, y=y2_shifted, mode="lines", name="Shifted cos(x)"))

#     # Rangeslider: Add static unshifted data
#     fig.update_layout(
#         xaxis=dict(
#             rangeslider=dict(visible=True, range=[x[0], x[-1]], thickness=0.1),  # Full data extent
#             range=[x_start, x_end],  # Main plot range matches user interaction
#         ),
#         title="Interactive Plot with Static Rangeslider",
#     )

#     # Add static traces for the rangeslider (unshifted data)
#     fig.update_layout(
#         xaxis_rangeslider=dict(
#             visible=True,
#             bgcolor="lightgray",  # Optional: distinguish rangeslider background
#         )
#     )
#     fig.add_trace(
#         go.Scatter(
#             x=x,
#             y=y1,
#             mode="lines",
#             name="Original sin(x)",
#             line=dict(width=0.5),
#             opacity=0.5,
#             showlegend=False,
#         )
#     )
#     fig.add_trace(
#         go.Scatter(
#             x=x,
#             y=y2,
#             mode="lines",
#             name="Original cos(x)",
#             line=dict(width=0.5),
#             opacity=0.5,
#             showlegend=False,
#         )
#     )

#     return fig


# if __name__ == "__main__":
#     app.run_server(debug=True, port=8052)

In [6]:
# import numpy as np
# import plotly.graph_objects as go
# from dash import Dash, Input, Output, State, dcc, html

# # Generate sample data
# x = np.linspace(0, 100, 1000)
# y1 = np.sin(x)
# y2 = np.cos(x)

# # Initialize Dash app
# app = Dash(__name__)

# # Layout of the Dash app
# app.layout = html.Div(
#     [
#         dcc.Graph(id="interactive-plot"),
#     ]
# )


# @app.callback(
#     Output("interactive-plot", "figure"),
#     Input("interactive-plot", "relayoutData"),
#     State("interactive-plot", "figure"),  # Preserve current state to avoid resetting
# )
# def update_plot(relayout_data, current_figure):
#     # Determine current x-axis range
#     if relayout_data and "xaxis.range[0]" in relayout_data and "xaxis.range[1]" in relayout_data:
#         x_start = relayout_data["xaxis.range[0]"]
#         x_end = relayout_data["xaxis.range[1]"]
#     elif current_figure:  # Use the existing figure's range if available
#         x_start, x_end = current_figure["layout"]["xaxis"]["range"]
#     else:
#         x_start, x_end = 0, 20  # Default initial range

#     # Filter and shift data for the selected range
#     mask = (x >= x_start) & (x <= x_end)
#     x_window = x[mask]
#     y1_window = y1[mask]
#     y2_window = y2[mask]

#     if len(x_window) > 0:  # Avoid errors with empty slices
#         y1_shifted = y1_window - y1_window[0]
#         y2_shifted = y2_window - y2_window[0]
#     else:
#         y1_shifted, y2_shifted = [], []  # Empty data if the range is invalid

#     # Create the updated figure
#     fig = go.Figure()

#     # Main plot: Add shifted data
#     fig.add_trace(
#         go.Scatter(x=x_window, y=y1_shifted, mode="lines", name="Shifted sin(x)", showlegend=True)
#     )
#     fig.add_trace(
#         go.Scatter(x=x_window, y=y2_shifted, mode="lines", name="Shifted cos(x)", showlegend=True)
#     )

#     # Ensure rangeslider shows the original unshifted data
#     fig.update_layout(
#         xaxis=dict(
#             rangeslider=dict(
#                 visible=True,
#                 range=[x[0], x[-1]],  # Full data extent
#                 thickness=0.1,
#                 yaxis=dict(
#                     range=[
#                         min(min(y1), min(y2)),
#                         max(max(y1), max(y2)),
#                     ]  # Ensure rangeslider captures full y-range
#                 ),
#             ),
#             range=[x_start, x_end],  # Main plot range matches user interaction
#         ),
#         title="Interactive Plot with Separate Rangeslider",
#     )

#     # Add static traces for rangeslider
#     fig.update_xaxes(rangeslider=dict(visible=True))
#     fig.add_trace(
#         go.Scatter(
#             x=x,
#             y=y1,
#             mode="lines",
#             name="Original sin(x)",
#             showlegend=False,
#             line=dict(width=1),
#             opacity=0.3,
#             visible=True,
#         )
#     )
#     fig.add_trace(
#         go.Scatter(
#             x=x,
#             y=y2,
#             mode="lines",
#             name="Original cos(x)",
#             showlegend=False,
#             line=dict(width=1),
#             opacity=0.3,
#             visible=True,
#         )
#     )

#     return fig


# if __name__ == "__main__":
#     app.run_server(debug=True, port=8053)

In [7]:
# import numpy as np
# import plotly.graph_objects as go
# from dash import Dash, Input, Output, State, dcc, html

# # Generate sample data
# x = np.linspace(0, 100, 1000)
# y1 = np.sin(x)
# y2 = np.cos(x)

# # Initialize Dash app
# app = Dash(__name__)

# # Layout of the Dash app
# app.layout = html.Div(
#     [
#         dcc.Graph(id="interactive-plot"),
#     ]
# )


# @app.callback(
#     Output("interactive-plot", "figure"),
#     Input("interactive-plot", "relayoutData"),
#     State("interactive-plot", "figure"),  # Preserve current state to avoid resetting
# )
# def update_plot(relayout_data, current_figure):
#     # Determine current x-axis range
#     if relayout_data and "xaxis.range[0]" in relayout_data and "xaxis.range[1]" in relayout_data:
#         x_start = relayout_data["xaxis.range[0]"]
#         x_end = relayout_data["xaxis.range[1]"]
#     elif current_figure:  # Use the existing figure's range if available
#         x_start, x_end = current_figure["layout"]["xaxis"]["range"]
#     else:
#         x_start, x_end = 0, 20  # Default initial range

#     # Filter and shift data for the selected range
#     mask = (x >= x_start) & (x <= x_end)
#     x_window = x[mask]
#     y1_window = y1[mask]
#     y2_window = y2[mask]

#     if len(x_window) > 0:  # Avoid errors with empty slices
#         y1_shifted = y1_window - y1_window[0]
#         y2_shifted = y2_window - y2_window[0]
#     else:
#         y1_shifted, y2_shifted = [], []  # Empty data if the range is invalid

#     # Create the updated figure
#     fig = go.Figure()

#     # Main plot: Add shifted data
#     fig.add_trace(
#         go.Scatter(x=x_window, y=y1_shifted, mode="lines", name="Shifted sin(x)", showlegend=True)
#     )
#     fig.add_trace(
#         go.Scatter(x=x_window, y=y2_shifted, mode="lines", name="Shifted cos(x)", showlegend=True)
#     )

#     # Configure rangeslider to only show unshifted lines
#     fig.update_layout(
#         xaxis=dict(
#             rangeslider=dict(visible=True, range=[x[0], x[-1]], thickness=0.1),  # Full data extent
#             range=[x_start, x_end],  # Main plot range matches user interaction
#         ),
#         title="Interactive Plot with Static Rangeslider",
#     )

#     # Add unshifted traces specifically for the rangeslider
#     fig.update_traces(
#         visible=True, selector=dict(name="Shifted sin(x)")
#     )  # Hide shifted lines in the rangeslider
#     fig.update_traces(
#         visible=True, selector=dict(name="Shifted cos(x)")
#     )  # Hide shifted lines in the rangeslider

#     # Add original unshifted traces explicitly
#     fig.add_trace(
#         go.Scatter(
#             x=x,
#             y=y1,
#             mode="lines",
#             name="Original sin(x)",
#             showlegend=False,
#             line=dict(width=1),
#             opacity=0.3,
#         )
#     )
#     fig.add_trace(
#         go.Scatter(
#             x=x,
#             y=y2,
#             mode="lines",
#             name="Original cos(x)",
#             showlegend=False,
#             line=dict(width=1),
#             opacity=0.3,
#         )
#     )

#     return fig


# if __name__ == "__main__":
#     app.run_server(debug=True, port=8054)

### Plotly MVE -- finally

In [26]:
import numpy as np
import plotly.graph_objects as go
from dash import Dash, Input, Output, State, dcc, html

x = np.linspace(0, 100, 1000)
y = np.sin(6 * x) + 10  # raw
z = np.cos(x) * 2  # normalized


def plot_prices(x, y, z, x_range=(None, None)):
    fig = go.Figure()

    # rangeslider plot
    fig.add_trace(
        go.Scatter(
            x=x,
            y=y,
            xaxis="x1",
            yaxis="y1",
            showlegend=False,
        )
    )

    # main plot
    fig.add_trace(
        go.Scatter(
            x=x,
            y=z,
            name="normalized asset prices",
            xaxis="x2",
            yaxis="y2",
        )
    )

    # configure axes
    if all(x_range):
        xaxis1_dict = dict(
            rangeslider=dict(visible=True),
            range=x_range,
        )
        xaxis2_dict = dict(
            matches="x1",
            range=x_range,
        )
    else:
        xaxis1_dict = dict(rangeslider=dict(visible=True))
        xaxis2_dict = dict(matches="x1")

    yaxis1_dict = dict(
        range=[9, 11],
        showticklabels=False,
    )
    yaxis2_dict = dict(title="Relative Change", range=[-2, 2])

    fig.update_layout(
        xaxis1=xaxis1_dict,
        yaxis1=yaxis1_dict,
        xaxis2=xaxis2_dict,
        yaxis2=yaxis2_dict,
    )

    return fig

In [27]:
def get_x_range(relayout_data, current_figure, x_range_default=(None, None)):
    # Default
    x_start, x_end = x_range_default

    if relayout_data:
        # Check the main axis first (common when user zooms the top plot)
        if "xaxis2.range[0]" in relayout_data:
            x_start = relayout_data["xaxis2.range[0]"]
            x_end = relayout_data["xaxis2.range[1]"]

        # If not found, check for "xaxis1.range[0]" (if Plotly is labeling the slider that way)
        elif "xaxis1.range[0]" in relayout_data:
            x_start = relayout_data["xaxis1.range[0]"]
            x_end = relayout_data["xaxis1.range[1]"]

        # If still not found, it might be "xaxis.range[0]" for the slider
        elif "xaxis.range[0]" in relayout_data:
            x_start = relayout_data["xaxis.range[0]"]
            x_end = relayout_data["xaxis.range[1]"]
        else:
            # If no recognized keys are present, we try to read from the figure layout
            layout = current_figure["layout"]
            if layout.get("xaxis2", {}).get("range"):
                x_start, x_end = layout["xaxis2"]["range"]
            elif layout.get("xaxis1", {}).get("range"):
                x_start, x_end = layout["xaxis1"]["range"]
            else:
                x_start, x_end = x_range_default
    else:
        # No relayout_data at all, read from the figure layout
        layout = current_figure["layout"]
        if layout.get("xaxis2", {}).get("range"):
            x_start, x_end = layout["xaxis2"]["range"]
        elif layout.get("xaxis1", {}).get("range"):
            x_start, x_end = layout["xaxis1"]["range"]
        else:
            x_start, x_end = x_range_default

    return x_start, x_end

In [28]:
app = Dash(__name__)

app.layout = html.Div([dcc.Graph(id="plotly-normalized-asset-prices", figure=plot_prices(x, y, z))])


@app.callback(
    Output("plotly-normalized-asset-prices", "figure"),
    Input("plotly-normalized-asset-prices", "relayoutData"),
    State("plotly-normalized-asset-prices", "figure"),
)
def update_plot(relayout_data, current_figure):
    """Preserve whichever x-range the user has set by zoom/pan or slider."""
    x_start, x_end = get_x_range(relayout_data, current_figure)
    print(f"{x_start=}, {x_end=}")

    return plot_prices(x, y, z, (x_start, x_end))


if __name__ == "__main__":
    app.run_server(debug=True)

x_start=None, x_end=None
x_start=0, x_end=100
x_start=14.884516680923868, x_end=64.41402908468777
x_start=14.884516680923868, x_end=64.41402908468777
x_start=15.055603079555175, x_end=64.58511548331909
x_start=16.5098374679213, x_end=66.03934987168522
x_start=18.81950384944397, x_end=68.34901625320788
x_start=20.35928143712575, x_end=69.88879384088966
x_start=21.813515825491873, x_end=71.3430282292558
x_start=23.86655260906758, x_end=73.39606501283149
x_start=24.978614200171087, x_end=74.508126603935
x_start=25.66295979469632, x_end=75.19247219846024
x_start=26.43284858853721, x_end=75.96236099230113
x_start=27.11719418306245, x_end=76.64670658682635
x_start=27.97262617621899, x_end=77.5021385799829
x_start=28.656971770744228, x_end=78.18648417450814
x_start=29.683490162532078, x_end=79.21300256629598
x_start=30.624465355004276, x_end=80.15397775876819
x_start=31.479897348160822, x_end=81.00940975192474
x_start=31.822070145423435, x_end=81.35158254918736
x_start=31.993156544054745, x_e