# Load packages

In [None]:
from pathlib import Path
from environment import Santa2022Environment
from utils import *

import pandas as pd
from stable_baselines3 import PPO

import matplotlib.pyplot as plt

# Load Image of Christmas card

In [None]:
df_image = pd.read_csv("image.csv")
image = df_to_image(df_image)

In [None]:
plt.imshow(image)
plt.show()

# Load submission files

In [None]:
all_confs = []
for sub_file in Path("./submissions").glob("*.csv"):
    s = pd.read_csv(sub_file.as_posix())
    list_of_confs = s.apply(lambda x: [list(map(int, link.split())) for link in x.configuration.split(";")], axis=1).tolist()
    all_confs.append(list_of_confs)

# Define env

In [None]:
max_iter = 1000
env = Santa2022Environment(image, max_iter=max_iter)

In [None]:
observations, values, actions = [], [], []
for confs in all_confs:
    rewards = []
    obs = env.reset()
    obs["image"] = obs["image"].transpose([2, 0, 1])
    observations.append(obs)
    for conf in confs[1:]:
        action = env.new_confs.index(conf)
        actions.append(action)
        obs, reward, done, info = env.step(action)
        obs["image"] = obs["image"].transpose([2, 0, 1])
        observations.append(obs)

        rewards.append(reward)
        
        if len(rewards) == max_iter:
            values_array = discounted_cumulative_sums(rewards, 0.99)
            values.extend(values_array.tolist())
            rewards = []
            obs = env.reset(conf)
    del observations[-1]

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
BATCH_SIZE = 32

In [None]:
class SantaDataset(Dataset):
    """Santa dataset."""

    def __init__(self, observations, actions, values):

        self.observations = observations
        self.actions = actions
        self.values = values

    def __len__(self):
        return len(self.observations)

    def __getitem__(self, idx):
        observation = self.observations[idx]
        image = observation["image"]
        conf = observation["conf"]
        action = self.actions[idx]
        reward = self.values[idx]

        return image, conf, action, reward

limit = len(observations) // BATCH_SIZE * BATCH_SIZE
santa_dataest = SantaDataset(observations[:limit], actions[:limit], values[:limit]) 

In [None]:
dataloader = DataLoader(santa_dataest, batch_size=32, shuffle=True, num_workers=0)

In [None]:
import torch as th
from torch import nn

In [None]:
class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the feature extractor.

    :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
    :param last_layer_dim_pi: (int) number of units for the last layer of the policy network
    :param last_layer_dim_vf: (int) number of units for the last layer of the value network
    """

    def __init__(
        self,
        feature_dim: int,
        n_input_channels: int = 3,
        last_layer_dim_pi: int = 64,
        last_layer_dim_vf: int = 64
    ):
        super(CustomNetwork, self).__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(
                th.rand(1, 3, 257, 257).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, feature_dim), nn.ReLU())
        
        self.conf_linear = nn.Sequential(nn.Linear(16, feature_dim), nn.ReLU())
        
        

        # Policy network
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim*2, last_layer_dim_pi)
        )
        # Value network
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim*2, last_layer_dim_vf)
        )

    def forward(self, images: th.Tensor, confs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """

        image_features = self.linear(self.cnn(images))
        conf_features = self.conf_linear(confs)
        
        features = th.cat((image_features, conf_features), 1)

        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)


In [None]:
import torch.optim as optim

net = CustomNetwork(feature_dim=128, last_layer_dim_pi=3**8, last_layer_dim_vf=1)

criterion_a = nn.CrossEntropyLoss()
criterion_v = nn.MSELoss()
optimizer = optim.AdamW(net.parameters())

In [None]:
DEVICE = "cpu" # or "cuda" if available
net.to(DEVICE)

In [None]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    running_loss_a = 0.0
    running_loss_v = 0.0
    total = 0.0
    correct = 0.0
    for i, data in enumerate(dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        im, c, a, v = data

        im = im.to(DEVICE).float()
        c = c.to(DEVICE).float()
        a = a.to(DEVICE)
        v = v.to(DEVICE).float()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        output_a, output_v = net(im, c)
        loss_a = criterion_a(output_a, a)
        loss_v = criterion_v(output_v, v)
        loss = loss_a + loss_v
        loss.backward()
        optimizer.step()

        _, predicted = th.max(output_a, 1)
        total += a.size(0)
        correct += (predicted == a).sum().item()

        # print statistics
        running_loss += loss.item()
        running_loss_a += loss_a.item()
        running_loss_v += loss_v.item()
        if i % 200 == 199:    # print every 200 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}, a_loss: {running_loss_a / 200:.3f}, v_loss: {running_loss_v / 200:.3f}, accuracy: {100 * correct // total}')
            running_loss = 0.0
            running_loss_a = 0.0
            running_loss_v = 0.0
            total = 0.0
            correct = 0.0

print('Finished Training')