\setcounter{secnumdepth}{0}

# Assignment 6 - Control Theory and System Identification

    Name: First, Last
    Student #: s...

In [None]:
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jax.random as jr

from typing import Callable
from jax.random import PRNGKey
from tqdm import tqdm

# Exercise 1: Linear Quadratic Gaussian control

Last week we implemented the LQR. However, under partial observability or noisy measurements, the LQR can be made more robust by estimating the current state (also known as filtering). In the exercise, we will implement the LQG, which extends the LQR with optimal filtering.

## Exercise 1.1

Extend the discrete-time stochastic double integrator from last week to have noisy observations and that only the position is observed. You can either show the true state and observations in separate plots, or plot the observations over the true state in the same plot.

## Solution 1.1

In [None]:
dt = 0.01
T = 10
ts = jnp.arange(0, T, dt)

var_x = 0.01
var_y = 0.1
var0 = 0.1
q = r = 0.5

A = jnp.array([[1, dt], [0, 1]])
B = jnp.array([[0], [dt]])
W = dt*var_x*jnp.eye(2)

C = None
V = None

mu0 = jnp.array([1.0,0.0])
Sigma0 = var0 * jnp.eye(2)

key = jr.PRNGKey(0)
key, subkey = jr.split(key)

x_init = None

key, subkey = jr.split(key)
y_init = None

def simulate_step(carry, t):
    x, key = carry
    key, subkey1, subkey2 = jr.split(key, 3)
    noise = None
    next_x = None
    y = None
    return (next_x, key), (next_x, y)

init_carry = (x_init, key)
(_, _), (no_control_xs, no_control_ys) = jax.lax.scan(simulate_step, init_carry, ts[:-1])

# Prepend initial state
no_control_xs = jnp.vstack([x_init, no_control_xs])
no_control_ys = jnp.vstack([y_init, no_control_ys])

## Exercise 1.2

Compute the optimal control solution with the LQG for the partially observed discrete-time stochastic double integrator.

## Solution 1.2

In [None]:
## Control theory utilities
from scipy.linalg import solve_discrete_are

class LQG:
    """       
    Discrete time linear quadratic regulator

    state equation:
    x_n+1 = x_n' A x_n + B u_n + e_n with e_n ~ N(0, W)

    initial state x(0) ~ N(mu0, Sigma0)

    cost function
    J = sum_n [x_n' Q x_n + u_n' R u_n]

    """

    def __init__(self, A: jnp.array, B: jnp.array, W: jnp.array, 
                 Q: jnp.array, R: jnp.array, C: jnp.array, V: jnp.array) -> None:
        
        self.n_state = A.shape[0]
        self.n_control = B.shape[1]

        self.A = A
        self.B = B
        self.W = W

        self.C = C
        self.V = V

        self.Q = Q
        self.R = R

        # solution to the DARE
        self.S = None

        # Feedback gain matrix
        self.L = None

    def control(self, x):
        return

    def reward(self, x, u):
        return
    
    def kalman_filter(self, mu, Sigma, y, uk=0.0):
        """
        Kalman Filter function for state estimation
        
        Input
        mu: np.array        current mean state estimate
        Sigma: np.array     current uncertainty estimate
        y: np.array         last observation
        uk: np.array        last control signal
        
        Return
        muk: np.array       updated mean state estimate
        Sigmak: np.array    updated uncerstainty estimate
        
        """

        # KF prediction step

        muk1 = None

        Sigmak1 = None

        # KF update step
        r = None # innovation
        K = None # optimal Kalman gain
        
        muk = None
        Sigmak = None
        
        return muk, Sigmak
    
Q = None
R = None

lqg = LQG(A, B, W, Q, R, C, V)

key = jr.PRNGKey(0)
key, subkey = jr.split(key)
x_init = None

subkey, key = jax.random.split(subkey)
y_init = None
u_init = None

def simulate_step(carry, t):
    x, mu, Sigma, key = carry
    key, subkey1, subkey2 = jax.random.split(key, 3)

    u = None

    next_x = None
    y = None

    next_mu, next_Sigma = None

    return (next_x, next_mu, next_Sigma, key), (next_x, y, u, next_mu, next_Sigma)

init_carry = (x_init, mu0, Sigma0, key)
(_, _, _, _), (control_xs, control_ys, control_us, control_mus, control_Sigmas) = jax.lax.scan(simulate_step, init_carry, ts[:-1])

# Prepend initial state
control_xs = jnp.vstack([x_init, control_xs])
control_ys = jnp.vstack([y_init, control_ys])
control_us = jnp.vstack([u_init, control_us])

## Exercise 1.3

Show the cummulative reward over time with different levels of observation noise.

## Solution 1.3

In [None]:
# YOUR ANSWER HERE

# Exercise 2: Simulating Continuous time RNNs

