In [1]:
pip install -q optax dm-haiku

Note: you may need to restart the kernel to use updated packages.


Due Date: December 7th
# Problem statement

Consider the stochastic cake-eating problem you solved in the previous assignment. Suppose that instead of investing the wealth entirely in the stock market, you have the option to assign a fraction $\alpha$ of your wealth in the stock market, and the remaining is invested in a risk-free savings account that pays a 1.04 % gross return. Notice that $\alpha$ is bounded below by 0, and bounded above by 1.

Solve for the optimal consumption ($c$) and asset allocation ($\alpha$).

- Print the average sum of discounted rewards (utilities) using 1 million simulations.
 
 - Plot the average consumption-wealth ratio ($c / x)$ for each time period $t=0, 1, ..., 49$

 - Plot the average asset allocation in the risky asset($\alpha)$ for each time period $t=0, 1, ..., 49$

Hint: Starting from the code of the previous assignment, the modifications you have to implement are minimal. Namely:

- The output of the neural network now should be a 2d vector, [link text](https://)corresponding to the consumption-wealth ratio (c / x) and $\alpha$, respectively


In [3]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95

def stock_return(rng, n):
    μs = 0.06
    σs = 0.2
    ε = jax.random.normal(rng, (1, n))
    log_return = μs + σs * ε
    return jnp.exp(log_return).squeeze()

def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50
batch_size = 256


def nnet(x):
    X = jnp.column_stack([x])
    X = hk.Linear(32)(X)
    X = jnp.tanh(X)
    X = hk.Linear(1)(X)
    X = jnp.squeeze(X)
    return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ, n, rng):

    x = jnp.ones(n)  # simulate x "n" times
    G = 0.

    # adding rng_vector
    rng_vector = jax.random.split(rng, T)  # create vector or rng here!

    state = x
    inputs = jnp.arange(T), rng_vector

    def core(state, inputs):
        t, rng = inputs
        xt = state

        ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt
        ut = U(ct)
        savings = xt - ct

        stk_ret = stock_return(rng, n)
        # change this line (just add log return instead of R)
        x_tp1 = stk_ret * savings

        state = x_tp1
        discounted_utility = β**t * ut
        # expectation discounted utility of the simulation??
        return state, discounted_utility.mean()

    _, discounted_utility = jax.lax.scan(core, state, inputs)
    G = discounted_utility.sum()
    return -G


@jax.jit
def evaluation(Θ):
    return -L(Θ, 1000000, rng)  # eveluate at 1 mil path


@jax.jit
def update_gradient_descent(Θ, opt_state, rng):
    rng, _ = jax.random.split(rng)  # get new rng here
    grad = jax.grad(L)(Θ, batch_size, rng)
    updates, opt_state = optimizer(lr).update(grad, opt_state)
    Θ = optax.apply_updates(Θ, updates)
    return Θ, opt_state, rng


for iteration in range(10000):
    Θ, opt_state, rng = update_gradient_descent(Θ, opt_state, rng)

print(evaluation(Θ))

-434.0201
