#### ------------------------------------------------------------
# Replication of Aiyagari (1994, QJE)
### Youngdoo Choi (lenatics@snu.ac.kr), January 2021
#### ------------------------------------------------------------

## 1. Solve Household's Problem

In [1]:
using Parameters, Interpolations, Optim, Plots;
using QuantEcon, StatsBase;
# using Statistics, Distributions;
# using NLsolve, Roots, Random;

### Parameters

In [2]:
const beta = 0.96;
const amin = 0;
const alpha = 0.36;
const delta = 0.08;

In [3]:
# Household parameters
HH = @with_kw (
    # preference
    mu = 3,
    u = c -> mu == 1 ? log(c) : (c^(1-mu)-1)/(1-mu),
    # asset
    amax = 100, agridsize = 201,
    agrid = range(amin, amax, length=agridsize),
    # productivity
    lgridsize = 7, 
    sigma = 0.2, rho = 0.9, 
    MC = tauchen(lgridsize, rho, sigma*sqrt(1-rho^2), 0, 3), Pi = MC.p,
    theta = stationary_distributions(MC)[1],
    lgrid_temp = exp.(range(-3*sigma, 3*sigma, length=lgridsize)),
    lgrid = lgrid_temp / sum(lgrid_temp .* theta),
    l_idx = l -> findfirst(x -> x == l, lgrid),
);

# Equlibrium
Eq = @with_kw (
    r = 0.04,
    MPK = r + delta,
    w = (1-alpha)*MPK^(alpha/(alpha-1))/alpha^(alpha/(alpha-1)),
);

### Value function iteration

In [4]:
function TV(V_old, HH, Eq)
    @unpack u, agrid, lgridsize, lgrid, l_idx, Pi = HH
    @unpack r, w = Eq
    
    # interpolate value function
    V_tilde = [LinearInterpolation(agrid, V) for V in V_old]
    # create objective functions
    Obj = [(a′ -> u(w*l + (1+r)a - a′) + beta*sum([V_til(a′) for V_til in V_tilde], 
                weights(Pi[l_idx(l), :])) for a in agrid) for l in lgrid]
    # optimize
    Res = [maximize.(obj, agrid[1], min.(max.(agrid[1], w*l.+(1 + r).*agrid), agrid[end])) 
        for (l, obj) in zip(lgrid, Obj)]
    # maximized values
    V_new = [Optim.maximum.(res) for res in Res] 
    # optimized policy
    Pol = [Optim.maximizer.(res) for res in Res]
    
    return V_new, Pol
end;

In [5]:
function VFI(V_0, HH, Eq, tol, max_iter; TV=TV)
    @unpack lgridsize = HH
    V_old = V_0
    for iter in 1:max_iter
        V_new, Pol = TV(V_old, HH, Eq)
        if maximum([maximum(abs.(V_new[i]-V_old[i])) for i in 1:lgridsize]) < tol
            return V_new, Pol
        else
            V_old = V_new
        end
    end
end;

In [7]:
@unpack agridsize, lgridsize = HH();
V_0 = [zeros(agridsize) for i in 1:lgridsize]
tol = 1e-6;
max_iter = 10000;

@time V_new, Pol = VFI(V_0, HH(), Eq(), tol, max_iter);

  4.427473 seconds (24.15 M allocations: 2.432 GiB, 9.36% gc time)


## 2. Compute Invariant Distribution

### Iteration

In [7]:
function Tμ(μ, Pol, HH)
    @unpack agridsize, agrid, lgridsize, Pi = HH
    μ′ = zeros(agridsize, lgridsize)
    
    for (y_idx, P) in enumerate(Pol)
        for a_idx in 1:agridsize
            ia′h = findfirst(a′-> a′ > P[a_idx], agrid)
            if ia′h == nothing
                μ′[end, :] .+= μ[a_idx, y_idx] .* Pi[y_idx, :]
            else
                ia′l = ia′h - 1
                μ′[ia′l, :] .+= μ[a_idx, y_idx] .* Pi[y_idx, :] * (agrid[ia′h]-P[a_idx]) / (agrid[ia′h]-agrid[ia′l])
                μ′[ia′h, :] .+= μ[a_idx, y_idx] .* Pi[y_idx, :] * (P[a_idx]-agrid[ia′l]) / (agrid[ia′h]-agrid[ia′l])
            end
        end
    end
    
    return μ′
end;

