In [None]:
using LinearAlgebra, Statistics
using Distributions, Plots, QuantEcon, Interpolations, Parameters
using NLsolve, Random
using Expectations
using DataFrames
using StatsBase
using JLD

In [None]:
function originalmodel(; λ = 0.009, α = 0.0009, β = 0.9985, 
        τ = 0.0285,
        w_min = 0, w_max = 1.0, w_grid_size = 150, I_grid_size = 15,
        h_min = 1, h_max = 2, h_grid_size = 21, s_grid_size = 150,
        μ_unemployed = [0.8, 0.2], μ_employed = [0.9, 0.1], μ_laidoff = [1, 0])
    
    dist_F = truncated(Normal(0.5, sqrt(0.1)), w_min, w_max)
    w_grid = sort!(rand(dist_F, w_grid_size))
    I_grid = range(2/15, 2.0, length = I_grid_size)
    h_grid = range(h_min, h_max, length = h_grid_size)
    s0_grid = range(0, 1.0, length = s_grid_size)
    sb_grid = range(0, 1.0, length = s_grid_size)

    F_w = cdf.(dist_F, w_grid)
    F_p = zeros(w_grid_size)
    for i in 1:w_grid_size
        if i == 1
            dfw = F_w[i]
        else
            dfw = F_w[i] - F_w[i-1]
        end
        F_p[i] = dfw
    end
    
    F_probs = F_p ./ sum(F_p)
    
    V0 = fill(100.0,(1, h_grid_size))
    Vb = fill(100.0,(I_grid_size, h_grid_size))
    V = fill(100.0,(w_grid_size, h_grid_size))
    V1 = fill(100.0,(w_grid_size, h_grid_size))
    Vvb = fill(100.0,(w_grid_size, h_grid_size))
    policy_v0 = fill(0.5,(1, h_grid_size))
    policy_vb = fill(0.5,(I_grid_size, h_grid_size))
    policy_v = fill(0.5,(w_grid_size, h_grid_size))
    
    
    return (λ = λ, α = α, β = β, τ = τ, w_grid_size = w_grid_size, h_grid_size = h_grid_size,
        I_grid_size = I_grid_size, s_grid_size = s_grid_size,
        μ_unemployed = μ_unemployed, μ_employed = μ_employed, μ_laidoff = μ_laidoff,
        w_grid = w_grid, I_grid = I_grid, h_grid = h_grid, dist_F = dist_F,
        s0_grid = s0_grid, sb_grid = sb_grid, F_probs = F_probs, V0 = V0, Vb = Vb, V = V, V1 = V1,
        Vvb = Vvb, policy_v0 = policy_v0, policy_vb = policy_vb, policy_v = policy_v)
end

ls = originalmodel()

c(x) = 0.5 * x

pi(x) = x^0.3

function a(x)
    if x == 1
        h_new_u = x
        h_new_e = x + 1
    elseif x == 21
        h_new_u = x - 1
        h_new_e = x
    else
        h_new_u = x - 1
        h_new_e = x + 1
    end
    return [h_new_u, h_new_e]
end


