In [1]:
import jax
import jax.numpy as jnp
import optax
from flax import nnx

from types import SimpleNamespace
import numpy as np

In [2]:
import flax

print("Flax version:", flax.__version__)

Flax version: 0.12.0


## 1. Create neural network

In [3]:
class Policy(nnx.Module):
  def __init__(self, din: int, dout: int, neurons: list, rngs: nnx.Rngs):

    # 1st layer
    self.layer1 = nnx.Linear(din, neurons[0], rngs=rngs)
    
    # hidden layers
    for layer in range(len(neurons)-1):
      
      setattr(self, f"layer{layer+2}", nnx.Linear(neurons[layer], neurons[layer+1], rngs=rngs))

    # last layer
    self.layer_out = nnx.Linear(neurons[-1], dout, rngs=rngs)

    self.din, self.dout, self.hidden_layers = din, dout, neurons
  
  def __call__(self, x: jax.Array):

    # 1st + hiden layers
    for i in range(len(self.hidden_layers)):

      # unpack layer
      layer = getattr(self, f"layer{i+1}")

      # forward x
      x = nnx.relu(layer(x))

    # last layer
    layer = self.layer_out

    y = jax.nn.sigmoid(layer(x))

    return y

In [None]:
def setup_nn(model):

  par = model.par
  train = model.train

  T = par.T
  Nstates = par.Nstates
  Nactions = par.Nactions

  din = Nstates + T
  dout = Nactions
  neurons = train.neurons

  nn = Policy(din, dout, neurons, rngs=nnx.Rngs(params=0))

  return nn

def eval_policy(model,nn,x,t):

  par = model.par

  # time dummies
  Nx = x.shape[0]
  T = par.T
    
  time_dummies = jax.nn.one_hot(t, T)           # shape (T,)
  time_dummies = jnp.broadcast_to(time_dummies, (Nx, T))
  # concatenate
  
  x = jnp.concatenate((x,time_dummies),axis=-1)

  # evaluate
  action = nn(x)

  return action

In [5]:
# direcly evaluate Policy network
hidden_layers = [5,5,10]
din = 10
dout = 1

model = Policy(din, dout, hidden_layers, rngs=nnx.Rngs(params=0))
model(x=jnp.ones((1, din)))

Array([[0.5598746]], dtype=float32)

## 2. Create neural network in model

In [6]:
class Model:

    def __init__(self):

        self.par = SimpleNamespace()
        self.train = SimpleNamespace()
        self.sim = SimpleNamespace()

        self.dtype = jnp.float32
        self.device = jax.devices("cpu")[0]
    
        self.setup() # Setup model parameters
        self.allocate() # Allocate model objects
        self.setup_train()  # Setup training parameters
        self.allocate_train() # Allocate training objects
    
    # setup empty functions to be overwritten

    # Setup and allocate
    def setup(self): pass
    def allocate(self): pass
    def setup_train(self): pass
    def allocate_train(self): pass

    # Draw
    def draw_initial_states(self): pass
    def draw_shocks(self): pass

    # Transition
    def state_trans(self): pass  # Post-decision states, shocks -> next-period states

    # Reward
    def reward(self): pass  # Utility

In [7]:
def setup(self,full=None):
    """ choose parameters """

    par = self.par
    sim = self.sim
    
    # a. model
    par.T = 5 # number of periods

    # preferences
    par.beta = 1/1.01 # discount factor

    # income
    par.kappa_base = 1.0
    par.rho_p= 0.95 # shock, persistenc
    par.sigma_xi = 0.1 # shock, permanent
    par.sigma_psi = 0.1 # shock, transitory std

    # return
    par.R = 1.01 # gross return

    # b. solver settings
    par.Nstates = 2 # number of states variables
    par.Nactions = 1 # number of actions variables
    par.Nshocks = 2 # number of shocks

    # c. simulation 
    sim.N = 50_000 # number of agents

    # initial states
    par.mu_m0 = 1.0 # initial cash-on-hand, mean
    par.sigma_m0 = 0.1 # initial cash-on-hand, std

    # initial permanent income
    par.mu_p0 = 1.0 # initial durable, mean
    par.sigma_p0 = 0.1 # initial durable, std


    sim.N = 10_000 # number of simulated agents

def setup_train(model):
    """ default parameters for training """
    
    train = model.train

    # a. neural network
    train.neurons = [100,100] # number of neurons in hidden layers

    train.N = 3000 # number of agents for training
    train.seed = 0

    train.learning_rate_policy = 1e-3 # learning rate for policy

