In [None]:
from matplotlib.pylab import *
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.lines import Line2D
π = pi

In [None]:
style.use(['dark_background', 'bmh'])
#%matplotlib notebook
%matplotlib inline

Car-trailer diagram (inverted image `res/car-trainer-k.png` available as well):
![car-trailer](res/car-trailer-w.png)

Car-trailer equation:
\begin{align}
\dot x &= s \cos \theta_0 \\
\dot y &= s \sin \theta_0 \\
\dot \theta_0 &= \frac{s}{L} \tan \phi \\
\dot \theta_1 &= \frac{s}{d_1} \sin(\theta_1 - \theta_0)
\end{align}
where $s$: signed speed, $\phi$: negative steering angle,

In [None]:
ENVIRONMENT_BBOX = [0, 40, -10, 10]
STEERING_ANGLE_RANGE_rad = pi / 4

class Truck:
    def __init__(self, display=False):

        self.W = 1  # car and trailer width, for drawing only
        self.L = 1 * self.W  # car length
        self.d = 4 * self.L  # d_1
        self.s = -0.1  # speed
        self.display = display
        
        self.box = ENVIRONMENT_BBOX
        if self.display:
            self.f = figure(figsize=(10, 5), num='The truck backer-upper', facecolor='none')
            self.ax = self.f.add_axes([0.01, 0.01, 0.98, 0.98], facecolor='black')
            self.patches = list()
            
            self.ax.axis('equal')
            b = self.box
            self.ax.axis([b[0] - 1, b[1], b[2], b[3]])
            self.ax.set_xticks([]); self.ax.set_yticks([])
            self.ax.axhline(); self.ax.axvline()

        self.reset()
    
    def reset(self, ϕ=0):
        self.ϕ = ϕ  # car initial steering angle
        
        # self.θ0 = deg2rad(30)  # car initial direction
        # self.θ1 = deg2rad(-30)  # trailer initial direction
        # self.x, self.y = 20, -5  # initial car coordinates
        
        self.θ0 = random() * 2 * π  # 0 <= ϑ₀ < 2π
        self.θ1 = (random() - 0.5) * π / 2 + self.θ0  # -π/4 <= ϑ₁ - ϑ₀ < π/4
        self.x = (random() * .75 + 0.25) * self.box[1]
        self.y = (random() - 0.5) * (self.box[3] - self.box[2])
        
        # If poorly initialise, then re-initialise
        if not self.valid():
            self.reset(ϕ)
        
        # Draw, if display is True
        if self.display: self.draw()
    
    def step(self, ϕ=0, dt=1):
        
        # Check for illegal conditions
        if self.is_jackknifed():
            print('The truck is jackknifed!')
            return
        
        if self.is_offscreen():
            print('The car or trailer is off screen')
            return
        
        self.ϕ = ϕ
        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()
        
        # Perform state update
        self.x += s * cos(θ0) * dt
        self.y += s * sin(θ0) * dt
        self.θ0 += s / L * tan(ϕ) * dt
        self.θ1 += s / d * sin(θ0 - θ1) * dt
        
        return (self.x, self.y, self.θ0, *self._traler_xy(), self.θ1)
    
    def state(self):
        return (self.x, self.y, self.θ0, *self._traler_xy(), self.θ1)
    
    def _get_atributes(self):
        return (
            self.x, self.y, self.W, self.L, self.d, self.s,
            self.θ0, self.θ1, self.ϕ
        )
    
    def _traler_xy(self):
        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()
        return x - d * cos(θ1), y - d * sin(θ1)
        
    def is_jackknifed(self):
        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()
        return abs(θ0 - θ1) * 180 / π > 90
    
    def is_offscreen(self):
        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()
        
        x1, y1 = x + 1.5 * L * cos(θ0), y + 1.5 * L * sin(θ0)
        x2, y2 = self._traler_xy()
        
        b = self.box
        return not (
            b[0] <= x1 <= b[1] and b[2] <= y1 <= b[3] and
            b[0] <= x2 <= b[1] and b[2] <= y2 <= b[3]
        )
        
    def valid(self):
        return not self.is_jackknifed() and not self.is_offscreen()
        
    def draw(self):
        if not self.display: return
        if self.patches: self.clear()
        self._draw_car()
        self._draw_trailer()
        self.f.canvas.draw()
            
    def clear(self):
        for p in self.patches:
            p.remove()
        self.patches = list()
        
    def _draw_car(self):
        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()
        ax = self.ax
        
        x1, y1 = x + L / 2 * cos(θ0), y + L / 2 * sin(θ0)
        bar = Line2D((x, x1), (y, y1), lw=5, color='C2', alpha=0.8)
        ax.add_line(bar)

        car = Rectangle(
            (x1, y1 - W / 2), L, W, angle=0, color='C2', alpha=0.8, transform=
            matplotlib.transforms.Affine2D().rotate_deg_around(x1, y1, θ0 * 180 / π) +
            ax.transData
        )
        ax.add_patch(car)

        x2, y2 = x1 + L / 2 ** 0.5 * cos(θ0 + π / 4), y1 + L / 2 ** 0.5 * sin(θ0 + π / 4)
        left_wheel = Line2D(
            (x2 - L / 4 * cos(θ0 + ϕ), x2 + L / 4 * cos(θ0 + ϕ)),
            (y2 - L / 4 * sin(θ0 + ϕ), y2 + L / 4 * sin(θ0 + ϕ)),
            lw=3, color='C5', alpha=1)
        ax.add_line(left_wheel)

        x3, y3 = x1 + L / 2 ** 0.5 * cos(π / 4 - θ0), y1 - L / 2 ** 0.5 * sin(π / 4 - θ0)
        right_wheel = Line2D(
            (x3 - L / 4 * cos(θ0 + ϕ), x3 + L / 4 * cos(θ0 + ϕ)),
            (y3 - L / 4 * sin(θ0 + ϕ), y3 + L / 4 * sin(θ0 + ϕ)),
            lw=3, color='C5', alpha=1)
        ax.add_line(right_wheel)
        
        self.patches += [car, bar, left_wheel, right_wheel]
        
    def _draw_trailer(self):
        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()
        ax = self.ax
            
        x, y = x - d * cos(θ1), y - d * sin(θ1) - W / 2
        trailer = Rectangle(
            (x, y), d, W, angle=0, color='C0', alpha=0.8, transform=
            matplotlib.transforms.Affine2D().rotate_deg_around(x, y + W/2, θ1 * 180 / π) +
            ax.transData
        )
        ax.add_patch(trailer)
        
        self.patches += [trailer]

