# Appendix: helper functions

In [None]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import pandas as pd
from IPython.display import Image, display

Here we create some helper functions that will be used across notebooks using the magic `%%writefile`.

## Data visualisation

Data exploration, in particular based on visualisation, is crucial to modern data science. `Pandas` has a lot of plotting functionalities (e.g. see the graph below), but we will find it usefull to use a custom `plot` set of functions.

In [None]:
%%writefile ../skfin/plot.py
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from skfin.metrics import sharpe_ratio

plt.style.use("seaborn-whitegrid")


def set_axis(ax=None, figsize=(8, 5), title=None, fig=None):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
    if title is not None:
        ax.set_title(title)
    return fig, ax


def line(
    df,
    sort=True,
    figsize=(8, 5),
    ax=None,
    title="",
    cumsum=False,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    legend_sharpe_ratio=None,
    legend=True,
    yscale=None,
    start_date=None,
):
    df = df.copy()
    if loc == "best":
        bbox_to_anchor = None
    if isinstance(df, dict) | isinstance(df, list):
        df = pd.concat(df, axis=1)
    if isinstance(df, pd.Series):
        df = df.to_frame()
    if start_date is not None:
        df = df[start_date:]
    if cumsum & (legend_sharpe_ratio is None):
        legend_sharpe_ratio = True
    if legend_sharpe_ratio:
        df.columns = [f"{c}: sr={sharpe_ratio(df[c]): 3.2f}" for c in df.columns]
    if cumsum:
        df = df.cumsum()
    if sort:
        df = df.loc[:, lambda x: x.iloc[-1].sort_values(ascending=False).index]
    fig, ax = set_axis(ax=ax, figsize=figsize, title=title)
    ax.plot(df.index, df.values)
    if legend:
        ax.legend(df.columns, loc=loc, bbox_to_anchor=bbox_to_anchor)
    if yscale == "log":
        ax.set_yscale("log")


def bar(
    df,
    err=None,
    sort=True,
    figsize=(8, 5),
    ax=None,
    title=None,
    horizontal=False,
    baseline=None,
    rotation=0,
):
    if isinstance(df, pd.DataFrame):
        df = df.squeeze()
    if isinstance(df, dict):
        df = pd.Series(df)
    if sort:
        df = df.sort_values()
    if err is not None:
        err = err.loc[df.index]
    labels = df.index
    x = np.arange(len(labels))
    fig, ax = set_axis(ax=ax, figsize=figsize, title=title)
    if horizontal:
        ax.barh(x, df.values, xerr=err, capsize=5)
        ax.set_yticks(x)
        ax.set_yticklabels(labels, rotation=0)
        if baseline in df.index:
            df_ = df.copy()
            df_[df.index != baseline] = 0
            ax.barh(x, df_.values, color="lightgreen")
    else:
        ax.bar(x, df.values, yerr=err, capsize=5)
        ax.set_xticks(x)
        ax.set_xticklabels(labels, rotation=0)
        if baseline in df.index:
            df_ = df.copy()
            df_[df.index != baseline] = 0
            ax.bar(x, df_.values, color="lightgreen")
    ax.set_title(title)


def heatmap(
    df,
    ax=None,
    fig=None, 
    figsize=(8, 5),
    title=None,
    vmin=None,
    vmax=None,
    vcompute=True,
    cmap="RdBu",
):
    labels_x = df.index
    x = np.arange(len(labels_x))
    labels_y = df.columns
    y = np.arange(len(labels_y))
    if vcompute:
        vmax = df.abs().max().max()
        vmin = -vmax
    fig, ax = set_axis(ax=ax, figsize=figsize, title=title, fig=fig)
    pos = ax.imshow(
        df.T.values, cmap=cmap, interpolation="nearest", vmax=vmax, vmin=vmin
    )
    ax.set_xticks(x)
    ax.set_yticks(y)
    ax.set_xticklabels(labels_x, rotation=90)
    ax.set_yticklabels(labels_y)
    ax.grid(True)
    fig.colorbar(pos, ax=ax)


