# Stock Comparison Tool

Trying to implement something similar to https://www.barchart.com/myBarchart/quotes/SPY/interactive-chart but in Gradio.

The code in the first code block of this notebook is executed when the `app.py` file is run. This notebook is used for the simplicity of interactive development and testing.

In [1]:
from datetime import timedelta

import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import yfinance as yf

# retrieve historical stock prices
tickers = ["AAPL", "MSFT", "GOOGL"]
df = yf.download(
    tickers,
    interval="1d",
    period="max",
    progress=False,
).Close.bfill()
df

Ticker,AAPL,GOOGL,MSFT
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1980-12-12,0.098834,2.501941,0.059947
1980-12-15,0.093678,2.501941,0.059947
1980-12-16,0.086802,2.501941,0.059947
1980-12-17,0.088951,2.501941,0.059947
1980-12-18,0.091530,2.501941,0.059947
...,...,...,...
2024-12-27,255.589996,192.759995,430.529999
2024-12-30,252.199997,191.240005,424.829987
2024-12-31,250.419998,189.300003,421.500000
2025-01-02,243.850006,189.429993,418.579987


### Gradio plot -- this works but is slow to update

In [2]:
def plot_asset_prices(period, shift):
    end_date = df.index[-1] - timedelta(days=shift)
    start_date = end_date - timedelta(days=period)
    df_normalized = df[(df.index >= start_date) & (df.index <= end_date)]
    df_normalized = df_normalized / df_normalized.iloc[0] - 1
    df_normalized.reset_index(names="Date", inplace=True)
    return gr.LinePlot(
        value=pd.melt(df_normalized, id_vars=["Date"], var_name="Asset", value_name="Price"),
        x="Date",
        y="Price",
        color="Asset",
        title="Normalized Asset Prices",
        y_title="Relative Change",
        x_label_angle=45,
        # height=600,
    )


with gr.Blocks() as demo:
    plot = plot_asset_prices(365, 0)
    with gr.Row():
        period = gr.Radio(
            choices=[
                ("5y", 5 * 365),
                ("3y", 2 * 365),
                ("2y", 2 * 365),
                ("1y", 365),
                ("6mo", 182),
                ("1mo", 30),
                ("1w", 7),
            ],
            value=365,
            label="Period",
        )
        shift = gr.Slider(minimum=0, maximum=365, value=0, label="End Date")
    period.change(plot_asset_prices, inputs=[period, shift], outputs=plot)
    shift.change(plot_asset_prices, inputs=[period, shift], outputs=plot)

demo.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




### Plotly plot -- this is super responsive but does not update

In [3]:
dg = pd.DataFrame(np.random.randn(1000, 3).cumsum(axis=0), columns=["A", "B", "C"])

fig = go.Figure()
for col in dg:
    fig.add_trace(go.Scatter(x=dg.index, y=dg[col], mode="lines", name=col))
fig.update_layout(
    xaxis_rangeslider_visible=True,
)

### Plotly MVE -- finally

In [4]:
import numpy as np
import plotly.graph_objects as go
from dash import Dash, Input, Output, State, dcc, html


def plot_prices(x, y1, y2, x_range=(None, None)):
    fig = go.Figure()

    # rangeslider plot
    fig.add_trace(
        go.Scatter(
            x=x,
            y=y1,
            xaxis="x1",
            yaxis="y1",
            showlegend=False,
        )
    )

    # main plot
    fig.add_trace(
        go.Scatter(
            x=x,
            y=y2,
            name="normalized asset prices",
            xaxis="x2",
            yaxis="y2",
        )
    )

    # configure axes
    if all(x_range):
        xaxis1_dict = dict(rangeslider=dict(visible=True), range=x_range)
        xaxis2_dict = dict(matches="x1", range=x_range)
    else:
        xaxis1_dict = dict(rangeslider=dict(visible=True))
        xaxis2_dict = dict(matches="x1")
    yaxis1_dict = dict(showticklabels=False)
    yaxis2_dict = dict(title="Relative Change")

    fig.update_layout(
        xaxis1=xaxis1_dict,
        yaxis1=yaxis1_dict,
        xaxis2=xaxis2_dict,
        yaxis2=yaxis2_dict,
        uirevision="constant-key",  # prevent resets from the xrange compression
    )

    return fig

In [5]:
def get_x_range(relayout_data, current_figure, x_range=(None, None)):
    if relayout_data:
        # Check the main axis first (common when user zooms the top plot)
        if "xaxis2.range[0]" in relayout_data:
            x_range = (relayout_data["xaxis2.range[0]"], relayout_data["xaxis2.range[1]"])

        # If not found, check for "xaxis1.range[0]" (if Plotly is labeling the slider that way)
        elif "xaxis1.range[0]" in relayout_data:
            x_range = (relayout_data["xaxis1.range[0]"], relayout_data["xaxis1.range[1]"])

        # If still not found, it might be "xaxis.range[0]" for the slider
        elif "xaxis.range[0]" in relayout_data:
            x_range = (relayout_data["xaxis.range[0]"], relayout_data["xaxis.range[1]"])

        # If no recognized keys are present, we try to read from the figure layout
        else:
            layout = current_figure["layout"]
            if layout.get("xaxis2", {}).get("range"):
                x_range = layout["xaxis2"]["range"]
            elif layout.get("xaxis1", {}).get("range"):
                x_range = layout["xaxis1"]["range"]

    # No relayout_data at all, read from the figure layout
    else:
        layout = current_figure["layout"]
        if layout.get("xaxis2", {}).get("range"):
            x_range = layout["xaxis2"]["range"]
        elif layout.get("xaxis1", {}).get("range"):
            x_range = layout["xaxis1"]["range"]

    return x_range

In [6]:
def normalize_prices(x, y, z, x_range):
    if not all(x_range):
        return z
    else:
        idx = int(10 * x_range[0])
        return z - z[idx]

In [7]:
def setup_app(x, y, z):
    app = Dash(__name__)
    app.layout = html.Div(
        [dcc.Graph(id="plotly-normalized-asset-prices", figure=plot_prices(x, y, z))]
    )

    @app.callback(
        Output("plotly-normalized-asset-prices", "figure"),
        Input("plotly-normalized-asset-prices", "relayoutData"),
        State("plotly-normalized-asset-prices", "figure"),
    )
    def update_plot(relayout_data, current_figure):
        """Preserve whichever x-range the user has set by zoom/pan or slider."""
        x_range = get_x_range(relayout_data, current_figure)
        z_normalized = normalize_prices(x, y, z, x_range)
        return plot_prices(x, y, z_normalized, x_range)

    return app

In [8]:
x = np.linspace(0, 100, 1000)
y = np.sin(6 * x) + 10  # raw
z = np.cos(x) * 2  # normalized

app = setup_app(x, y, z)
app.run_server(debug=True)