In [None]:
truck = Truck(display=True)

In [None]:
truck.reset()

In [None]:
import torch
import torch.nn as nn
from torch.optim import SGD
from tqdm import tqdm

In [None]:
# Build expert data set

episodes = 10
inputs = list()
outputs = list()
truck = Truck(); episodes = 10_000  # uncooment for creating the data set

for episode in tqdm(range(episodes)):
    
    truck.reset()
    
    while truck.valid():
        initial_state = truck.state()
        ϕ = (random() - 0.5) * π / 2
        inputs.append((ϕ, *initial_state))
        outputs.append(truck.step(ϕ))
        truck.draw()

In [None]:
len(inputs), len(outputs)

In [None]:
state_size = 6
steering_size = 1
hidden_units_e = 45

emulator = nn.Sequential(
    nn.Linear(steering_size + state_size, hidden_units_e),
    nn.ReLU(),
    nn.Linear(hidden_units_e, state_size)
)

optimiser_e = SGD(emulator.parameters(), lr=0.005)
criterion = nn.MSELoss()

In [None]:
tensor_inputs = torch.Tensor(inputs)
tensor_outputs = torch.Tensor(outputs)

In [None]:
mean = tensor_inputs.mean(0)
std = tensor_inputs.std(0)
tensor_inputs = (tensor_inputs - mean) / std
tensor_outputs = (tensor_outputs - mean[1:]) / std[1:]

In [None]:
# Split the data into 80:20 for test:train.
test_size = int(len(tensor_inputs) * 0.8)
print(len(tensor_inputs), test_size)

train_inputs = tensor_inputs[:test_size]
train_outputs = tensor_outputs[:test_size]
test_inputs = tensor_inputs[test_size:]
test_outputs = tensor_outputs[test_size:]

In [None]:
len(train_inputs)

