In [21]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 11 2020

@author: kalyan
"""

## SOLOW GROWTH MODEL 
from sympy import *
import numpy as np
import scipy.stats as stats
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Constant returns to scale SOLOW growth model

def solow_growth(init_cap=1,init_lab=2,alpha=0.5,sav_rate=0.25,dep_rate=0.1,pop_rate=0.1,show_ss= True,show_gold=False):
    # Parameters defining the structure of an economy
    α = alpha    # Production share of capital
    s = sav_rate # Exogenous Savings Rate
    δ = dep_rate # Exogenous Depreciation Rate
    n = pop_rate # Exogenous Population Growth Rate
    
    # Parameters defining starting point of an economy
    K = init_cap
    L = init_lab
    init_k = K/L

    # Calculating STEADY STATE 
    kss = (s/(n+δ))**(1/(1-α))
    
  
    k = symbols('k')

    # Structure of Economy: Production, Savings & Required Investment
    f = k**α        # Output per capita
    sf = s*(f)      # Savings per capita
    ri = (n+δ)*(k)  # Required investment
      
    # Generate functions from input parameters
    f_x = lambdify(k, f, modules=['numpy']) # sympy module
    s_x = lambdify(k, sf, modules=['numpy']) # sympy module
    ri_x = lambdify(k, ri, modules=['numpy']) # sympy module
       
    k_max = kss*1.5
    xvals = np.linspace(0,k_max,100)
    
    trace_f = go.Scatter(x= xvals, y= f_x(xvals), name='Output per capita (y)')
    trace_s = go.Scatter(x= xvals, y= s_x(xvals), name='Savings/capita@ rate (s={0:.2f})'.format(s))
    trace_ri = go.Scatter(x= xvals, y= ri_x(xvals), name='Required Inv. per capita (n+δ)*(k)')
    
    ############### STEADY STATE  ###############
    # Calculate other steady state values
    yss = f_x(kss) # SS Output
    css = (1-s)*yss # SS Consumption
    
    # Figure Layout & other components
    solow_layout = go.Layout(
        title = dict(text='Solow Growth model for (y={}) production'.format(f)),
        xaxis=dict(title="Per capita units of capital (k=K/L)"),
        yaxis=dict(title="Output per capita (y=Y/L)"),
        grid=None
        )
    
    fig1 = go.Figure(data=[trace_f, trace_s, trace_ri], layout = solow_layout)
    
    def get_steady_state():

        # Add Steady state lines on the plot
        fig1.add_shape(
            # Vertical line @SS
            dict(type="line", x0=kss,y0=0,x1=kss,y1=yss,line=dict(color="Brown",width=3, dash="dash"))
        )
        fig1.add_shape(
            # Horizontal line @SS
            dict(type="line", x0=0,y0=yss,x1=kss,y1=yss,line=dict(color="Brown",width=3, dash="dash"))
        )
        fig1.add_shape(
            # Vertical line @SS
            dict(type="line", x0=kss,y0=yss,x1=kss,y1=yss-css,line=dict(color="Black",width=3))
        )
        fig1.add_trace(
            go.Scatter(
            x=[kss, 0],
            y=[0, yss],
            mode="markers+text",
            text=["SS Capital:{0:.2f}".format(kss), "SS Output:{0:.2f}".format(yss)],
            textposition=["bottom center","top right"],
            textfont=dict(family="Courier New, monospace",size=10,color="Brown"), showlegend=False)
        )
        fig1.add_annotation(
            x=kss,
            y=(yss-css/2),
            xref="x",
            yref="y",
            text="SS Consumption: {0:.2f}".format(css),
            showarrow=True,
            font=dict(family="Courier New, monospace",size=10,color="#ffffff"),
            align="center",
            arrowhead=2,
            arrowsize=1,
            arrowwidth=2,
            arrowcolor="#636363",
            ax=20,
            ay=-50,
            bordercolor="#c7c7c7",
            borderwidth=2,
            borderpad=4,
            bgcolor="Brown",
            opacity=0.8
        )
        return fig1
    
    
      ############### GOLDEN (STEADY) STATE  ###############

    def get_gold_state():

            ############### GOLDEN (STEADY) STATE  ###############
            # Calculate Golden steady state values
            k_gold = (α/(n+δ))**(1/1-α)
            y_gold = f_x(k_gold) # SS Output
            s_gold = (n+δ)*(k_gold)**(1-α) # Note: Savings rate has to change to attain golden state
            c_gold = (1-s_gold)*y_gold # SS Consumption
            
            sgf = s_gold*(f)
            s_gx = lambdify(k, sgf, modules=['numpy']) # sympy module

            trace_sgold = go.Scatter(x= xvals, y= s_gx(xvals), 
                                     name='Golden rate savings/capita @ rate (s_gold ={0:.2f})'.format(s_gold),
                                     line= dict(color="Gold",dash="dot")
                                    )
            
            fig1.add_trace(trace_sgold)

                
            # Add Golden Steady State lines on the plot
            fig1.add_shape(
                # Vertical line @SS
                dict(type="line", x0=k_gold,y0=0,x1=k_gold,y1=y_gold,line=dict(color="Green",width=2, dash="dot"))
            )
            fig1.add_shape(
                # Horizontal line @SS
                dict(type="line", x0=0,y0=y_gold,x1=k_gold,y1=y_gold,line=dict(color="Green",width=2, dash="dot"))
            )
            fig1.add_shape(
                # Vertical line @SS
                dict(type="line", x0=k_gold,y0=y_gold,x1=k_gold,y1=y_gold-c_gold,line=dict(color="Green",width=3))
            )
            fig1.add_trace(
                go.Scatter(
                x=[k_gold, 0],
                y=[0, y_gold],
                mode="markers+text",
                text=["SS Capital:{0:.2f}".format(k_gold), "SS Output:{0:.2f}".format(y_gold)],
                textposition=["bottom center","top right"],
                textfont=dict(family="Courier New, monospace",size=10,color="Green"), showlegend=False)
            )
            fig1.add_annotation(
                x=k_gold,
                y=((y_gold-c_gold/2)),
                xref="x",
                yref="y",
                text="Golden Consumption: {0:.2f}".format(c_gold),
                showarrow=True,
                font=dict(family="Courier New, monospace",size=10,color="#ffffff"),
                align="center",
                arrowhead=2,
                arrowsize=1,
                arrowwidth=2,
                arrowcolor="#656563",
                ax=20,
                ay=-70,
                bordercolor="#c8c8c8",
                borderwidth=2,
                borderpad=4,
                bgcolor="Green",
                opacity=0.9
            )
            return fig1
            
    if show_ss== True:
        get_steady_state()
    else:
        print("No steady state request")
        
    
    if show_gold == True:
        get_gold_state()
    else:
        print("No golden state request")
       
    fig1.update_layout(hovermode='x unified')
    fig1.show()
    return fig1


In [23]:
x  = solow_growth(1,2,0.25,0.25,0.5,0.70,show_ss=True,show_gold=True)