def allocate_train(model):
    """ allocate memory training """

    par = model.par
    train = model.train
    device = model.device

    # a. neural network
    model.Policy_NN = setup_nn(model) # policy neural network
    model.create_opt() # create optimizer for policy neural network

def create_opt(model):

    train = model.train

    lr = train.learning_rate_policy

    model.policy_opt = nnx.ModelAndOptimizer(model.Policy_NN, optax.adam(learning_rate=train.learning_rate_policy))


Model.allocate_train = allocate_train

In [8]:
# evaluate nn in model
Model.setup = setup
Model.setup_train = setup_train
Model.allocate_train = allocate_train
Model.create_opt = create_opt

model = Model()

#model.nn =  setup_nn(model)

x = 20*jnp.ones((1, model.par.Nstates))

eval_policy(model,model.Policy_NN,x,t=1)

Array([[0.9248638]], dtype=float32)

In [9]:
eval_policy_jit = jax.jit(eval_policy, static_argnums=(0))

eval_policy_jit(model,model.Policy_NN,x,t=1)

Array([[0.9248638]], dtype=float32)

## 3. Allocate

In [10]:
def allocate(model):
    """ allocate arrays  """

    # unpack
    par = model.par
    sim = model.sim

    # b. simulation (same across models)
    sim.states = jnp.zeros((par.T,sim.N,par.Nstates)) # State-vector
    sim.shocks = jnp.zeros((par.T,sim.N,par.Nshocks)) # Shock-vector
    sim.actions = jnp.zeros((par.T,sim.N,par.Nactions))  # actions array
    sim.reward = jnp.zeros((par.T,sim.N)) # array for utility rewards
    sim.outcomes = jnp.zeros((par.T,sim.N,1)) # array for outcomes - will just be consumption
    sim.R = jnp.nan # initialize average discounted utility

Model.allocate = allocate

def draw_initial_states(model,N):
    """ draw initial state (m,p,t) """

    par = model.par
    train = model.train

    sigma_m0 = par.sigma_m0
    sigma_p0 = par.sigma_p0

    # a. draw cash-on-hand:	
    m0 = par.mu_m0*np.exp(np.random.normal(-0.5*sigma_m0**2,sigma_m0,size=(N,)))
    m0 = jnp.array(m0,dtype=model.dtype,device=model.device)
    
    # b. draw permanent income
    p0 = par.mu_p0*np.exp(np.random.normal(-0.5*sigma_p0**2,sigma_p0,size=(N,)))
    p0 = jnp.array(p0,dtype=model.dtype,device=model.device)

    # c. store
    return jnp.stack((m0,p0),axis=1)

Model.draw_initial_states = draw_initial_states

def draw_shocks(model,N):
    """ draw shocks """

    par = model.par

    # xi 
    xi_loc = -0.5*par.sigma_xi**2
    xi = np.exp(np.random.normal(xi_loc,par.sigma_xi,size=(par.T,N,)))
    xi = jnp.array(xi,dtype=model.dtype,device=model.device)

    # psi
    psi_loc = -0.5*par.sigma_psi**2
    psi = np.exp(np.random.normal(psi_loc,par.sigma_psi,size=(par.T,N,)))
    psi = jnp.array(psi,dtype=model.dtype,device=model.device)


    return jnp.stack((xi,psi),axis=-1)

Model.draw_shocks = draw_shocks


In [11]:
Model.draw_initial_states(model, 10)

Array([[0.8558104 , 0.8778358 ],
       [1.0766381 , 1.019488  ],
       [0.9568096 , 0.93885803],
       [1.1297044 , 0.9903299 ],
       [1.0857869 , 1.0999461 ],
       [1.0134012 , 0.8934043 ],
       [0.91807693, 0.9612195 ],
       [1.3216639 , 1.004609  ],
       [1.0485314 , 0.81817067],
       [0.96520007, 1.0397788 ]], dtype=float32)

## 3. Simulate

In [39]:
def outcomes(model,states,actions,t=None):
	""" outcomes - just consumption here """

	m = states[...,0] # cash-on-hand
	a = actions[...,0] # savings rate
	c = m*(1-a) # consumption

	return jnp.stack((c,),axis=-1) # (T,N,Noutcomes)

Model.outcomes = outcomes

