In [2]:
pip install -q optax dm-haiku

[K     |████████████████████████████████| 154 kB 4.1 MB/s 
[K     |████████████████████████████████| 352 kB 8.7 MB/s 
[K     |████████████████████████████████| 85 kB 1.9 MB/s 
[?25h

In [3]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


Solving for the optimal consumption policy and generating 1 million sample paths of expected sum of discounted rewards (value function) resulting from that policy, assuming savings are fully invested in the stock market, so the evolution of wealth is now stochastic.

In [None]:
γ = 2.
β = 0.95

def stock_return(rng):
  μs = 0.06
  σs = 0.2
  ε = jax.random.normal(rng, ())
  log_return = μs + σs * ε
  return jnp.exp(log_return)

def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50


def nnet(x):
  X = jnp.column_stack([x])
  X = hk.Linear(32)(X)
  X = jnp.tanh(X)
  X = hk.Linear(1)(X)
  X = jnp.squeeze(X)
  return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))

opt_state = optimizer(lr).init(Θ)

def L(Θ, rng):

  x = 1.
  G = 0.

  state = x
  inputs = jnp.arange(T)
##########################################################

# creating a vector of returns for different time

  def ret(rng, inputs):
    rng, _ = jax.random.split(rng)
    R = stock_return(rng)
    return rng, R
  rng, R = jax.lax.scan(ret, rng, inputs) 
##########################################################
    
  def core(state, inputs):
    t = inputs
    xt = state
    r = R[inputs-1]# selecting a randomly generated return 
    ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt
    ut = U(ct)
    savings = xt - ct
    x_tp1 = (r) * savings # generating new state with savings growing at rate r

    discounted_utility = β**t * ut
    return x_tp1, discounted_utility

  x, discounted_utility = jax.lax.scan(core, state, inputs)
  G = discounted_utility.sum()
  return -G


@jax.jit
def evaluation(Θ, rng):
  return -L(Θ, rng)


@jax.jit
def update_gradient_descent(Θ, opt_state, rng):
  rng, _ = jax.random.split(rng)
  grad = jax.grad(L)(Θ,rng)
  updates, opt_state = optimizer(lr).update(grad, opt_state)
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state, rng


for iteration in range(1000000):
  rng = jax.random.PRNGKey(0)
  Θ, opt_state, rng = update_gradient_descent(Θ, opt_state, rng)
  

  if iteration % 1000 == 0:
    print(evaluation(Θ, rng))

-768.27893
-443.21417
-442.118
-439.13965
-437.8162
-437.30887
-437.05707
-436.6582
-436.6275
-436.28903
