In [None]:
import os
from datetime import datetime, timedelta
import ipywidgets as widgets
import plotly.graph_objs as go
import yfinance as yf
import pandas as pd
from IPython.display import display

In [None]:
interval_opts = [
    "1m",
    "2m",
    "5m",
    "15m",
    "30m",
    "60m",
    "90m",
    "1h",
    "1d",
    "5d",
    "1wk",
    "1mo",
    "3mo",
]

rows = [
    "sector",
    "marketCap",
    "beta",
    "fiftyTwoWeekHigh",
    "fiftyTwoWeekLow",
    "floatShares",
    "sharesShort",
    "exDividendDate",
]

views = {
    "Raw Data": lambda x, y: x,
    "Percent Change": lambda x, y: x.pct_change(),
    "Rolling Average": lambda x, y: x.rolling(y).mean(),
    "Rolling Variance": lambda x, y: x.rolling(y).var(),
    "Rolling Standard Deviation": lambda x, y: x.rolling(y).var() ** 0.5,
    "Rolling Coefficient of Variation": lambda x, y: (x.rolling(y).var() ** 0.5)
    / (x.rolling(y).mean()),
}

clean_row = {
    "sector": "Sector",
    "marketCap": "M Cap",
    "beta": "Beta",
    "fiftyTwoWeekHigh": "52W High",
    "fiftyTwoWeekLow": "52W Low",
    "floatShares": "Floats",
    "sharesShort": "Shorts",
    "exDividendDate": "Ex-Div",
}

clean_data = {
    "sector": lambda x: "N/A" if x is None else x,
    "marketCap": lambda x: "N/A" if x is None else big_num(x),
    "beta": lambda x: "N/A" if x is None else f"{round(x,2)}",
    "fiftyTwoWeekHigh": lambda x: "N/A" if x is None else f"${round(x,2)}",
    "fiftyTwoWeekLow": lambda x: "N/A" if x is None else f"${round(x,2)}",
    "floatShares": lambda x: "N/A" if x is None else big_num(x),
    "sharesShort": lambda x: "N/A" if x is None else big_num(x),
    "exDividendDate": lambda x: "N/A"
    if x is None
    else datetime.fromtimestamp(x).strftime("%Y/%m/%d"),
}


def big_num(num):
    if num > 1_000_000_000_000:
        return f"{round(num/1_000_000_000_000,2)}T"
    if num > 1_000_000_000:
        return f"{round(num/1_000_000_000,2)}B"
    if num > 1_000_000:
        return f"{round(num/1_000_000,2)}M"
    if num > 1_000:
        return f"{num/round(1_000,2)}K"
    return f"{round(num,2)}"


def clean_str(string):
    new_str = ""
    for letter in string:
        if letter.isupper():
            new_str += " "
        new_str += letter
    return new_str.title()