In [None]:
# Emulator training
cnt = 0
for i in torch.randperm(len(train_inputs)):
    ϕ_state = train_inputs[i]
    next_state_prediction = emulator(ϕ_state)
    
    next_state = train_outputs[i]
    loss = criterion(next_state_prediction, next_state)
    
    optimiser_e.zero_grad()
    loss.backward()
    optimiser_e.step()
    
    if cnt == 0 or (cnt + 1) % 1000 == 0:
        print(f'{cnt + 1:4d} / {len(train_inputs)}, {loss.item():.10f}', end='\r')
    cnt += 1

In [None]:
# Test
total_loss = 0
with torch.no_grad():
    for idx, ϕ_state in enumerate(test_inputs):
        next_state_prediction = emulator(ϕ_state)

        next_state = test_outputs[idx]
        total_loss += criterion(next_state_prediction, next_state).item()

ave_test_loss = total_loss/test_size
print(f'Test loss: {ave_test_loss:.10f}')

In [None]:
# Freeze emulator
for param in emulator.parameters():
    param.requires_grad = False

In [None]:
# def pick_action(controller, state, epsilon):
#     if torch.rand(1) < epsilon:
#         return torch.rand(1)
#     # FIXME: the model returns the expected total reward for each action.
#     # We need to return the action with the highest expected reward.
#     return controller(state)


# # Unit tests for take_action
# controller = lambda s: 2.0
# action_random = pick_action(controller, state=(1, 2, 3), epsilon=1.0)
# assert 0 <= action_random < 1, f"{action_random} not in range [0, 1)"
# action_controller = pick_action(controller, state=(1, 2, 3), epsilon=0.0)
# assert action_controller == 2

In [None]:
# def episode(truck, emulator, controller, max_steps=100):
#     truck.reset()
#     steps = 0
#     while truck.valid() and steps < max_steps:
#         steering_angle = controller(torch.tensor(
#             truck.state(), requires_grad=True
#         ).to(torch.float))
#         # controller output varies between -1 and 1.
#         # Scale that to +/- pi/4.
#         steering_angle_rad = steering_angle * pi / 4
#         emulator_input = torch.tensor(
#             truck.state() + (steering_angle,),
#         ).to(torch.float)
#         with torch.no_grad():
#             emulated_state = emulator(emulator_input).detach()
#         truck.x = emulated_state[0]
#         truck.y = emulated_state[1]
#         truck.θ0 = emulated_state[2]
#         truck.θ1 = emulated_state[5]
#         steps += 1
#     return truck.state()

In [None]:
# The actor picks an action (steering angle) from a normal distribution.
# The mean and standard deviation of the distribution are learned.
class Actor(nn.Module):
    def __init__(self, state_size, hidden_sizes):
        super().__init__()
        hidden_input = state_size
        self.layers = torch.nn.Sequential()
        for hidden_output in hidden_sizes:
            self.layers.append(nn.Linear(hidden_input, hidden_output))
            self.layers.append(torch.nn.ReLU())
            self.layers.append(torch.nn.LayerNorm(hidden_output))
            hidden_input = hidden_output
        self.mean_layer = nn.Linear(hidden_output, 1)
        self.std_layer = nn.Linear(hidden_output, 1)

    def forward(self, state):
        hidden = self.layers(state)
        # Apply tanh so mean varies from -1 to 1
        mean = torch.tanh(self.mean_layer(hidden))
        # Apply sigmoid so STD varies from 0 to 1
        std = torch.sigmoid(self.std_layer(hidden))
        return mean, std

In [None]:
# The critic estimates the future reward from the given state.
class Critic(nn.Module):
    def __init__(self, state_size, hidden_sizes):
        super().__init__()
        hidden_input = state_size
        self.layers = torch.nn.Sequential()
        for hidden_output in hidden_sizes:
            self.layers.append(nn.Linear(hidden_input, hidden_output))
            self.layers.append(torch.nn.ReLU())
            self.layers.append(torch.nn.LayerNorm(hidden_output))
            hidden_input = hidden_output
        self.layers.append(torch.nn.Linear(hidden_output, 1))

    def forward(self, state):
        return self.layers(state)

In [None]:
# class ActorCritic(nn.Module):
#     def __init__(self, state_size, hidden_sizes):
#         super().__init__()
#         hidden_input = state_size
#         self.hidden_layers = torch.nn.Sequential()
#         for hidden_output in hidden_sizes:
#             self.hidden_layers.append(nn.Linear(hidden_input, hidden_output))
#             self.hidden_layers.append(torch.nn.ReLU())
#             self.hidden_layers.append(torch.nn.LayerNorm(hidden_output))
#             hidden_input = hidden_output
#         self.actor_mean = nn.Linear(hidden_output, 1)
#         self.actor_std = nn.Linear(hidden_output, 1)
#         self.critic = nn.Linear(hidden_output, 1)