In [8]:
function stat_dist(μ_0, Pol, HH, Eq, tol, max_iter; Tμ=Tμ)
    @unpack agrid, lgrid = HH
    @unpack r = Eq
    μ_old = μ_0
    for iter in 1:max_iter
        μ_new = Tμ(μ_old, Pol, HH)
        if maximum(abs.(μ_new.-μ_old)) < tol
            Ks = sum(μ_new, dims=2)' * agrid
            L = sum(μ_new, dims=1) * lgrid
            Kd = ((r+delta)/(alpha*L.^(1-alpha)))^(1/(alpha-1))
            return μ_new, Ks, Kd
        else
            μ_old = μ_new
        end
    end
end;

In [9]:
@unpack agridsize, lgridsize = HH();
μ_0 = zeros(agridsize, lgridsize);
μ_0[1, 1] = 1;
tol = 1e-6;
max_iter = 10000;

@time mu, Ks, Kd = stat_dist(μ_0, Pol, HH(), Eq(), tol, max_iter);

  4.348094 seconds (28.86 M allocations: 3.099 GiB, 9.36% gc time)


## 3. Compute Equilibrium

### Compute Eq'm interest rate

In [10]:
function Eqm_r(HH, max_bisec; TV=TV, VFI=VFI, Tμ=Tμ, stat_dist=stat_dist)
    # setup
    @unpack agridsize, lgridsize = HH
    V_0 = [zeros(agridsize) for i in 1:lgridsize]
    μ_0 = zeros(agridsize, lgridsize); μ_0[1, 1] = 1
    tol = 1e-6; max_iter = 10000
    
    # initial points
    r1 = -delta/2; r2 = (1-beta)/beta - 1e-6
#    r_left = Eq(r = r1); r_right = Eq(r = r2)
#    # left
#    ~, Pol = VFI(V_0, HH, r_left, tol, max_iter; TV=TV)
#    ~, Ks_left, Kd_left = stat_dist(μ_0, Pol, HH, r_left, tol, max_iter; Tμ=Tμ)
#    # right
#    ~, Pol = VFI(V_0, HH, r_right, tol, max_iter; TV=TV)
#    ~, Ks_right, Kd_right = stat_dist(μ_0, Pol, HH, r_right, tol, max_iter; Tμ=Tμ)
    
    # iteration
    for iter in 1:max_bisec
        r3 = (r1 + r2)/2
        ~, Pol = VFI(V_0, HH, Eq(r = r3), tol, max_iter; TV=TV)
        ~, Ks_new, Kd_new = stat_dist(μ_0, Pol, HH, Eq(r = r3), tol, max_iter; Tμ=Tμ)
        if abs.((Ks_new - Kd_new)[1]) < 0.01
            return r3, Ks_new, Kd_new, iter
        elseif (Ks_new - Kd_new)[1] > 0
            r2 = r3
        else
            r1 = r3
        end
    end
end;

In [11]:
max_bisec = 100
@time r_star, Ks_star, Kd_star, iter = Eqm_r(HH(), max_bisec; TV=TV, VFI=VFI, Tμ=Tμ, stat_dist=stat_dist);

 42.617460 seconds (307.15 M allocations: 33.267 GiB, 9.09% gc time)


In [12]:
[r_star*100 Ks_star Kd_star iter] 

1×4 Array{Float64,2}:
 3.32917  6.09254  6.08889  10.0

### Replicate Table II

In [13]:
Table = zeros(4, 3)
rho_vec = [0 0.3 0.6 0.9 0 0.3 0.6 0.9 0 0.3 0.6 0.9]
mu_vec = [1 1 1 1 3 3 3 3 5 5 5 5]

@time begin
    for i in 1:12
        r_star, ~, ~, ~ = Eqm_r(HH(rho = rho_vec[i], mu = mu_vec[i]), max_bisec; TV=TV, VFI=VFI, Tμ=Tμ, stat_dist=stat_dist);
        Table[i] = r_star
    end
end

Table

661.869000 seconds (5.03 G allocations: 557.867 GiB, 9.77% gc time)


4×3 Array{Float64,2}:
 0.041242   0.0401653  0.0388245
 0.0410426  0.0394775  0.0375684
 0.0406389  0.0379971  0.034807
 0.0394177  0.0332917  0.0259546

In [14]:
Table*100

4×3 Array{Float64,2}:
 4.1242   4.01653  3.88245
 4.10426  3.94775  3.75684
 4.06389  3.79971  3.4807
 3.94177  3.32917  2.59546