In [1]:
from src.imitation.dataset import MyDataset, TransformAction


train_ds = MyDataset("FCFS_training.parquet.gzip", transform=TransformAction())
validation_ds = MyDataset(
    "FCFS_validation.parquet.gzip", transform=TransformAction())

In [2]:
from torch.utils.data import DataLoader
from src.imitation.utils import collate_to_float32


train_loader = DataLoader(
    train_ds,
    batch_size=256,
    shuffle=True,
    collate_fn=collate_to_float32,
)

validation_loader = DataLoader(
    validation_ds,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_to_float32,
)

In [None]:
from src.cleanRL.agent import Agent
import torch.optim as optim
import torch
import numpy as np

learning_rate = 1e-5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sample = train_ds[0]

agent = Agent(
    observation_shape=np.prod(sample["observation"].shape),
    action_shape=np.prod(sample["action"].shape)
).to(device)

optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

In [6]:
import torch.nn as nn
actor_loss = nn.MSELoss()

In [7]:
def calc_val_loss(agent, validation_loader):
    actions = []

    true_actions = []

    for x in validation_loader:
        action, log_prob, entropy, value = agent.get_action_and_value(
            x["observation"])

        actions.append(action)
        true_actions.append(x["action"])

    loss = actor_loss(torch.vstack(actions), torch.vstack(true_actions))

    return loss.item()

In [8]:
losses = []
validation_losses = []
val_steps = []

epochs = 40
for epoch in range(epochs):
    for step, x in enumerate(train_loader):
        action, _, _, _ = agent.get_action_and_value(x["observation"])

        loss = actor_loss(action, x["action"])
        losses.append(loss.item())

        if step % 100 == 0:
            agent.eval()
            validation_losses.append(calc_val_loss(agent, validation_loader))
            agent.train()
            val_steps.append(step)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
import matplotlib.pyplot as plt
N = 20

plt.plot(losses, label="loss")
plt.plot(np.convolve(losses, np.ones(N)/N, mode='valid'), label="running avg")
plt.plot(validation_losses, val_steps, label="val loss")
plt.plot(np.convolve(validation_losses, np.ones(N)/N, mode='valid'),
         val_steps, label="val running avg")
plt.legend()

plt.savefig("loss.png", dpi=1200)
# plt.close()