def scatter(
    df,
    ax=None,
    xscale=None,
    yscale=None,
    xlabel=None,
    ylabel=None,
    xticks=None,
    yticks=None,
    figsize=(8, 5),
    title=None,
):
    fig, ax = set_axis(ax=ax, figsize=figsize, title=title)
    ax.scatter(df, df.index, facecolors="none", edgecolors="b", s=50)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if xscale is not None:
        ax.set_xscale(xscale)
    if yscale is not None:
        ax.set_yscale(yscale)
    if yticks is not None:
        ax.set_yticks(yticks)
        ax.set_yticklabels(yticks)
    if xticks is not None:
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticks)

In [None]:
from skfin.plot import bar, heatmap, line

In [None]:
line(
    pd.Series(np.random.normal(size=50)),
    cumsum=True,
    title="This is a graph",
    legend_sharpe_ratio=False,
)

In [None]:
bar(pd.Series(np.random.normal(size=50)), baseline=10, horizontal=True)

## Dates and mappings

In [None]:
%%writefile ../skfin/dataset_mappings.py
symbol_dict = {
    "TOT": "Total",
    "XOM": "Exxon",
    "CVX": "Chevron",
    "COP": "ConocoPhillips",
    "VLO": "Valero Energy",
    "MSFT": "Microsoft",
    "IBM": "IBM",
    "TWX": "Time Warner",
    "CMCSA": "Comcast",
    "CVC": "Cablevision",
    "YHOO": "Yahoo",
    "DELL": "Dell",
    "HPQ": "HP",
    "AMZN": "Amazon",
    "TM": "Toyota",
    "CAJ": "Canon",
    "SNE": "Sony",
    "F": "Ford",
    "HMC": "Honda",
    "NAV": "Navistar",
    "NOC": "Northrop Grumman",
    "BA": "Boeing",
    "KO": "Coca Cola",
    "MMM": "3M",
    "MCD": "McDonald's",
    "PEP": "Pepsi",
    "K": "Kellogg",
    "UN": "Unilever",
    "MAR": "Marriott",
    "PG": "Procter Gamble",
    "CL": "Colgate-Palmolive",
    "GE": "General Electrics",
    "WFC": "Wells Fargo",
    "JPM": "JPMorgan Chase",
    "AIG": "AIG",
    "AXP": "American express",
    "BAC": "Bank of America",
    "GS": "Goldman Sachs",
    "AAPL": "Apple",
    "SAP": "SAP",
    "CSCO": "Cisco",
    "TXN": "Texas Instruments",
    "XRX": "Xerox",
    "WMT": "Wal-Mart",
    "HD": "Home Depot",
    "GSK": "GlaxoSmithKline",
    "PFE": "Pfizer",
    "SNY": "Sanofi-Aventis",
    "NVS": "Novartis",
    "KMB": "Kimberly-Clark",
    "R": "Ryder",
    "GD": "General Dynamics",
    "RTN": "Raytheon",
    "CVS": "CVS",
    "CAT": "Caterpillar",
    "DD": "DuPont de Nemours",
}

mapping_10X = {
    "AAPL": ["APPLE COMPUTER INC", "APPLE INC"],
    "AIG": "AMERICAN INTERNATIONAL GROUP INC",
    "AMZN": "AMAZON COM INC",
    "AXP": "AMERICAN EXPRESS CO",
    "BA": "BOEING CO",
    "BAC": "BANK OF AMERICA CORP /DE/",
    "CAT": "CATERPILLAR INC",
    "CL": "COLGATE PALMOLIVE CO",
    "CMCSA": "COMCAST CORP",
    "COP": "CONOCOPHILLIPS",
    "CSCO": "CISCO SYSTEMS INC",
    "CVC": "CABLEVISION SYSTEMS CORP /NY",
    "CVS": ["CVS CORP", "CVS/CAREMARK CORP", "CVS CAREMARK CORP"],
    "CVX": ["CHEVRONTEXACO CORP", "CHEVRON CORP"],
    "DD": "DUPONT E I DE NEMOURS & CO",
    "DELL": ["DELL COMPUTER CORP", "DELL INC"],
    "F": "FORD MOTOR CO",
    "GD": "GENERAL DYNAMICS CORP",
    "GE": "GENERAL ELECTRIC CO",
    "GS": "GOLDMAN SACHS GROUP INC/",
    "HD": "HOME DEPOT INC",
    "HPQ": "HEWLETT PACKARD CO",
    "IBM": "INTERNATIONAL BUSINESS MACHINES CORP",
    "JPM": "J P MORGAN CHASE & CO",
    "K": "KELLOGG CO",
    "KMB": "KIMBERLY CLARK CORP",
    "KO": "COCA COLA CO",
    "MAR": "MARRIOTT INTERNATIONAL INC /MD/",
    "MCD": "MCDONALDS CORP",
    "MMM": "3M CO",
    "MSFT": "MICROSOFT CORP",
    "NAV": "NAVISTAR INTERNATIONAL CORP",
    "NOC": "NORTHROP GRUMMAN CORP /DE/",
    "PEP": "PEPSI BOTTLING GROUP INC",
    "PFE": "PFIZER INC",
    "PG": "PROCTER & GAMBLE CO",
    "R": "RYDER SYSTEM INC",
    "RTN": "RAYTHEON CO/",
    "TWX": ["AOL TIME WARNER INC", "TIME WARNER INC"],
    "TXN": "TEXAS INSTRUMENTS INC",
    "VLO": "VALERO ENERGY CORP/TX",
    "WFC": "WELLS FARGO & CO/MN",
    "WMT": "WAL MART STORES INC",
    "XOM": "EXXON MOBIL CORP",
    "XRX": "XEROX CORP",
    "YHOO": "YAHOO INC",
}

