In [59]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.stats import norm
import quantecon as qe

import matplotlib.pyplot as plt
import seaborn as sns

from collections import namedtuple

In [60]:
# Use 64 bit floats for extra precision.
jax.config.update('jax_enable_x64', True)

In [61]:
# Store the parameters, grids, and Monte Carlo draws in a `namedtuple`.
Parameters = namedtuple('Parameters', ['β', 'θ', 'cf', 'ce', 'A'])
Grids = namedtuple('Grids', ['sgrid', 'ngrid', 'Edraws', 'Sdraws'])
Model = namedtuple('Model', ['params', 'grids'])

In [3]:
# Define tolerance.
tol1 = 0.05 # precision in finding the equilibrium price
tol2 = 0.01 # precision in the first stage value function
tol3 = 0.01 # precision in the second stage value function
tol4 = 1e-4 # precision in the stationary distribution

In [7]:
nnum = 250             # size of employment grid
nmin = jnp.log(1)      # min. employment
nmax = jnp.log(5000)   # max. employment
ngrid = jnp.linspace(nmin, nmax, nnum) 
ngrid = jnp.exp(ngrid) # employment grid

In [8]:
# Assign parameters of the firm problem.
θ = 0.64
ce = -8.26953813830592
cf = 17
β = 0.8
f = 0

In [9]:
# Assign demand parameter.
A = 1.2459

In [10]:
# Stochastic parameters
snum = 20 # number of points in the grid for the shock
smin = jnp.log(1/0.64)
smax = jnp.log(32)
sgrid = jnp.linspace(smin, smax, snum)
σ = 0.26
ρ = 0.93
smean = 0.208
means = smin + smean * (smax - smin)

In [13]:
# Distribution for entrants
# The distribution is uniform over the first `νnum` components of the s vector.
νgrid = jnp.zeros_like(sgrid)
νnum = 8
νgrid = νgrid.at[:νnum].set((1 / νnum) * jnp.ones((νnum, )))

In [17]:
# Solution choices

# A toggle for whether or not you want to determine the price
# If set True, price will be set at `p_exog`.
find_p = True 
p_exog = 1
p_init = 1

# A toggle that determines if you want to specify an exogenous exit rule.
# If set True, firms will exit whenever their individual shock is less than or equal to `s_exit` in the s grid.
# The matrix `exexog` then specifies the corresponding decision rule for exit.
exog_exit = False
s_exit = 6
exexog = jnp.zeros((snum, nnum))
exexog = exexog.at[:s_exit, :].set(jnp.ones((s_exit, nnum)))

In [42]:
# Make the transition matrix.
# The row identifies the old state; the column, the new state.

intercept = (1 - ρ) * means
midpoints = jnp.linspace((sgrid[0] + sgrid[1])/2, (sgrid[-2] + sgrid[-1])/2, snum-1)
displacements = jnp.ones((snum, 1)) * midpoints - intercept - ρ * sgrid.reshape((snum, 1)) * jnp.ones((1, snum-1))
probs = norm.cdf(displacements/σ)
trans = jnp.hstack([probs, jnp.ones((snum, 1))]) - jnp.hstack([jnp.zeros((snum, 1)), probs])

In [44]:
state = jnp.ones_like(sgrid) / snum

for i in range(1000):
    state_new = state @ trans
    error = jnp.max(jnp.abs(state_new - state))
    state = state_new
    if error < 1e-5:
        print(f'Convergence achieved at loop {i}')
        break

Convergence achieved at loop 76


In [33]:
Parameters = namedtuple('Parameters', ['θ', 'cf', 'β', 'tol', 'f'])

In [34]:
params = Parameters(θ, cf, β, tol2, f)
params_fs = Parameters(θ, cf, β, tol3, f)

In [55]:
# The variables n20, n100, n500, are the indicies that determine the cutoff for the corresponding firm sizes.
n20 = jnp.argmax(ngrid * (ngrid <= 20))
n100 = jnp.argmax(ngrid * (ngrid <= 100))
n500 = jnp.argmax(ngrid * (ngrid <= 500))

In [None]:
# Do you need to find the price?

def compute_p_star(params, grids, p_init, tol=tol1):

    
if find_p:
    Ve = 10
    # Use Newton-Raphson method to find p*.