In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

# import two functions out of the retirement_model.py file
from retirement_model import calc_savings, calc_retirement_account

In [None]:
def get_curve(currect_income, remaining_years, saving_ratio, burn_rate_per_month, income_increase, interest):
    # call function calc_savings and save the result to an array
    savings = np.array(calc_savings(currect_income, saving_ratio,
        income_increase, remaining_years, interest))

    # accumulate row 1 by reason of showing increasing income
    savings[:,1] = np.add.accumulate(savings[:,1])

    # call function calc_retirement_account and also save output to abs
    # numpy array
    total_savings = savings[-1, 1]
    retirement = np.array(calc_retirement_account(total_savings,
        burn_rate_per_month, interest))
    # we only want to return one curve, so we have to concatenate these two matrices
    curve = np.concatenate(
        (savings, retirement + (remaining_years, 0)), axis=0)
    #scale down cash
    curve[:, 1] *= (1/1000)
    #shift right the years, cause we don't want to start at 0
    curve[:, 0] += 1
    return curve                                                     

In [None]:
curve_1 = get_curve(currect_income=3500, remaining_years=40, saving_ratio=0.15,
                    burn_rate_per_month=1500, income_increase=1.03, interest=1.01)

curve_2 = get_curve(currect_income=4000, remaining_years=40, saving_ratio=0.15,
                    burn_rate_per_month=2300, income_increase=1.03, interest=1.01)

curve_3 = get_curve(currect_income=3500, remaining_years=20, saving_ratio=0.5,
                    burn_rate_per_month=2300, income_increase=1.03, interest=1.01)

curves = [(curve_1, 'red', 'income 3500, br 1500'), (curve_2, 'blue', 'income 4000'), (curve_3, 'green', 'inc.2500,saving 50%')]

In [None]:
# do the fun part: plotting all curves to a figure
for curve, color, label in curves:
    plt.plot(curve[:,0], curve[:,1], color=color,
        linestyle='solid', label=label)

plt.legend(numpoints=1, fontsize=10)
plt.xlabel('Years')
plt.ylabel('Cash Balance [1.000 \u20AC]')

max_y = max(np.max(tup[0][:,1]) for tup in curves)
plt.ylim((0, max_y * 1.1))
plt.grid()