In [1]:
import numpy as np
from numpy import radians, pi
from lib.cartpolesystem import CartPoleSystem
from lib.cartpoleenv import CartPoleEnv
from lib.colors import Colors
from lib.direct_collocation import DirectCollocation
import time
from lib.controllers import LQR
from time import perf_counter
from matplotlib import pyplot as plt
from scipy.io import savemat

In [2]:
dt = 0.01
g = 9.81

n = 2
system_noise = 0 * np.diag(np.ones(2+2*n))

system = CartPoleSystem(
    (0.0, 0.5, 0.005, -0.8, 0.8, Colors.red),
    (0.2, 8.7e-5, 8.7e-5, 0.02, 0.05, 2400.0, Colors.black),
    [
        (0, 0.2, 0.2, 0.001, Colors.green),
        (0, 0.1, 0.2, 0.001, Colors.blue),
    ],
    g,
    dt,
    "rk4",
    "nonlinear",
    system_noise
)

env = CartPoleEnv(system, dt, g)
env.observation_space.shape

(6, 1)

In [3]:
direct_collocation = DirectCollocation(
    7,
    system.differentiate, 
    env.observation_space.shape[0],
    env.action_space.shape[0],
    env.observation_space.low,
    env.observation_space.high,
    env.action_space.low,
    env.action_space.high,
    0.1
)

x0 = np.vstack([-0.1, 0, radians(180), 0, radians(180), 0])

r = np.vstack(
    [0.1, 0, 0, 0, 0,0]
)


In [4]:
last_update = perf_counter()

obs, _ = env.reset(x0)

time_longer = 10
N_longer = int(time_longer/dt)

end_time = 1
follow_curve = 1

for i in range(N_longer):
    time_left = end_time - i*dt
    N = int(time_left/dt)
    state, control = direct_collocation.make_controller(time_left, None, obs, r)

    for k in range(follow_curve):
        while perf_counter() < last_update + dt:
            pass
        last_update = perf_counter()

        u = np.vstack([control[0][k]])

        obs, reward, done, msg, _ = env.step(u)
        env.render()
env.close()