# Load packages

In [None]:
from pathlib import Path
import pandas as pd
from environment import Santa2022Environment
from utils import *

import torch as th
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from stable_baselines3 import PPO

import matplotlib.pyplot as plt

# Load Image of Christmas card

In [None]:
df_image = pd.read_csv("image.csv")
image = df_to_image(df_image)

In [None]:
plt.imshow(image)
plt.show()

# Load submission files

In [None]:
all_confs = []
for sub_file in Path("./submissions").glob("*.csv"):
    s = pd.read_csv(sub_file.as_posix())
    list_of_confs = s.apply(lambda x: [list(map(int, link.split())) for link in x.configuration.split(";")], axis=1).tolist()
    all_confs.append(list_of_confs)

# Define env

In [None]:
max_iter = 1000
env = Santa2022Environment(image, max_iter=max_iter)

In [None]:
observations, values, actions = [], [], []
for confs in all_confs:
    rewards = []
    obs = env.reset()
    obs["image"] = obs["image"].transpose([2, 0, 1])
    observations.append(obs)
    for conf in confs[1:]:
        action = env.new_confs.index(conf)
        actions.append(action)
        obs, reward, done, info = env.step(action)
        obs["image"] = obs["image"].transpose([2, 0, 1])
        observations.append(obs)

        rewards.append(reward)
        
        if len(rewards) == max_iter:
            values_array = discounted_cumulative_sums(rewards, 0.99)
            values.extend(values_array.tolist())
            rewards = []
            obs = env.reset(conf)
    del observations[-1]

In [None]:
BATCH_SIZE = 32
limit = len(values) // BATCH_SIZE * BATCH_SIZE
santa_dataest = SantaDataset(observations[:limit], actions[:limit], values[:limit])
dataloader = DataLoader(santa_dataest, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [None]:
net = CustomNetwork(feature_dim=128, last_layer_dim_pi=3**8, last_layer_dim_vf=1)

criterion_a = nn.CrossEntropyLoss()
criterion_v = nn.MSELoss()
optimizer = optim.AdamW(net.parameters())

In [None]:
DEVICE = "cpu" # or "cuda" if available
net.to(DEVICE)

In [None]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    running_loss_a = 0.0
    running_loss_v = 0.0
    total = 0.0
    correct = 0.0
    for i, data in enumerate(dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        im, c, a, v = data

        im = im.to(DEVICE).float()
        c = c.to(DEVICE).float()
        a = a.to(DEVICE)
        v = v.to(DEVICE).float()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        output_a, output_v = net(im, c)
        loss_a = criterion_a(output_a, a)
        loss_v = criterion_v(output_v, v)
        loss = loss_a + loss_v
        loss.backward()
        optimizer.step()

        _, predicted = th.max(output_a, 1)
        total += a.size(0)
        correct += (predicted == a).sum().item()

        # print statistics
        running_loss += loss.item()
        running_loss_a += loss_a.item()
        running_loss_v += loss_v.item()
        if i % 200 == 199:    # print every 200 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}, a_loss: {running_loss_a / 200:.3f}, v_loss: {running_loss_v / 200:.3f}, accuracy: {100 * correct // total}')
            running_loss = 0.0
            running_loss_a = 0.0
            running_loss_v = 0.0
            total = 0.0
            correct = 0.0

print('Finished Training')