# Load packages

In [None]:
from pathlib import Path
from environment import Santa2022Environment
from utils import *

import pandas as pd
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import CheckpointCallback
from gym.envs.registration import register

import matplotlib.pyplot as plt

# Register Env

In [None]:
register(
    id="kaggle_santa/Santa2022-v0",
    entry_point="environment:Santa2022Environment",
    max_episode_steps=1e09
)

# Load Image of Christmas card

In [None]:
df_image = pd.read_csv("image.csv")
image = df_to_image(df_image)

In [None]:
plt.imshow(image)
plt.show()

# Load subbmission confs

In [None]:
all_confs = []
for sub_file in Path("./submissions").glob("*.csv"):
    s = pd.read_csv(sub_file.as_posix())
    list_of_confs = s.apply(lambda x: [list(map(int, link.split())) for link in x.configuration.split(";")], axis=1).tolist()
    all_confs.extend(list_of_confs)

# Load Gym Env

In [None]:
# Parallel environments
env = make_vec_env("kaggle_santa/Santa2022-v0", n_envs=4, env_kwargs={"image": image, "starting_confs": all_confs})

# Create instance of PPO model

In [None]:
model = PPO(
    "MultiInputPolicy",
    env = env,
    verbose=1
)
checkpoint_callback = CheckpointCallback(
  save_freq=10000,
  save_path="models/",
  name_prefix="rl_model",
)

# Run Training

In [None]:
model.learn(total_timesteps=int(3e8), callback=checkpoint_callback)

# Show video of agent

In [None]:
new_env = Santa2022Environment(image)
new_env.reset()
video_path = "./videos"
video_prefix = "ppo-santa"
record_video(new_env, model, video_length=1000, prefix=video_prefix, video_folder=video_path)
show_videos(video_path, prefix=video_prefix)