In [2]:
growth_rates = {
    'robert_income': 0.0,
    'isabel_income': 0.0,
    'expenses': 0.01,
    'assets': 0.04
}

shocks = {
    2027: {
        'robert_income': (-10000, 'Robert leaves Google'),
        'isabel_income': (-100000, 'Isabel book deals are smaller')
    },
    2030: {
        'expenses': (30000, 'Childcare')
    }
}

In [3]:
import pandas as pd
import nbconvert
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
from matplotlib import cm

def financial_projection(robert_income, isabel_income, expenses, assets, growth_rates, shocks, volatility, simulations):
    # set defaults for charts
    pio.templates.default = "plotly_white"

    @np.vectorize
    def calculate_tax(income):
        brackets = [9950, 40525, 86375, 164925, 209425, 523600]
        rates = [0.10, 0.12, 0.22, 0.24, 0.32, 0.35, 0.37]
        tax = 0
        for i in range(len(brackets)):
            if income > brackets[i]:
                if i == 0:
                    tax += rates[i] * brackets[i]
                else:
                    tax += rates[i] * (brackets[i] - brackets[i - 1])
            else:
                if i == 0:
                    tax += rates[i] * income
                else:
                    tax += rates[i] * (income - brackets[i - 1])
                break
        if income > brackets[-1]:
            tax += rates[-1] * (income - brackets[-1])

        return tax

    # parameters
    variables = ['robert_income', 'isabel_income', 'expenses', 'assets']

    # create initial setup DataFrame
    data = pd.DataFrame({
        'year': [2023],
        'robert_income': [robert_income],
        'isabel_income': [isabel_income],
        'expenses': [expenses],
        'assets': [assets]
    }).set_index('year')

    growth_assumptions = growth_rates

    volatility = volatility  # standard deviation of asset growth
    simulations = simulations  # number of simulations

    # create a DataFrame to hold the future projections
    projection = pd.DataFrame(index=range(2023, 2083))

    # initialize a DataFrame with simulations for assets
    asset_simulations = pd.DataFrame(1 + volatility * np.random.standard_normal(size=(60, simulations)),
                                     index=projection.index,
                                     columns=['simulation_' + str(i) for i in range(simulations)]
    )

    # chain all
    asset_simulations = asset_simulations.cumprod()

    # loop over years
    for year in projection.index:
        if year == 2023:
            # handle base year
            for var in variables:
                projection.loc[year, var] = data.loc[2023, var]
                asset_simulations.loc[year] = data.loc[2023, 'assets']
        else:
            # apply growth assumptions and shocks
            for var in variables:
                projection.loc[year, var] = projection.loc[year - 1, var] * (1 + growth_assumptions[var])
                if year in shocks and var in shocks[year]:
                    shock, _ = shocks[year][var]
                    projection.loc[year, var] += shock

    # calculate tax
    projection['tax'] = projection['robert_income'] + projection['isabel_income'] - projection['expenses']
    projection['tax'] = calculate_tax(projection['tax'])

    # plot income, expenses, savings, and tax
    fig = go.Figure(layout=go.Layout(template='plotly_white'))

    for var in ['robert_income', 'isabel_income', 'expenses', 'savings', 'tax']:
        fig.add_trace(go.Scatter(x=projection.index, y=projection[var], mode='lines', name=var))

    fig.show()

    # plot asset simulations as a fan chart
    fig = go.Figure(layout=go.Layout(template='plotly_white'))

    percentiles = [1, 5, 20, 50, 80, 95, 99]
    colors = [cm.Reds(x) for x in np.linspace(0.5, 1, 3)][::-1] + [cm.Greys(0.3)] + [cm.Blues(x) for x in np.linspace(0.5, 1, 3)]

    for i in range(len(percentiles)):
        percentile = percentiles[i]
        color = colors[i]
        asset_percentile = asset_simulations.apply(lambda x: np.percentile(x, percentile), axis=1)
        if i == 0:
            fig.add_trace(
                go.Scatter(x=asset_percentile.index, y=asset_percentile, fill=None, fillcolor=None,
                           line_color='rgba' + str(color), name=str(percentile) + 'th percentile'))
        else:
            fig.add_trace(
                go.Scatter(x=asset_percentile.index, y=asset_percentile, fill='tonexty', fillcolor='rgba' + str(color),
                           line_color='rgba' + str(color), name=str(percentile) + 'th percentile'))

    fig.show()

# Example usage
financial_projection(100000, 200000, 50000, 800000, growth_rates, shocks, 0.08, 1000)


KeyError: 'savings'