In this exercise, we will simulate a continuous-time Recurrent Neural Network (CTRNN) to analyze its behaviour over time. 

\begin{equation} f_{\text{CTRNN}}(x) = \frac{1}{\tau} \cdot \left( -x + W \sigma(\mathbf{x}) + b \right), \end{equation}

where: 
\begin{align*}
    W &= \begin{pmatrix} 4.5 & 1.0 \\ -1.0 & 4.5 \end{pmatrix}\\
    b &= \begin{pmatrix} -2.75 & -1.75  \end{pmatrix} \\
    \tau &= 1.0
\end{align*}

Use jax.vmap to compute the vectors in the phase plane. Plot several trajectories of the CTRNN with different initial conditions in the phase plane.

## Solution 2

In [None]:
# set parameters
w = None
b = None
tau = None

dt = 0.1
T = 10000
ts = jnp.arange(0, T) * dt

In [None]:
def sigmoid(x):
    return

def CTRNN(x, args):
    w, b = args
    return

In [None]:
ts = jnp.arange(0, 100, dt)
solver = None
"""Simulate trajectories """


""" visualize trajectories and phase plane"""



# Exercise 3: System identification with neural networks

Besides simulating a neural network with a given set of weights, one can learn the weights such that the integrated network over time resembles observed data as closely as possible. In the next part, we will create a neural network that learns a dynamical system from a given a set of observed noisy measurements.

## Exercise 3.1
i) Create a function 'initialize_mlp()' initializes a set of network parameters according to the documentation. 

ii) Then, create a function 'network()' that takes a state x and a set of parameters as input, and returns the output of the neural network according to the documentation.

## Solution 3.1

In [None]:
### ANSWER:
def initialize_mlp(layer_sizes, key:PRNGKey, scale:float=1e-2):
    """
    Inputs:
        layer_sizes (tuple) Tuple of shapes of the neural network layers. 
                            Includes the input shape, hidden layer shape, and output layer shape: (input_dim, hidden_dim, ..., output_dim.)
        key (PRNGKey) 
        scale (float) standard deviation of initial weights and biases

    Return: 
        params (List) Tuple of weights and biases - [ (weights_1, biases_1), ..., (weights_n, biases_n) ]
    """
    keys = jr.split(key, 2*len(layer_sizes))
    params = []

    return 

def network(x, params):
    """ Standard MLP.
    
    Inputs:
        params (PyTree) Parameters of the policy network, represented as PyTree. 
        x (D,) input state, where D is the dimensionality of the state observation.
        """
    return

## Exercise 3.2
iii) Initialize a 2 layer network with input and output dimensions of 2, and two hidden layers of 32 hidden unit, using your newly created function 'initialize_mlp()'. 

iv) Initiate an Euler or RK4 step function from the previous exercise, but now instead give your 'network()' as the differential function 'f'.

v) Come up with a reasonable learning rate to train your model with, and a reasonable number of gradient descent steps.

## Solution 3.2

In [None]:
key, subkey = jr.split(jr.PRNGKey(0))

ts = jnp.load('ts.npy') # download this from Brightspace
ys = jnp.load('ys.npy') # download this from Brightspace
data = (ts, ys)

layer_sizes = None
solver = None
params = None
key, subkey = jr.split(key)

In [None]:
import optax

# set training parameters
num_iters = None
lr = None

# Optimization
optim = optax.adam(learning_rate=lr)
state = optim.init(params)

## Exercise 3.3
Define a loss function for your optimization problem, that:

- uses the first observation as an initial condition for the system,
- solves the system using your numerical integration method and the scan function from the first exercise,
- returns the mean squared error loss between the predicted trajectories, and the target observations.

You can use the given code below to run gradient descent with the Adam optimizer based on this loss function. 

## Solution 3.3

In [None]:
def loss(params, data):
    ts, ys = data
    return

In [None]:
# Optimisation. When running this multiple times, be sure to re-initialize your parameters. 
loss_vals = jnp.zeros((num_iters))
loss = jax.jit(loss)
for i in tqdm(range(num_iters)):
    loss_val, loss_gradient = jax.value_and_grad(loss)(params, data)
    updates, state = optim.update(loss_gradient, state, params)
    params = optax.apply_updates(params, updates)
    loss_vals = loss_vals.at[i].set(loss_val)

plt.plot(jnp.arange(num_iters), loss_vals)

## Exercise 3.4
Visualize your trained neural network predictions, and compare it with the observed data. In addition, create a slightly longer time horizon of 70 time units and comment on the forecast of the network: is it in line with what we can reasonably expect of the unknown system?

If the network predictions do not resemble the observations, consider changing the learning rate or number of training iterations during training, or inspecting the correctness of your loss function.

## Solution 3.4

In [None]:
# YOUR ANSWER HERE