In [None]:
%%writefile ../skfin/dataset_dates.py
import pandas as pd


def load_fomc_change_date(as_datetime=True):
    change_up = [
        "1999-06-30",
        "1999-08-24",
        "1999-11-16",
        "2000-02-02",
        "2000-03-21",
        "2000-05-16",
        "2004-06-30",
        "2004-08-10",
        "2004-09-21",
        "2004-11-10",
        "2004-12-14",
        "2005-02-02",
        "2005-03-22",
        "2005-05-03",
        "2005-06-30",
        "2005-08-09",
        "2005-09-20",
        "2005-11-01",
        "2005-12-13",
        "2006-01-31",
        "2006-03-28",
        "2006-05-10",
        "2006-06-29",
        "2015-12-16",
        "2016-12-14",
        "2017-03-15",
        "2017-06-14",
        "2017-12-13",
        "2018-03-21",
        "2018-06-13",
        "2018-09-26",
        "2018-12-19",
        "2022-03-16",
        "2022-05-04",
        "2022-06-15",
        "2022-07-27",
    ]

    change_dw = [
        "2001-01-03",
        "2001-01-31",
        "2001-03-20",
        "2001-04-18",
        "2001-05-15",
        "2001-06-27",
        "2001-08-21",
        "2001-09-17",
        "2001-10-02",
        "2001-11-06",
        "2001-12-11",
        "2002-11-06",
        "2003-06-25",
        "2007-09-18",
        "2007-10-31",
        "2007-12-11",
        "2008-01-22",
        "2008-01-30",
        "2008-03-18",
        "2008-04-30",
        "2008-10-08",
        "2008-10-29",
        "2008-12-16",
        "2019-07-31",
        "2019-09-18",
        "2019-10-30",
        "2020-03-03",
        "2020-03-15",
    ]
    if as_datetime:
        change_up, change_dw = pd.to_datetime(change_up), pd.to_datetime(change_dw)

    return change_up, change_dw

## Data utils

In [None]:
%%writefile ../skfin/data_utils.py
import os
from pathlib import Path

import pandas as pd


def clean_directory_path(cache_dir, default_dir="data"):
    if cache_dir is None:
        cache_dir = Path(os.getcwd()) / default_dir
    if isinstance(cache_dir, str):
        cache_dir = Path(cache_dir)
    if not cache_dir.is_dir():
        os.makedirs(cache_dir)
    return cache_dir


def save_dict(data, output_dir):
    assert isinstance(data, dict)
    if not output_dir.is_dir():
        os.mkdir(output_dir)
    for k, v in data.items():
        if isinstance(v, pd.DataFrame):
            v.to_parquet(output_dir / f"{k}.parquet")
        else:
            save_dict(v, output_dir=output_dir / k)


def load_dict(input_dir):
    data = {}
    for o in os.scandir(input_dir):
        if o.name.endswith(".parquet"):
            k = o.name.replace(".parquet", "")
            data[k] = pd.read_parquet(o)
        elif o.is_dir:
            data[o.name] = load_dict(o)
    return data