# Reinforcement Learning: Cake Eating Problem

In [None]:
import jax
import jax.numpy as jnp
import haiku as hk
from time import sleep
γ = 2.
β = 0.95
x0 = 1

def U(c):
    return c**(1 - γ) / (1 - γ)

In [32]:
# The process visualization
def process_visualization(iteration, total):
    print('\r' + '[Progress]:[%s%s]%.2f%%;' % (
    '███' * int((iteration+1)*20/total), '   ' *
                (20-int((iteration+1)*20/total)),
    float((iteration+1)/total*100)), end='')
    sleep(0.01)

In [33]:
import optax

optimizer = optax.adam
lr = 1e-3
T = 50 * 12
rng = jax.random.PRNGKey(0)


Θ = jnp.array(-10.)


def nnet(x):
    X = jnp.column_stack([x])
    X = hk.Linear(32)(X)
    X = jax.nn.relu(X)
    X = hk.Linear(1)(X)
    X = jnp.squeeze(X)
    return X

init, nnet = hk.without_apply_rng(hk.transform(nnet))
Θ = init(rng, jnp.array(1.))

opt_state = optimizer(lr).init(Θ)

@jax.jit
def L(Θ):

    xt = 1.
    V = 0.

    state = xt
    inputs = jnp.arange(T)

    def core(state, inputs):
        t = inputs
        xt = state

        ct = jax.nn.sigmoid(cx(Θ, xt)-4.) * xt
        ut = U(ct)
        xt = xt - ct
        discounted_utility = β**t * ut
        return xt, discounted_utility

    x, discounted_utility = jax.lax.scan(core, state, inputs)
    V = discounted_utility.sum()
    return -V


@jax.jit
def evaluation(Θ):
    return -L(Θ)


@jax.jit
def update_gradient_descent(Θ, opt_state):
    grad = jax.grad(L)(Θ)
    updates, opt_state = optimizer(lr).update(grad, opt_state)
    Θ = optax.apply_updates(Θ, updates)
    return Θ, opt_state

In [38]:
total = 100
for iteration in range(total):
    process_visualization(iteration, total)
    Θ, opt_state = update_gradient_descent(Θ, opt_state)
print("\n")
print(evaluation(Θ))

[Progress]:[████████████████████████████████████████████████████████████]100.00%;

-1559.7428
