In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from shark.utils import nb_init

nb_init()

# PPO againt on Inverted Double Pendulum

In [None]:
from loguru import logger
import os
import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
import pandas as pd

from shark.models import PPOPendulum
from shark.utils import plot_metrics

## Create model

We choose the gymnasium environment here.

In [None]:
frame_skip = 1
frames_per_batch = frame_skip * 100
model = PPOPendulum(
    frame_skip=frame_skip,
    frames_per_batch=frames_per_batch,
)

## Rollout

We can immediately try to play, and check what playing returns.

In [None]:
rollout = model.env.rollout(3)
logger.info(f"Rollout of three steps: {rollout}")
logger.info(f"Shape of the rollout TensorDict: {rollout.batch_size}")
logger.info(f"Running policy: {model.policy_module(model.env.reset())}")
logger.info(f"Running value: {model.value_module(model.env.reset())}")

## Data

This is what a batch of data looks like.

In [None]:
# Collector
collector = model.train_dataloader()
for _, tensordict_data in enumerate(collector):
    logger.info(f"Tensordict data:\n{tensordict_data}")
    batch_size = int(tensordict_data.batch_size[0])
    assert batch_size == int(frames_per_batch // frame_skip)
    break

In [None]:
# Training
trainer = pl.Trainer(
    accelerator="cpu",
    max_steps=16,
    val_check_interval=2,
    log_every_n_steps=1,
    logger=CSVLogger(
        save_dir="pytest_artifacts",
        name=model.__class__.__name__,
    ),
)
trainer.fit(model)

In [None]:
# Get logged stuff
log_dir = trainer.log_dir
assert isinstance(log_dir, str)
logs = trainer.logged_metrics
assert isinstance(logs, dict)
logger.info(log_dir)
logger.info(logs)
filename = os.path.join(log_dir, "metrics.csv")
df = pd.read_csv(filename)
logger.info(df.head())

In [None]:
import matplotlib.pyplot as plt

# Plot
plot_metrics(df, show=True)