In [None]:
function one_step_update!(ls, EV0, EVb, EVvb, EV, EV1)
    
    @unpack λ, α, β, τ, w_grid_size, h_grid_size, I_grid_size, s_grid_size, μ_unemployed, μ_employed, μ_laidoff = ls 
    @unpack w_grid, I_grid, h_grid, dist_F,s0_grid, sb_grid, F_probs = ls 
    @unpack V0, Vb, V, V1, Vvb, policy_v0, policy_vb, policy_v = ls
    
    function de(x)    
        i = 0
        for nw in ls.w_grid
            if nw < x
                i = i + 1
            end
        end
        return i = i
    end
    
    for ih in 1:ls.h_grid_size
        h = ls.h_grid[ih]
        ihnu = a(ih)[1]         #the sequential number of (next) new h with μ_unemployed
        ihne = a(ih)[2]         #the sequential number of (next) new h with μ_employed
        h_nu = ls.h_grid[ihnu]  #the value of new h of unemployed worker
        h_ne = ls.h_grid[ihne]  #the value of new h of employed worker
    
        currentmax_v0_h = -10
        currentmax_s0 = -10
        
        for is_0 in 1:ls.s_grid_size
            s_0 = ls.s0_grid[is_0]
                    
            v00 = (1 - pi(s_0)) * EV0[1,ih] + pi(s_0) * sum(EV[:,ih] .* ls.F_probs)      #case 1: h' = h
            v01 = (1 - pi(s_0)) * EV0[1, ihnu] + pi(s_0) * sum(EV[:,ihnu] .* ls.F_probs) #case 2: h' = h_nu
            v0 = - c(s_0) + (1 - ls.α) * ls.β * sum(ls.μ_unemployed .* [v00, v01])
            
            if v0 > currentmax_v0_h
                currentmax_v0_h = v0
                currentmax_s0 = s_0
            end
        end
        ls.V0[1,ih] = currentmax_v0_h 
        ls.policy_v0[1,ih] = currentmax_s0
        
        
        for iw in 1:ls.w_grid_size
            w = ls.w_grid[iw]
            iI_wh = Int(ceil((w * h)/(2/15)))
            
            m = (1 - ls.λ) * sum(ls.μ_employed .* [EV[iw,ih], EV[iw,ihne]]) + ls.λ * EVb[iI_wh,ih]
            v1 = (1 - ls.τ) * w * h + (1 - ls.α) * ls.β * m
            
            ls.V1[iw,ih] = v1
            ls.V[iw,ih] = v1 > currentmax_v0_h ? v1 : currentmax_v0_h           
            ls.policy_v[iw,ih] = v1 > currentmax_v0_h ? 1 : 0
            
            for iI in 1:ls.I_grid_size
                I = ls.I_grid[iI]
                Vvb[iw,ih] = max(EV1[iw,ih],EVb[iI,ih])
                
                currentmax_vb = -10
                currentmax_sb = -10
                
                for isb in 1:ls.s_grid_size
                    sb = ls.sb_grid[isb]
                    
                #case 1: h' = h
                    dd1 = 0.7 * I / h
                    ip = de(dd1)   #number of wage lower than government accepted level
                    ip_ = ip + 1
                    
                    vb1 = -1
                    
                    if ip == 0
                        vb1 = (1 - pi(sb)) * EVb[iI,ih] + pi(sb) * sum(EV[:,ih] .* ls.F_probs)
                    elseif ip == ls.w_grid_size
                        vb1 = (1 - pi(sb)) * EVb[iI,ih] + pi(sb) * sum(EVvb[:,ih] .* ls.F_probs)
                    else
                        tru_11 = sum(EV[ip_:end,ih] .* ls.F_probs[ip_:end])
                        tru_12 = sum(EVvb[1:ip,ih] .* ls.F_probs[1:ip])
                        vb1 = (1 - pi(sb)) * EVb[iI,ih] + pi(sb) * (tru_11 + tru_12)
                    end
                    
                #case 2: h' = h_nu
                    dd2 = 0.7 * I / h_nu
                    ipp = de(dd2)
                    ipp_ = ipp + 1
                    
                    vb2 = -1
                    
                    if ipp == 0
                        vb2 = (1 - pi(sb)) * EVb[iI,ihnu] + pi(sb) * sum(EV[:,ihnu] .* ls.F_probs)
                    elseif ipp == ls.w_grid_size
                        vb2 = (1 - pi(sb)) * EVb[iI,ihnu] + pi(sb) * sum(EVvb[:,ihnu] .* ls.F_probs)
                    else
                        tru_21 = sum(EV[ipp_:end,ihnu] .* ls.F_probs[ipp_:end])
                        tru_22 = sum(EVvb[1:ipp,ihnu] .* ls.F_probs[1:ipp])
                        vb2 = (1 - pi(sb)) * EVb[iI,ihnu] + pi(sb) * (tru_21 + tru_22)
                    end
                    
                    #vb
                    vb = - c(sb) + (1 - ls.τ) * 0.7 * I + (1 - ls.α) * ls.β * sum(ls.μ_unemployed .* [vb1, vb2])
                    
                    if vb > currentmax_vb
                        currentmax_vb = vb
                        currentmax_sb = sb
                    end
                end
                    
                ls.Vb[iI,ih] = currentmax_vb
                Vvb[iw,ih] = max(EV1[iw,ih],EVb[iI,ih])
                ls.policy_vb[iI,ih] = currentmax_sb
            end
        end
    end  
    return ls.V, ls.V0, ls.Vb, ls.policy_v0, ls.policy_vb, ls.policy_v
end