#     def forward(self, state):
#         hidden = self.hidden_layers(state)
#         # Apply tanh so mean varies from -1 to 1
#         mean = torch.tanh(self.actor_mean(hidden))
#         # Apply sigmoid so STD varies from 0 to 1
#         std = torch.sigmoid(self.actor_std(hidden))
#         value = self.critic(hidden)
#         return mean, std, value

In [None]:
desired_state = torch.tensor([
    0.,  # trailer_x
    0.,  # trailer_y
    0.,  # trailer_theta
])
worst_state = torch.tensor([
    max(abs(x) for x in ENVIRONMENT_BBOX[:2]),
    max(abs(x) for x in ENVIRONMENT_BBOX[2:]),
    STEERING_ANGLE_RANGE_rad,
])
# Use smooth L1 (~Huber) loss instead of MSE to avoid exploding gradients
loss_func = torch.nn.SmoothL1Loss(reduction="none")
worst_error = loss_func(desired_state, worst_state).mean()

def compute_reward(truck_state):
    """ Calculate the reward for the given truck state(s)

    truck_state is a 1D tensor of size 6 or a 2D tensor of shape Bx6, where B is
    the batch size.
    
    The output is a 1D tensor of size B. If the input tensor is 1D, the output
    is a 1D tensor of size 1 (i.e. B is considered to be 1).
    """
    if truck_state.ndim < 2:
        truck_state = truck_state.unsqueeze(0)
    relevant_truck_state = truck_state[:, 3:]
    state_error = loss_func(
        desired_state.repeat(truck_state.size(0), 1),
        relevant_truck_state,
    ).mean(dim=1)
    # Subtract current loss from worst possible
    # so being close to the goal produces a high reward
    return worst_error - state_error

# def compute_reward(truck_state):
#     relevant_truck_state = truck_state[3:]
#     return -loss_func(desired_state, relevant_truck_state)


# Unit tests for compute_reward()
def test_reward(test_state, expected_reward):
    reward = compute_reward(test_state)
    assert torch.allclose(reward, expected_reward), (
        f"{reward=} != {expected_reward=}"
    )

test_reward(torch.zeros(state_size), worst_error)
test_reward(torch.zeros(4, state_size), worst_error.repeat(4))

test_state = torch.cat((torch.zeros(3), worst_state))
test_reward(test_state, torch.tensor(0.0))

test_state = torch.tensor((0, 0, 1, 0, 0, 0))
test_reward(test_state, worst_error)

test_state = torch.tensor((0, 0, 0, 1, 0, 0))
test_reward(test_state, worst_error - 1/6)

test_state = torch.tensor((0, 0, 0, 1, 1, 0))
test_reward(test_state, worst_error - 1/3)

test_state = torch.tensor((0, 0, 0, 2, 1, 0))
test_reward(test_state, worst_error - 2/3)

test_state = torch.tensor((0, 0, 0, 2, -1, -3))
test_reward(test_state, worst_error - 3/2)

In [None]:
# Same math as the Truck class but in a batch-friendly, functional form rather
# than an object-oriented one

# Constants copied from Truck class
W = 1  # car and trailer width
L = 1 * W  # car length
d = 4 * L
s = -0.1
dt = 1


def random_truck_state(batch_size):
    """ Return a Bx6 tensor of random truck states """
    state = torch.rand(batch_size, state_size)
    # The positional ranges are adjusted slightly compared to the Truck class to
    # ensure the state is always valid
    state[:, 0] = (state[:, 0] * 0.75 + 0.25) * (ENVIRONMENT_BBOX[1] - d)  # x0
    state[:, 1] = (state[:, 1] - 0.5) * (  # y0
        ENVIRONMENT_BBOX[3] - ENVIRONMENT_BBOX[2] - 2 * d
    )
    state[:, 2] *= 2 * π  # 0 <= θ0 < 2π
    state[:, 5] = (state[:, 5] - 0.5) * π/2 + state[:, 2]  # -π/4 <= θ1 - θ0 < π/4
    state[:, 3] = state[:, 0] - d * torch.cos(state[:, 5])  # x1
    state[:, 4] = state[:, 1] - d * torch.sin(state[:, 5])  # y1
    return state