def state_trans(model,states,actions,outcomes,shocks,t=None):
	""" transition to future state """

	# a. unpack
	par = model.par
	xi = shocks[...,0] # permanent income shock
	psi = shocks[...,1] # transitory income shock
	m = states[...,0]
	p = states[...,1]
	c = outcomes[...,0]

	# c. post-decision
	m_pd = m-c
	
    # d. persistent income
	p_plus = p**par.rho_p * xi # permanent income
	
    # d. income
	income = par.kappa_base * p_plus * psi # income
	
    # e. future cash-on-hand
	m_plus = par.R * m_pd + income # future cash-on-hand

	# d. finalize
	states_pd = jnp.stack((m_plus,p_plus),axis=-1)
	return states_pd

Model.state_trans = state_trans

def utility(par,c):
	""" utility """

	return jnp.log(c)

def reward(model,states,actions,outcomes,t0=0,t=None):
	""" reward """

	# a. unpack
	par = model.par

	# b. consumption
	c = outcomes[...,0]

	# c. utility
	u = utility(par,c)

	return u 

Model.reward = reward

def simulate(model):
    """ Simulate to get loss in DeepSimulate """

    # a. unpack
    par = model.par
    sim = model.sim
    dtype = model.dtype
    device = model.device

    states = sim.states
    actions = sim.actions
    rewards = sim.reward
    outcomes = sim.outcomes

    # b. draw initial states
    states = states.at[0].set(model.draw_initial_states(N=sim.N))
    shocks = model.draw_shocks(N=sim.N)

    # c. simulate
    for t in range(par.T):

        # i. compute actions
        actions = actions.at[t].set(eval_policy(model, model.Policy_NN, states[t], t=t))

        # ii. compute outcomes
        outcomes = outcomes.at[t].set(model.outcomes(states[t], actions[t], t=t))

        # iii. compute rewards
        rewards = rewards.at[t].set(model.reward(states[t], actions[t], outcomes[t], t=t))

        # iv. transition
        if t < par.T - 1:
            states = states.at[t + 1].set(model.state_trans(states[t], actions[t], outcomes[t], shocks[t + 1], t=t))

    # d. compute discounted utility
    discount_factor = jnp.zeros((par.T, sim.N), dtype=dtype)
    for t in range(par.T):
        discount_factor = discount_factor.at[t].set(par.beta ** t)

    R = jnp.sum(discount_factor * rewards) / sim.N

    return R

Model.simulate = simulate

def simulate_loss(model,policy_NN,initial_states,shocks):
	""" Simulate to get objective function for policy optimization"""

	# a. unpack
	par = model.par
	train = model.train
	dtype = model.dtype
	device = model.device

	jax.lax.stop_gradient(initial_states)
	jax.lax.stop_gradient(shocks)

	# b. allocate
	N = initial_states.shape[0]
	discount_factor = jnp.zeros((par.T,N),dtype=dtype,device=device)	
	reward = jnp.zeros((par.T,N),dtype=dtype,device=device)
	new_states_t = jnp.zeros((N,par.Nstates),dtype=dtype,device=device)

	def scan_step(carry, t):
		states_t = carry

		# actions
		actions_t = eval_policy(model, policy_NN, states_t, t)

		# outcomes
		outcomes_t = model.outcomes(states_t, actions_t, t)

		# reward
		reward_t = model.reward(states_t, actions_t, outcomes_t, t)

		# transition
		next_states = model.state_trans(states_t, actions_t, outcomes_t, shocks[t], t)

		return next_states, (actions_t, outcomes_t, reward_t)

	T = par.T
	ts = jnp.arange(T)

	final_states, (actions_seq, outcomes_seq, rewards_seq) = jax.lax.scan(
		scan_step,
		initial_states,   # carry
		ts                # loop variable
	)
	
	if False:
		# c. simulate
		for t in range(par.T):

			# i. states in period t 
			if t > 0:
				states_t = jnp.array(new_states_t) # states_t then doesn't share memory with new_states_t
			else:
				states_t = initial_states
			
			# ii. endogenous actions
			actions_t = eval_policy_jit(model,policy_NN,states_t,t=t)

			# iii. reward and discount factor
			outcomes_t = model.outcomes(states_t,actions_t,t=t)
			reward = reward.at[t].set(model.reward(states_t,actions_t,outcomes_t,t=t))
			discount_factor = discount_factor.at[t].set(par.beta**t)

			# iv. transition
			if t < par.T-1:
				new_states_t = new_states_t.at[:,:].set(model.state_trans(states_t,actions_t,outcomes_t,shocks[t+1],t=t))
			
	# d. compute discounted utility
	print(rewards_seq)
	R = jnp.sum(discount_factor*rewards_seq)/train.N
	loss = -R # - because we minimize negative reward

	return loss

