In [1]:
pip install - q optax dm-haiku

Note: you may need to restart the kernel to use updated packages.


Due Date: November 30
# Problem statement

The code below is similar to the Cake Eating problem code we implemented in class. The differences are:
- Each time interval corresponds to one year (instead of one month)
- The consumption policy function is written as a simple sigle-layer neural network, with tanh activation (instead of the usual relu)

We will interpret the size of the cake as being total wealth, and cake consumption as general consumption. The fraction of wealth not consumed today are the *savings* (line 51). The dynamics of wealth are described by line 54. That line is equivalent to assuming that your savings are invested in a risk-free savings account that pays 0 interest, and therefore has a gross return of 1, denoted by *R* (line 53).




In [2]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95


def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50


def nnet(x):
    X = jnp.column_stack([x])
    X = hk.Linear(32)(X)
    X = jnp.tanh(X)
    X = hk.Linear(1)(X)
    X = jnp.squeeze(X)
    return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ):

    x = 1.
    G = 0.

    state = x
    inputs = jnp.arange(T)

    def core(state, inputs):
        t = inputs
        xt = state

        ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt
        ut = U(ct)
        savings = xt - ct

        R = 1.
        x_tp1 = R * savings

        discounted_utility = β**t * ut
        return x_tp1, discounted_utility

    x, discounted_utility = jax.lax.scan(core, state, inputs)
    G = discounted_utility.sum()
    return -G


@jax.jit
def evaluation(Θ):
    return -L(Θ)


@jax.jit
def update_gradient_descent(Θ, opt_state):
    grad = jax.grad(L)(Θ)
    updates, opt_state = optimizer(lr).update(grad, opt_state)
    Θ = optax.apply_updates(Θ, updates)
    return Θ, opt_state


for iteration in range(1000000):
    Θ, opt_state = update_gradient_descent(Θ, opt_state)

    if iteration % 1000 == 0:
        print(evaluation(Θ))

-1301.506
-869.2983
-846.7147
-839.0068
-833.3893
-830.0215
-828.1758
-826.8768
-825.7528
-824.72284
-823.7802
-822.95996
-822.3214
-821.6879
-821.20917
-820.8489
-820.5056
-820.31714
-819.926
-819.6896
-819.4854
-819.2804
-819.0966
-818.9267
-818.761
-818.6187
-818.4889
-818.3689
-818.206
-818.06696
-817.9734
-817.84827
-817.7638
-817.6399
-817.5612
-817.517
-817.4094
-817.3505
-817.2761
-817.21906
-817.12494
-817.0376
-817.01245
-816.95465
-816.87646
-816.87427
-816.78265
-816.7363
-816.7178
-816.6643
-816.6433
-816.6158
-816.58417
-816.5563
-816.5157
-816.437
-816.4616
-816.41644
-816.34863
-816.3405
-816.2927
-816.2726
-816.25287
-816.21985
-816.20056
-816.203
-816.2161
-816.1862
-816.14526
-816.09863
-816.0674
-816.0504
-816.0377
-816.0652
-816.00604
-815.97614
-815.9579
-815.9514
-815.96045
-815.952
-815.8906
-815.88293
-815.9252
-815.8686
-815.82983
-815.8142
-815.81305
-815.8284
-815.77563
-815.7582
-815.8479
-815.78906
-815.7189
-815.73206
-815.6999
-815.68097
-815.6881
-815.7

-814.50446
-814.5043
-814.5055
-814.505
-814.50476
-814.5044
-814.504
-814.5044
-814.5031
-814.5039
-814.50415
-814.5041
-814.5037
-814.503
-814.5042
-814.5032
-814.5036
-814.5032
-814.50256
-814.50275
-814.503
-814.50354
-814.50323
-814.5035
-814.5028
-814.50305
-814.50354
-814.50195
-814.5025
-814.5032
-814.50165
-814.5025
-814.5023
-814.5017
-814.5024
-814.5023
-814.5021
-814.50116
-814.501
-814.5022
-814.5012
-814.502
-814.50165
-814.50116
-814.5012
-814.50085
-814.5016
-814.5006
-814.50085
-814.501
-814.50024
-814.50085
-814.5004
-814.50073
-814.5
-814.5
-814.4998
-814.5006
-814.5004
-814.49994
-814.4999
-814.5001
-814.50024
-814.4999
-814.5005
-814.4999
-814.4994
-814.5002
-814.4994
-814.4995
-814.4995
-814.5001
-814.4989
-814.4994
-814.49884
-814.4993
-814.4988
-814.4991
-814.4994
-814.49963
-814.49994
-814.4988
-814.49805
-814.4989
-814.49817
-814.4984
-814.4984
-814.4983
-814.4984
-814.4981
-814.49835
-814.49866
-814.4979
-814.4977
-814.49866
-814.49854
-814.4974
-814.49725
-8

Suppose now that your savings are fully invested in the stock market, so the evolution of wealth is now stochastic. The stock market gross return is modeled by the function below:

In [3]:
def stock_return(rng, n):
    μs = 0.06
    σs = 0.2
    ε = jax.random.normal(rng, (1, n))
    log_return = μs + σs * ε
    return jnp.exp(log_return).squeeze()

Write a code to solve for the optimal consumption policy in this environment. 
What is the expceted sum of discounted rewards (value function) resulting from that policy? Use at least 1 million sample paths to estimate that number.

In [4]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95


def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50
batch_size = 256


def nnet(x):
    X = jnp.column_stack([x])
    X = hk.Linear(32)(X)
    X = jnp.tanh(X)
    X = hk.Linear(1)(X)
    X = jnp.squeeze(X)
    return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ, n, rng):

    x = jnp.ones(n)  # simulate x "n" times
    G = 0.

    # adding rng_vector
    rng_vector = jax.random.split(rng, T)  # create vector or rng here!

    state = x
    inputs = jnp.arange(T), rng_vector

    def core(state, inputs):
        t, rng = inputs
        xt = state

        ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt
        ut = U(ct)
        savings = xt - ct

        stk_ret = stock_return(rng, n)
        # change this line (just add log return instead of R)
        x_tp1 = stk_ret * savings

        state = x_tp1
        discounted_utility = β**t * ut
        # expectation discounted utility of the simulation??
        return state, discounted_utility.mean()

    _, discounted_utility = jax.lax.scan(core, state, inputs)
    G = discounted_utility.sum()
    return -G


@jax.jit
def evaluation(Θ):
    return -L(Θ, 1000000, rng)  # eveluate at 1 mil path


@jax.jit
def update_gradient_descent(Θ, opt_state, rng):
    rng, _ = jax.random.split(rng)  # get new rng here
    grad = jax.grad(L)(Θ, batch_size, rng)
    updates, opt_state = optimizer(lr).update(grad, opt_state)
    Θ = optax.apply_updates(Θ, updates)
    return Θ, opt_state, rng


for iteration in range(10000):
    Θ, opt_state, rng = update_gradient_descent(Θ, opt_state, rng)

print(evaluation(Θ))

-434.0201