def is_valid(truck_state):
    """ Calculate whether the given truck state(s) is/are valid

    truck_state is a 1D tensor of size 6 or a 2D tensor of shape Bx6, where B is
    the batch size.
    
    The output is a 1D, boolean tensor of size B. If the input tensor is 1D, the
    output is a 1D tensor of size 1 (i.e. B is considered to be 1).
    """
    if truck_state.ndim < 2:
        truck_state = truck_state.unsqueeze(0)
    # Make the various positions and angles the first dimension so it's easier
    # to split them into separate tensors.
    truck_state = truck_state.transpose(0, 1)
    (x0, y0, θ0, trailer_x, trailer_y, θ1) = truck_state
    jackknifed = torch.abs(θ0 - θ1) > π / 2

    cab_x = x0 + 1.5 * L * torch.cos(θ0)
    cab_y = y0 + 1.5 * L * torch.sin(θ0)
    bbox = ENVIRONMENT_BBOX
    # pytorch doesn't like to do a < x < b, so we have to do (a < x) & (x < b)
    in_bounds = (
        (bbox[0] <= cab_x) & (cab_x <= bbox[1])
        & (bbox[2] <= cab_y) & (cab_y <= bbox[3])
        & (bbox[0] <= trailer_x) & (trailer_x <= bbox[1])
        & (bbox[2] <= trailer_y) & (trailer_y <= bbox[3])
    )
    return ~jackknifed & in_bounds
    

def analytical_world_model(steering_angle_and_state):
    """ Update the environment by taking a single step
    
    steering_angle_and_state is a 1D tensor of size 7 or a 2D tensor of shape
    Bx7, where B is the batch size. It represents the desired steering angle and
    the current truck state.
    
    The output is a 1D tensor of size 6 or a 2D tensor of shape Bx6 representing
    the new truck state.
    """
    if steering_angle_and_state.ndim == 1:
        steering_angle_and_state = steering_angle_and_state.unsqueeze(0)
    # Make the various positions and angles the first dimension so it's easier
    # to split them into separate tensors.
    steering_angle_and_state = steering_angle_and_state.transpose(0, 1)
    (ϕ, x0, y0, θ0, x1, y1, θ1) = steering_angle_and_state
    x0_ = x0 + s * torch.cos(θ0) * dt
    y0_ = y0 + s * torch.sin(θ0) * dt
    θ1_ = θ1 + s / d * sin(θ0 - θ1) * dt
    θ0_ = θ0 + s / L * tan(ϕ) * dt
    x1_ = x0 - d * torch.cos(θ1)
    y1_ = y0 - d * torch.sin(θ1)
    return torch.stack((x0_, y0_, θ0_, x1_, y1_, θ1_), dim=1)


# Unit tests for analytical_world_model
def test_world_model(starting_state, expected_next_state):
    next_state = analytical_world_model(starting_state)
    assert torch.allclose(
        torch.as_tensor(next_state), torch.as_tensor(expected_next_state)), (
        f"{next_state=} != {expected_next_state=}"
    )

starting_state = torch.tensor([0.0, 0.0, 0.0, 0.0, -3.9, 0.0, 0.0])
expected_next_state = torch.tensor([-0.1, 0.0, 0.0, -4.0, 0.0, 0.0])
test_world_model(starting_state, expected_next_state)

starting_state = starting_state.unsqueeze(0).repeat(16, 1)
expected_next_state = expected_next_state.unsqueeze(0).repeat(16, 1)
test_world_model(starting_state, expected_next_state)

# Test that random starting states are always valid
random_states = random_truck_state(1000)
invalid = ~is_valid(random_states)
invalid_count = torch.count_nonzero(invalid).item()
assert invalid_count == 0, f"{invalid_count=}\n{random_states[invalid]=}"

In [None]:
# columns: (epochs, batch_size, training_time)
training_times = torch.zeros((0, 3))

In [None]:
# Here you need to insert the code for training the controller
# by using the emulator for backpropagation

# If you succeed, feel free to send a PR

