In [None]:
import os
from datetime import datetime, timedelta
import ipywidgets as widgets
import plotly.graph_objs as go
import yfinance as yf
import pandas as pd
from IPython.display import display

In [None]:
interval_opts = [
    "1m",
    "2m",
    "5m",
    "15m",
    "30m",
    "60m",
    "90m",
    "1h",
    "1d",
    "5d",
    "1wk",
    "1mo",
    "3mo",
]

In [None]:
class Chart:
    def __init__(self):
        self.last_tickers = ""
        self.last_interval = "1d"
        self.df = pd.DataFrame()

    def create(self, data, start, end, interval, tickers):
        global last_tickers, df, last_interval
        if tickers and tickers[-1] == ",":
            if tickers != self.last_tickers or interval != self.last_interval:
                if interval in ["1d", "5d", "1wk", "1mo", "3mo"]:
                    self.df = yf.download(
                        tickers, period="max", interval=interval, progress=False
                    )
                else:
                    self.df = yf.download(
                        tickers, start=start, end=end, interval=interval, progress=False
                    )
                self.last_tickers = tickers
                self.last_interval = interval

            start_n = datetime(start.year, start.month, start.day)
            end_n = datetime(end.year, end.month, end.day)

            df = self.df[data]

            if not isinstance(df, pd.Series):
                if interval in ["1d", "5d", "1wk", "1mo", "3mo"]:
                    result = df.loc[(df.index >= start_n) & (df.index <= end_n)].corr()
                else:
                    result = df.corr()

                base = [
                    [
                        "black" if x == 1 else "lightgreen" if x > 0 else "lightpink"
                        for x in result[y].tolist()
                    ]
                    for y in result.columns
                ]
                base = [["lightgray" for _ in range(result.shape[0])]] + base
                result = result.reset_index()
                result.rename(columns={"index": ""}, inplace=True)

                fig = go.Figure(
                    data=[
                        go.Table(
                            header=dict(
                                values=list(result.columns),
                                fill_color="lightgray",
                                font=dict(color="black"),
                                align="left",
                            ),
                            cells=dict(
                                values=[result[x] for x in result.columns],
                                fill_color=base,
                                format=[""]
                                + [".2f" for _ in range(len(df.columns) - 1)],
                                font=dict(color="black"),
                                align="left",
                            ),
                        )
                    ],
                )
                fig.update_layout(
                    autosize=True,
                    height=600,
                    showlegend=False,
                )
                if os.environ.get("SERVER_SOFTWARE", "jupyter").startswith("voila"):
                    fig.show(config={"showTips": False}, renderer="notebook")
                else:
                    fig.show(config={"showTips": False})


w_auto = widgets.Layout(width="auto")
data_opts = ["Open", "Close", "High", "Low", "Volume"]
data_widget = widgets.Dropdown(
    options=data_opts, value="Close", layout=w_auto, description="Data"
)

base_date = (datetime.today() - timedelta(days=365)).date()
start_widget = widgets.DatePicker(value=base_date, layout=w_auto, description="Start")
end_widget = widgets.DatePicker(
    value=datetime.today().date(), layout=w_auto, description="End"
)
interval_widget = widgets.Dropdown(
    options=interval_opts, value="1d", layout=w_auto, description="Interval"
)
tickers_widget = widgets.Textarea(value="TSLA,AAPL,", layout=w_auto)

data_box = widgets.VBox([tickers_widget, data_widget])
date_box = widgets.VBox([start_widget, end_widget, interval_widget])
controls = widgets.HBox([data_box, date_box], layout=widgets.Layout(width="60%"))

chart = Chart()
stocks_view = widgets.interactive_output(
    chart.create,
    {
        "data": data_widget,
        "start": start_widget,
        "end": end_widget,
        "interval": interval_widget,
        "tickers": tickers_widget,
    },
)

title_html = "<h1>Correlation Analysis Dashboard</h1>"
warning_html = '<p style="color:red">Use a comma after EVERY stock typed.</p>'
app_contents = [
    widgets.HTML(title_html),
    controls,
    widgets.HTML(warning_html),
    stocks_view,
]
app = widgets.VBox(app_contents)
display(app)