In [49]:
def simulate_loss(model, policy_NN, initial_states, shocks):
    """Fast, JIT+SCAN version of simulate_loss"""

    par = model.par
    train = model.train
    beta = par.beta
    T = par.T

    # Helpers (faster to bind locally)
    eval_pol = eval_policy_jit   # Or eval_policy if not jitted
    state_trans = model.state_trans
    outcomes_fn = model.outcomes
    reward_fn = model.reward

    # ------------------------------------------
    # Define scan step
    # ------------------------------------------
    def scan_step(states_t, t):
        """One period transition in simulation."""

        # Policy
        actions_t = eval_pol(model, policy_NN, states_t, t)

        # Outcomes & reward
        outcomes_t = outcomes_fn(states_t, actions_t, t)
        reward_t = reward_fn(states_t, actions_t, outcomes_t, t)

        # Transition to next state
        next_states = state_trans(states_t, actions_t, outcomes_t, shocks[t], t)

        return next_states, (reward_t,)

    # ------------------------------------------
    # Run scan over all T periods
    # ------------------------------------------
    ts = jnp.arange(T)
    final_states, (reward_seq,) = jax.lax.scan(
        scan_step, 
        initial_states,   # carry
        ts                # loop variable
    )

    # ------------------------------------------
    # Discount rewards
    # ------------------------------------------
    discounts = beta ** jnp.arange(T)
    discounted_sum = jnp.sum(discounts[:, None] * reward_seq, axis=0)

    # Average objective over training samples
    loss = -jnp.mean(discounted_sum)

    return loss


In [50]:
Model.simulate = simulate
Model.simulate_loss = simulate_loss 

model = Model()
model.simulate()

Array(-1.4446996, dtype=float32)

In [57]:
def policy_update_step(model,loss_fn):

    policy_opt = model.policy_opt

    loss, grads = jax.value_and_grad(loss_fn)(model.Policy_NN)
    
    policy_opt.update(grads)

    return loss

def train_policy(model, K=100):
    """ train policy function with simulation approach """

    """ 
    K: number of iterations to train policy
    """
    train = model.train

    np.random.seed(model.train.seed)
    
    # training loop
    for k in range(K):

        # i. draw initial states and shocks
        initial_states = model.draw_initial_states(N=train.N)
        shocks = model.draw_shocks(N=train.N)

        simulate_loss_jit = jax.jit(simulate_loss, static_argnums=(0,))
        
        # ii. simulate loss
        loss_fn = lambda nn: simulate_loss_jit(model, nn, initial_states, shocks)

        # iii. update policy parameters
        loss = policy_update_step(model,loss_fn)

        # iv. print progress
        if k % 10 == 0:
            print(f"Iteration {k}: Loss {loss.item()}")

Model.train_policy = train_policy

In [58]:
model = Model()

In [59]:
model.train_policy(K=1000)

Iteration 0: Loss 1.468570590019226
Iteration 10: Loss 0.319608598947525
Iteration 20: Loss 0.18654939532279968
Iteration 30: Loss 0.12268534302711487
Iteration 40: Loss 0.10773950070142746
Iteration 50: Loss 0.09452074021100998
Iteration 60: Loss 0.09294816851615906
Iteration 70: Loss 0.09860627353191376
Iteration 80: Loss 0.0831630602478981
Iteration 90: Loss 0.08901738375425339
Iteration 100: Loss 0.07833657413721085
Iteration 110: Loss 0.08447569608688354
Iteration 120: Loss 0.08388619124889374
Iteration 130: Loss 0.08444564789533615
Iteration 140: Loss 0.0844244733452797
Iteration 150: Loss 0.07709135860204697
Iteration 160: Loss 0.0725313350558281
Iteration 170: Loss 0.10389026999473572
Iteration 180: Loss 0.09708646684885025
Iteration 190: Loss 0.09744958579540253
Iteration 200: Loss 0.09046687930822372
Iteration 210: Loss 0.09070216864347458
Iteration 220: Loss 0.08374987542629242
Iteration 230: Loss 0.0709710344672203
Iteration 240: Loss 0.07275847345590591
Iteration 250: Loss