def format_plotly(fig, data, start, end, chart, calc=None):
    fig.update_yaxes(title=None)
    fig.update_xaxes(title=None)
    start_t = start.strftime("%Y/%m/%d")
    end_t = end.strftime("%Y/%m/%d")
    if calc:
        if len(calc) == 1:
            fig_title = f"{calc[0]} of {data} from {start_t} to {end_t}"
        else:
            fig_title = f"{', '.join(calc)} of {data} from {start_t} to {end_t}"
    else:
        fig_title = "Volume"
    height = 500 if chart == "main" else 300
    fig.update_layout(
        margin=dict(l=0, r=10, t=10, b=10),
        autosize=False,
        width=900,
        height=height,
        legend=dict(orientation="h"),
        title={
            "text": fig_title,
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
    )


def create_line(visual, x, y, name, data, fig):
    if visual == "line":
        plot = go.Scatter(x=x, y=y[data], mode="lines", name=name, connectgaps=True)
    if visual == "scatter":
        plot = go.Scatter(x=x, y=y[data], mode="markers", name=name)
    if visual == "candle":
        plot = go.Candlestick(
            x=x,
            open=y["Open"],
            close=y["Close"],
            high=y["High"],
            low=y["Low"],
            name=name,
        )
    fig.add_trace(plot)


def show_fig(fig):
    config = {"showTips": False, "scrollZoom": True}
    if os.environ.get("SERVER_SOFTWARE", "jupyter").startswith("voila"):
        fig.show(config=config, renderer="notebook")
    else:
        fig.show(config=config)


def table_data(infos):
    cols = ["Ticker"] + list(infos)
    data = pd.DataFrame(columns=cols)
    data["Ticker"] = [clean_row[x] for x in rows]
    for ticker in list(infos):
        data[ticker] = [clean_data[x](infos[ticker].get(x, None)) for x in rows]
    new_cols = {k: clean_str(k) for k in rows}
    return data

In [None]:
class Chart:
    def __init__(self):
        self.last_tickers = ""
        self.last_interval = "1d"
        self.df = pd.DataFrame()
        self.infos = {}

    def create_stock(
        self, calculation, data, rolling, start, end, interval, tickers, chart
    ):
        if tickers and tickers[-1] == ",":
            if tickers != self.last_tickers or interval != self.last_interval:
                if interval in ["1d", "5d", "1wk", "1mo", "3mo"]:
                    self.df = yf.download(
                        tickers, period="max", interval=interval, progress=False
                    )
                else:
                    end_date = end + timedelta(days=1)
                    self.df = yf.download(
                        tickers,
                        start=start,
                        end=end_date,
                        interval=interval,
                        progress=False,
                    )
                if not self.df.empty:
                    self.df.index = self.df.index.tz_localize(None)
                self.last_tickers = tickers
                self.last_interval = interval

            start_n = datetime(start.year, start.month, start.day)
            end_n = datetime(end.year, end.month, end.day)
            fig = go.Figure()
            for item in calculation:
                calcs = views[item](self.df, rolling)
                if interval in ["1d", "5d", "1wk", "1mo", "3mo"]:
                    result = calcs.loc[
                        (calcs.index >= start_n) & (calcs.index <= end_n)
                    ]
                else:
                    result = calcs

                if len(result.columns) == 6:
                    name = f"{tickers.split(',')[0]} {item}"
                    create_line(chart, result.index, result, name, data, fig)

                else:
                    for val in result.columns.levels[1]:
                        vals = result.xs(val, axis=1, level=1, drop_level=True)
                        name = f"{val.upper()} {item}"
                        create_line(chart, result.index, vals, name, data, fig)

            format_plotly(fig, data, start, end, "main", calculation)
            show_fig(fig)

    def create_volume(self, start, end, interval, tickers):
        start_n = datetime(start.year, start.month, start.day)
        end_n = datetime(end.year, end.month, end.day)
        result = self.df.loc[(self.df.index >= start_n) & (self.df.index <= end_n)]
        fig = go.Figure()
        if len(result.columns) == 6:
            name = f"{tickers.split(',')[0]}"
            create_line("line", result.index, result, name, "Volume", fig)
        else:
            for val in result.columns.levels[1]:
                vals = result.xs(val, axis=1, level=1, drop_level=True)
                name = f"{val.upper()}"
                create_line("line", result.index, vals, name, "Volume", fig)
        format_plotly(fig, "Volume", start, end, "volume")
        show_fig(fig)

    def create_table(self, tickers):
        if tickers and tickers[-1] == ",":
            clean_tickers = [x for x in tickers.split(",") if x]
            for ticker in clean_tickers:
                if ticker not in self.infos:
                    self.infos[ticker] = yf.Ticker(ticker).info
            delete = [ticker for ticker in self.infos if ticker not in tickers]
            for ticker in delete:
                self.infos.pop(ticker)
            result = table_data(self.infos)
            fig = go.Figure(
                data=[
                    go.Table(
                        header=dict(
                            values=result.columns,
                            fill_color="lightgray",
                            font=dict(color="black"),
                            align="left",
                        ),
                        cells=dict(
                            values=[result[x] for x in result.columns],
                            font=dict(color="black"),
                            align="left",
                        ),
                    )
                ],
            )
            fig.update_layout(margin=dict(l=0, r=20, t=0, b=0), width=350)
            show_fig(fig)

In [None]:
w_auto = widgets.Layout(width="auto")
calc_widget = widgets.SelectMultiple(
    options=list(views.keys()), value=["Raw Data"], layout=w_auto
)

data_opts = ["Open", "Close", "High", "Low"]
data_widget = widgets.Dropdown(
    options=data_opts, value="Close", layout=w_auto, description="Data"
)
rolling_widget = widgets.Dropdown(
    options=list(range(2, 101)), value=60, layout=w_auto, description="Rolling"
)

base_date = (datetime.today() - timedelta(days=365)).date()
start_widget = widgets.DatePicker(value=base_date, layout=w_auto, description="Start")
end_widget = widgets.DatePicker(
    value=datetime.today().date(), layout=w_auto, description="End"
)

interval_widget = widgets.Dropdown(
    options=interval_opts, value="1d", layout=w_auto, description="Interval"
)
tickers_widget = widgets.Textarea(
    value="TSLA,", layout=widgets.Layout(width="auto", height="100%")
)
chart_opts = ["line", "scatter", "candle"]
chart_widget = widgets.Dropdown(
    options=chart_opts, value="line", layout=w_auto, description="Chart"
)
data_box = widgets.VBox([data_widget, rolling_widget, chart_widget])
date_box = widgets.VBox([start_widget, end_widget, interval_widget])
controls = widgets.HBox(
    [tickers_widget, calc_widget, date_box, data_box],
    layout=widgets.Layout(width="90%"),
)
chart = Chart()
stocks_view = widgets.interactive_output(
    chart.create_stock,
    {
        "calculation": calc_widget,
        "data": data_widget,
        "rolling": rolling_widget,
        "start": start_widget,
        "end": end_widget,
        "interval": interval_widget,
        "tickers": tickers_widget,
        "chart": chart_widget,
    },
)

volume_view = widgets.interactive_output(
    chart.create_volume,
    {
        "start": start_widget,
        "end": end_widget,
        "interval": interval_widget,
        "tickers": tickers_widget,
    },
)

table_view = widgets.interactive_output(chart.create_table, {"tickers": tickers_widget})

charts = widgets.VBox(
    [stocks_view, volume_view],
    layout=widgets.Layout(width="100%", padding="0", margin="0"),
)
figures = widgets.HBox(
    [charts, table_view], layout=widgets.Layout(padding="0", margin="0")
)

title_html = "<h1>Stock Analysis Dashboard</h1>"
warning_html = '<p style="color:red"=>Use a comma after EVERY stock typed.</p>'

app_contents = [widgets.HTML(title_html), controls, widgets.HTML(warning_html), figures]
app = widgets.VBox(app_contents)
display(app)