# Things to try
#   RL algorithms
#       Q-learning: requires discrete action space?
#       Q actor-critic
#       Advantage actor-critic
#   update every step instead of at end of episode
#   explore/exploit: take random action with decreasing probability
#   save state transitions to memory and take batches from memory
#       better suited to critic since it's basically Q learner?
#   prevent actor from being over-confident. add entropy metric to loss.
#   model params become NaN
#       learning rate
#       smaller loss/reward
#       normalize
#   low episode length limit to start, increase as model trains
#   compare episode losses to number of steps
#   different discount factors and episode length limits

from collections import deque
import time

NUM_EPOCHS = int(2**14)
MAX_STEPS_PER_EPISODE = 100
TIME_DISCOUNT = 0.95
BATCH_SIZE = 8

actor = Actor(state_size, hidden_sizes=[25, 50])
critic = Critic(state_size, hidden_sizes=[25, 50])
actor_opt = torch.optim.Adam(actor.parameters(), lr=1e-6)
critic_opt = torch.optim.Adam(critic.parameters(), lr=1e-2)

world_model = analytical_world_model  # Use exact world model
# world_model = emulator  # Use trained emulator

NUM_LOSSES_TO_PLOT = 100
log_every_n_epochs = max(NUM_EPOCHS // NUM_LOSSES_TO_PLOT, 1)
losses_to_plot = []
entropies_to_plot = []
NUM_EPISODES_TO_PLOT = 25
longest_episodes = torch.full(
    (NUM_EPISODES_TO_PLOT, MAX_STEPS_PER_EPISODE, state_size),
    torch.nan,
    dtype=torch.float,
)
longest_lengths = torch.zeros(NUM_EPISODES_TO_PLOT, dtype=torch.int)

start_time = time.perf_counter()
for i in range(NUM_EPOCHS):
    current_state = random_truck_state(BATCH_SIZE)
    # Store actor and critic outputs separately since pytorch complains about
    # visiting the same tensor twice when backpropagating (even though the
    # relevant elements are different).
    # Tensor of action log probabilities
    episode_actor = torch.zeros(
        (BATCH_SIZE, MAX_STEPS_PER_EPISODE), dtype=torch.float
    )
    # Tensor of critic values for each state
    episode_critic = torch.zeros_like(episode_actor)
    # Tensor of (reward, done)
    episode_reward_done = torch.zeros(
        (BATCH_SIZE, MAX_STEPS_PER_EPISODE, 2), dtype=torch.float
    )
    # Overwrite steps from shortest saved episode
    saved_episodes_index = torch.min(longest_lengths, dim=0).indices.item()
    longest_episodes[saved_episodes_index, :, :] = torch.nan

    for step in range(MAX_STEPS_PER_EPISODE):
        # Only consider the first episode in the batch. Could instead consider
        # the longest episode in the batch, but that's more complicated without
        # much practical benefit.
        longest_episodes[saved_episodes_index, step, :] = current_state[0,:]

        action_mean, action_std = actor(current_state)
        action_mean = action_mean
        action_std = action_std
        critic_value = critic(current_state).squeeze(1)
        action_distribution = torch.distributions.normal.Normal(
            action_mean, action_std
        )
        action_sample = action_distribution.rsample()

        with torch.no_grad():
            steering_angle = STEERING_ANGLE_RANGE_rad * action_sample
            world_model_input = torch.cat((steering_angle, current_state), dim=1)
            assert world_model_input.shape == (
                BATCH_SIZE, steering_size + state_size
            )
            next_state = world_model(world_model_input)
            reward = compute_reward(next_state)

        episode_actor[:, step] = action_distribution.log_prob(
            action_sample
        ).squeeze(1)
        # These values done need to be calculated on each iteration. They could
        # be calculated after the episode is over. Would that be faster?
        episode_critic[:, step] = critic_value
        episode_reward_done[:, step, 0] = reward
        episode_reward_done[:, step, 1] = ~is_valid(current_state)
        if torch.all(episode_reward_done[:, step, 1]):
            # All states in batch invalid, break out early
            episode_actor = episode_actor[:, :step + 1]
            episode_critic = episode_critic[:, :step + 1]
            episode_reward_done = episode_reward_done[:, :step + 1, :]
            break
        current_state = next_state

    # Episode over. Calculate return (discounted total reward) and loss.
    with torch.no_grad():
        returns = torch.zeros(
            (BATCH_SIZE, episode_reward_done.size(1)), dtype=float
        )
        # Use critic to estimate remaining reward
        next_return = critic(next_state).detach().squeeze(1)
        for j in range(returns.size(1) - 1, -1, -1):
            instant_reward = episode_reward_done[:, j, 0]
            done = episode_reward_done[:, j, 1].unsqueeze(1)
            next_return = (instant_reward + TIME_DISCOUNT * next_return)
            returns[:, j] = next_return
            # Zero out any reward following an invalid state
            returns[:, j:] *= 1.0 - done

    advantage = returns - episode_critic

    actor_loss = (-episode_actor * advantage.detach()).mean()
    # Equivalent to RMS error between critic value and return
    critic_loss = torch.sqrt((advantage * advantage).mean())

    actor_opt.zero_grad()
    actor_loss.backward()
    actor_opt.step()
    critic_opt.zero_grad()
    critic_loss.backward()
    critic_opt.step()

    if i % log_every_n_epochs == 0:
        print(
            f"Epoch {i:5d}:",
            f"length {step:3d}",
            f"actor loss {actor_loss.item():7.3f},",
            f"critic loss {critic_loss.item():7.3f}",
            end="\r",
        )
        losses_to_plot.append((i, actor_loss.item(), critic_loss.item()))
        entropies_to_plot.append((i, action_distribution.entropy().mean().item()))

    longest_lengths[saved_episodes_index] = step

stop_time = time.perf_counter()
training_time = stop_time - start_time
print(f"\nTraining took {training_time:.3f} seconds")
training_times = torch.cat(
    (training_times, torch.tensor([[NUM_EPOCHS, BATCH_SIZE, training_time]]))
)

losses_to_plot = torch.as_tensor(losses_to_plot)
_, ax1 = plt.subplots()
ax1.plot(losses_to_plot[:, 0], losses_to_plot[:, 1], color="C0", label="Actor")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Actor Loss")
ax2 = ax1.twinx()
ax2.plot(losses_to_plot[:, 0], losses_to_plot[:, 2], color="C1", label="Critic")
ax2.set_label("Critic Loss")
ax2.legend(handles=ax1.lines + ax2.lines, labelcolor="black")
ax2.grid(False)

In [None]:
# Moving average filter by convolving with window of 1s
filter_window = 11
with torch.no_grad():
    filtered_loss = torch.nn.functional.conv1d(
        losses_to_plot[:, 1:].T,
        weight=torch.ones((2, 1, filter_window)),
        groups=2,
    ).T / filter_window
_, ax1 = plt.subplots()
ax1.plot(
    losses_to_plot[:-filter_window + 1, 0],
    filtered_loss[:, 0],
    color="C0",
    label="Actor",
)
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Actor Loss")
ax2 = ax1.twinx()
ax2.plot(
    losses_to_plot[:-filter_window + 1, 0],
    filtered_loss[:, 1],
    color="C1",
    label="Critic",
)
ax2.set_label("Critic Loss")
ax2.legend(handles=ax1.lines + ax2.lines, labelcolor="black")
ax2.grid(False)
ax1.set_title("Filtered Loss")

In [None]:
combined_loss = losses_to_plot[:, 1:].sum(1)
filtered_combined_loss = filtered_loss.sum(1)
len_diff = len(combined_loss) - len(filtered_combined_loss)
plt.plot(
    losses_to_plot[:, 0],
    combined_loss,
    label="Combined Loss",
)
plt.plot(
    losses_to_plot[len_diff // 2 : len_diff // 2 + len(filtered_combined_loss), 0],
    filtered_combined_loss,
    label="Filtered Combined Loss",
)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(labelcolor="Black")

In [None]:
entropies_to_plot = torch.as_tensor(entropies_to_plot)
plt.plot(entropies_to_plot[:, 0], entropies_to_plot[:, 1])
plt.xlabel("Epoch")
plt.ylabel("Action Distribution Entropy")

In [None]:
def plot_episode(states, axis=None):
    states = states.numpy()
    if axis is None:
        _, axis = plt.subplots()
    axis.scatter(
        states[:,0], states[:,1],
        # Color points according to order
        c=np.arange(len(states)),
        vmin=0, vmax=MAX_STEPS_PER_EPISODE,
    )

In [None]:
print(longest_episodes.shape)
print(longest_lengths)
fig, axes = plt.subplots(
    nrows=5, ncols=5, sharex=True, sharey=True, figsize=(12, 8),
)
for k, ax in enumerate(axes.flatten()):
    plot_episode(longest_episodes[k], axis=ax)