# PBO against LQR
## Define environment

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np

from pbo.environment.linear_quadratic import LinearQuadraticEnv

low_value = -10
high_value = 10
state_range = 10
action_range = 10
initial_state = 4

env = LinearQuadraticEnv(
    A=np.random.uniform(low_value, high_value, size=(1, 1)),
    B=np.random.uniform(low_value, high_value, size=(1, 1)),
    Q=np.random.uniform(low_value, high_value, size=(1, 1)),
    R=np.random.uniform(low_value, high_value, size=(1, 1)),
    S=np.random.uniform(low_value, high_value, size=(1, 1)),
    max_pos=state_range,
    max_action=action_range,
    initial_state=np.array([initial_state]),
)

## Data collection

In [2]:
from pbo.data_collection.replay_buffer import ReplayBuffer


n_samples = 100

replay_buffer = ReplayBuffer()

state = env.reset()
terminal = False
idx_sample = 0

while idx_sample < n_samples:
    action = np.random.uniform(-action_range, action_range, size=1)
    next_state, reward, terminal, _ = env.step(action)

    replay_buffer.add(state, action, np.array([reward]), next_state)

    if terminal:
        state = env.reset()
        terminal = False

    state = next_state
    idx_sample += 1

In [3]:
from pbo.data_collection.dataloader import DataLoader


batch_size = 8
    
data_loader = DataLoader(replay_buffer, batch_size)

## Building networks

In [4]:
from pbo.agents.q_networks import QFullyConnectedNet
from pbo.agents.pbo_networks import LinearPBONet

n_discretisation_step_state = 100
n_discretisation_step_action = 100

layer_dimension = 3
random_range = 10
gamma = 0.99

Q_network = QFullyConnectedNet(
    layer_dimension=layer_dimension,
    random_range=random_range,
    action_range=action_range,
    n_discretisation_step_action=n_discretisation_step_action,
)
PBO = LinearPBONet(gamma=gamma, q_weights_dimensions=Q_network.q_weights_dimensions)


## Training

In [6]:
n_iteration = 10


for iteration in range(n_iteration):
    random_weights = Q_network.get_random_weights()

    for batch in data_loader:
        PBO.learn_on_batch(batch, random_weights, Q_network)
    
    # monitor the training with:
    Q_network.set_weights(PBO.get_fixed_point())