In [1]:
!pip install gymnasium torch numpy matplotlib tqdm panda_gym

Collecting panda_gym
  Downloading panda_gym-3.0.7-py3-none-any.whl.metadata (4.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecti

In [None]:
import os
import gymnasium as gym
import panda_gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
from collections import deque
import time
import warnings
from typing import Dict, List, Tuple, Type, Union, Optional, Any, Callable

warnings.filterwarnings("ignore", category=DeprecationWarning)

def set_seeds(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seeds()

device = torch.device("cuda" if torch.cuda.is_available() else
                     "mps" if torch.backends.mps.is_available() else
                     "cpu")
print(f"Using device: {device}")

class RunningMeanStd:
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, dtype=np.float64)
        self.var = np.ones(shape, dtype=np.float64)
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        delta = batch_mean - self.mean
        tot_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / tot_count
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
        new_var = m_2 / tot_count

        self.mean = new_mean
        self.var = new_var
        self.count = tot_count

class VecNormalize:
    def __init__(self, obs_rms=None, ret_rms=None, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        self.obs_rms = obs_rms or RunningMeanStd(shape=(1,))
        self.ret_rms = ret_rms or RunningMeanStd(shape=())
        self.clipob = clipob
        self.cliprew = cliprew
        self.gamma = gamma
        self.epsilon = epsilon
        self.returns = 0.0

    def normalize_obs(self, obs):
        if isinstance(obs, dict):
            normalized_obs = {}
            for key, value in obs.items():
                if key in ["observation", "desired_goal"]:
                    normalized_obs[key] = self._normalize_obs(value)
                else:
                    normalized_obs[key] = value
            return normalized_obs
        else:
            return self._normalize_obs(obs)

    def _normalize_obs(self, obs):
        obs_array = np.array(obs)
        self.obs_rms.update(obs_array)
        return np.clip((obs_array - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon),
                       -self.clipob, self.clipob)

    def normalize_reward(self, reward):
        self.returns = self.returns * self.gamma + reward
        self.ret_rms.update(np.array([self.returns]))
        return np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)

    def reset_returns(self):
        self.returns = 0.0

# NN for Actor
class Actor(nn.Module):
    def __init__(self, obs_dim, goal_dim, action_dim, net_arch=[256, 256], activation_fn=nn.ReLU):
        super(Actor, self).__init__()

        self.input_dim = obs_dim + goal_dim

        self.layers = nn.ModuleList()
        last_dim = self.input_dim

        for dim in net_arch:
            self.layers.append(nn.Linear(last_dim, dim))
            self.layers.append(activation_fn())
            last_dim = dim

        # Mean of the Gaussian policy
        self.mean_layer = nn.Linear(last_dim, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))

        # weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=1)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        mean = self.mean_layer(x)
        mean = torch.tanh(mean)

        log_std = self.log_std.expand_as(mean)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)

        return mean, std

# NN for Critic
class Critic(nn.Module):
    def __init__(self, obs_dim, goal_dim, net_arch=[256, 256], activation_fn=nn.ReLU):
        super(Critic, self).__init__()

        self.input_dim = obs_dim + goal_dim

        self.layers = nn.ModuleList()
        last_dim = self.input_dim

        for dim in net_arch:
            self.layers.append(nn.Linear(last_dim, dim))
            self.layers.append(activation_fn())
            last_dim = dim

        self.value_layer = nn.Linear(last_dim, 1)

        # weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=1)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        value = self.value_layer(x)
        return value

class RolloutBuffer:
    def __init__(self, buffer_size, obs_dim, goal_dim, action_dim, gamma=0.99, gae_lambda=0.95):
        self.observations = np.zeros((buffer_size, obs_dim), dtype=np.float32)
        self.goals = np.zeros((buffer_size, goal_dim), dtype=np.float32)
        self.actions = np.zeros((buffer_size, action_dim), dtype=np.float32)
        self.rewards = np.zeros(buffer_size, dtype=np.float32)
        self.advantages = np.zeros(buffer_size, dtype=np.float32)
        self.returns = np.zeros(buffer_size, dtype=np.float32)
        self.values = np.zeros(buffer_size, dtype=np.float32)
        self.log_probs = np.zeros(buffer_size, dtype=np.float32)
        self.dones = np.zeros(buffer_size, dtype=np.float32)

        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.buffer_size = buffer_size
        self.pos = 0
        self.full = False

    def add(self, obs, goal, action, reward, done, value, log_prob):
        self.observations[self.pos] = obs
        self.goals[self.pos] = goal
        self.actions[self.pos] = action
        self.rewards[self.pos] = reward
        self.dones[self.pos] = done
        self.values[self.pos] = value
        self.log_probs[self.pos] = log_prob

        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True
            self.pos = 0

    def compute_returns_and_advantages(self, last_value=0.0):
        last_gae_lam = 0
        for step in reversed(range(self.buffer_size)):
            if step == self.buffer_size - 1:
                next_non_terminal = 1.0 - self.dones[step]
                next_values = last_value
            else:
                next_non_terminal = 1.0 - self.dones[step + 1]
                next_values = self.values[step + 1]

            delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
            last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
            self.advantages[step] = last_gae_lam

        self.returns = self.advantages + self.values

    def get(self):
        indices = np.arange(self.buffer_size)
        return (
            self.observations,
            self.goals,
            self.actions,
            self.rewards,
            self.returns,
            self.log_probs,
            self.advantages
        )

    def clear(self):
        self.pos = 0
        self.full = False

# PPO Agent class
class PPO:
    def __init__(
        self,
        env,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        clip_range_vf=None,
        ent_coef=0.0,
        vf_coef=0.5,
        max_grad_norm=0.5,
        use_sde=False,
        sde_sample_freq=4,
        target_kl=None,
        tensorboard_log=None,
        create_eval_env=False,
        policy_kwargs=None,
        verbose=0,
        seed=None,
        device=device,
        _init_setup_model=True
    ):
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.env = env
        self.obs_dim = env.observation_space['observation'].shape[0]
        self.goal_dim = env.observation_space['desired_goal'].shape[0]
        self.action_dim = env.action_space.shape[0]

        # Actions
        self.action_scale = (env.action_space.high - env.action_space.low) / 2
        self.action_bias = (env.action_space.high + env.action_space.low) / 2

        # Hyperparameters
        self.learning_rate = learning_rate
        self.n_steps = n_steps
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.clip_range = clip_range
        self.clip_range_vf = clip_range_vf
        self.target_kl = target_kl

        policy_kwargs = policy_kwargs or {}
        self.actor = Actor(self.obs_dim, self.goal_dim, self.action_dim,
                           net_arch=policy_kwargs.get("net_arch", [256, 256]))
        self.critic = Critic(self.obs_dim, self.goal_dim,
                            net_arch=policy_kwargs.get("net_arch", [256, 256]))

        self.actor.to(device)
        self.critic.to(device)

        # Optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate)

        # Experience bufffer
        self.rollout_buffer = RolloutBuffer(n_steps, self.obs_dim, self.goal_dim, self.action_dim, gamma, gae_lambda)

        # Normalization
        self.normalizer = VecNormalize(gamma=gamma)

        self.timesteps = 0
        
        # Track losses
        self.logger = {
            "timesteps": [],
            "mean_reward": [],
            "episode_length": [],
            "policy_loss": [],
            "value_loss": [],
            "entropy_loss": [],
            "kl_divergence": []
        }

        self.callbacks = []

    def _process_observation(self, obs):
        if isinstance(obs, dict):
            observation = np.array(obs["observation"], dtype=np.float32)
            goal = np.array(obs["desired_goal"], dtype=np.float32)
        else:
            observation = obs[:self.obs_dim]
            goal = obs[self.obs_dim:]

        return observation, goal

    def _combined_input(self, obs, goal):
        if isinstance(obs, np.ndarray) and isinstance(goal, np.ndarray):
            return np.concatenate([obs, goal], axis=-1)
        else:
            return torch.cat([obs, goal], dim=-1)

    def _normalize_observation(self, obs):
        if isinstance(obs, dict):
            return self.normalizer.normalize_obs(obs)
        return obs

    def _normalize_reward(self, reward):
        return self.normalizer.normalize_reward(reward)

    def predict(self, observation, deterministic=False):
        with torch.no_grad():
            observation = self._normalize_observation(observation)
            obs_array, goal_array = self._process_observation(observation)

            obs_tensor = torch.FloatTensor(obs_array).to(device)
            goal_tensor = torch.FloatTensor(goal_array).to(device)
            x = self._combined_input(obs_tensor, goal_tensor)
            action_mean, action_std = self.actor(x)

            if deterministic:
                action = action_mean
            else:
                dist = Normal(action_mean, action_std)
                action = dist.sample()

            action_np = action.cpu().numpy()
            scaled_action = action_np * self.action_scale + self.action_bias

            return scaled_action, None 

    def collect_rollouts(self, env, n_steps=None):
        n_steps = n_steps or self.n_steps
        self.rollout_buffer.clear()

        self.normalizer.reset_returns()

        obs, _ = env.reset()

        episode_rewards = []
        episode_lengths = []
        current_episode_reward = 0
        current_episode_length = 0

        for step in range(n_steps):
            current_episode_length += 1

            norm_obs = self._normalize_observation(obs)
            obs_array, goal_array = self._process_observation(norm_obs)

            obs_tensor = torch.FloatTensor(obs_array).to(device)
            goal_tensor = torch.FloatTensor(goal_array).to(device)

            x = self._combined_input(obs_tensor, goal_tensor)

            # Value estimate
            with torch.no_grad():
                value = self.critic(x).cpu().numpy().flatten()
                action_mean, action_std = self.actor(x)

                dist = Normal(action_mean, action_std)
                action = dist.sample()
                log_prob = dist.log_prob(action).sum(dim=-1).cpu().numpy()

            action_np = action.cpu().numpy()
            scaled_action = action_np * self.action_scale + self.action_bias

            next_obs, reward, terminated, truncated, info = env.step(scaled_action)
            done = terminated or truncated

            current_episode_reward += reward
            norm_reward = self._normalize_reward(reward)
            self.rollout_buffer.add(obs_array, goal_array, action_np, norm_reward, done, value, log_prob)
            obs = next_obs

            if done:
                episode_rewards.append(current_episode_reward)
                episode_lengths.append(current_episode_length)
                current_episode_reward = 0
                current_episode_length = 0

                obs, _ = env.reset()
                self.normalizer.reset_returns()

        with torch.no_grad():
            obs_array, goal_array = self._process_observation(self._normalize_observation(obs))
            obs_tensor = torch.FloatTensor(obs_array).to(device)
            goal_tensor = torch.FloatTensor(goal_array).to(device)
            x = self._combined_input(obs_tensor, goal_tensor)
            last_value = self.critic(x).cpu().numpy().flatten()

        self.rollout_buffer.compute_returns_and_advantages(last_value)

        mean_reward = np.mean(episode_rewards) if episode_rewards else 0
        mean_length = np.mean(episode_lengths) if episode_lengths else 0

        return mean_reward, mean_length

    def train(self):
        observations, goals, actions, _, returns, old_log_probs, advantages = self.rollout_buffer.get()
        batch_size = self.batch_size or self.n_steps
        n_batches = self.n_steps // batch_size

        # Normalize
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        policy_losses = []
        value_losses = []
        entropy_losses = []
        kl_divs = []

        for epoch in range(self.n_epochs):
            perm = np.random.permutation(self.n_steps)

            for batch_idx in range(n_batches):
                start_idx = batch_idx * batch_size
                end_idx = (batch_idx + 1) * batch_size
                batch_indices = perm[start_idx:end_idx]

                obs_batch = torch.FloatTensor(observations[batch_indices]).to(device)
                goal_batch = torch.FloatTensor(goals[batch_indices]).to(device)
                action_batch = torch.FloatTensor(actions[batch_indices]).to(device)
                old_log_prob_batch = torch.FloatTensor(old_log_probs[batch_indices]).to(device)
                advantage_batch = torch.FloatTensor(advantages[batch_indices]).to(device)
                return_batch = torch.FloatTensor(returns[batch_indices]).to(device)
                state_batch = self._combined_input(obs_batch, goal_batch)

                # Update critic
                values = self.critic(state_batch).flatten()

                # Value loss
                if self.clip_range_vf is None:
                    value_loss = F.mse_loss(values, return_batch)
                else:
                    values_clipped = torch.clamp(
                        values,
                        return_batch - self.clip_range_vf,
                        return_batch + self.clip_range_vf
                    )
                    value_loss = torch.max(
                        F.mse_loss(values, return_batch),
                        F.mse_loss(values_clipped, return_batch)
                    )

                # Update critic
                self.critic_optimizer.zero_grad()
                value_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                self.critic_optimizer.step()

                # Update actor
                action_mean, action_std = self.actor(state_batch)
                dist = Normal(action_mean, action_std)

                log_prob = dist.log_prob(action_batch).sum(dim=-1)
                entropy = dist.entropy().sum(dim=-1).mean()

                ratio = torch.exp(log_prob - old_log_prob_batch)
                policy_loss_1 = advantage_batch * ratio
                policy_loss_2 = advantage_batch * torch.clamp(ratio, 1.0 - self.clip_range, 1.0 + self.clip_range)
                policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

                entropy_loss = -entropy * self.ent_coef

                with torch.no_grad():
                    log_ratio = log_prob - old_log_prob_batch
                    approx_kl = torch.mean((torch.exp(log_ratio) - 1) - log_ratio).item()
                    kl_divs.append(approx_kl)

                # Update actor
                self.actor_optimizer.zero_grad()
                (policy_loss + entropy_loss).backward()
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                self.actor_optimizer.step()

                policy_losses.append(policy_loss.item())
                value_losses.append(value_loss.item())
                entropy_losses.append(entropy_loss.item())

            if self.target_kl is not None and np.mean(kl_divs) > 1.5 * self.target_kl:
                print(f"Early stopping at epoch {epoch+1}/{self.n_epochs} due to reaching KL target")
                break

        return {
            "policy_loss": np.mean(policy_losses),
            "value_loss": np.mean(value_losses),
            "entropy_loss": np.mean(entropy_losses),
            "kl_divergence": np.mean(kl_divs)
        }

    def learn(self, total_timesteps, callback=None, log_interval=1, eval_env=None, eval_freq=None, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True):
        timesteps_elapsed = 0
        timesteps_since_eval = 0
        iterations = 0

        self.timesteps = 0 if reset_num_timesteps else self.timesteps

        while self.timesteps < total_timesteps:
            mean_reward, mean_len = self.collect_rollouts(self.env, self.n_steps)

            losses = self.train()

            timesteps_elapsed += self.n_steps
            self.timesteps += self.n_steps
            timesteps_since_eval += self.n_steps
            iterations += 1

            if iterations % log_interval == 0:
                print(f"Iteration: {iterations}, timesteps: {self.timesteps}")
                print(f"Mean reward: {mean_reward:.2f}, mean episode length: {mean_len:.2f}")
                print(f"Losses: {losses}")

                self.logger["timesteps"].append(self.timesteps)
                self.logger["mean_reward"].append(mean_reward)
                self.logger["episode_length"].append(mean_len)

                # Storing loss values
                self.logger["policy_loss"].append(losses["policy_loss"])
                self.logger["value_loss"].append(losses["value_loss"])
                self.logger["entropy_loss"].append(losses["entropy_loss"])
                self.logger["kl_divergence"].append(losses["kl_divergence"])

            if eval_env is not None and eval_freq is not None and timesteps_since_eval >= eval_freq:
                timesteps_since_eval = 0
                self.evaluate_policy(eval_env, n_eval_episodes)

            if callback is not None:
                callback(locals(), globals())

        return self

    def evaluate_policy(self, env=None, n_eval_episodes=10, deterministic=True):
        eval_env = env or self.env

        episode_rewards = []
        episode_lengths = []

        for _ in range(n_eval_episodes):
            obs, _ = eval_env.reset()
            done = False
            episode_reward = 0
            episode_length = 0

            while not done:
                action, _ = self.predict(obs, deterministic=deterministic)
                obs, reward, terminated, truncated, _ = eval_env.step(action)
                done = terminated or truncated

                episode_reward += reward
                episode_length += 1

            episode_rewards.append(episode_reward)
            episode_lengths.append(episode_length)

        mean_reward = np.mean(episode_rewards)
        mean_length = np.mean(episode_lengths)

        print(f"Evaluation: Mean reward: {mean_reward:.2f}, mean episode length: {mean_length:.2f}")

        return mean_reward, mean_length

    def save(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            "actor_state_dict": self.actor.state_dict(),
            "critic_state_dict": self.critic.state_dict(),
            "actor_optimizer_state_dict": self.actor_optimizer.state_dict(),
            "critic_optimizer_state_dict": self.critic_optimizer.state_dict(),
            "normalizer_mean": self.normalizer.obs_rms.mean,
            "normalizer_var": self.normalizer.obs_rms.var,
            "normalizer_count": self.normalizer.obs_rms.count,
            "ret_rms_mean": self.normalizer.ret_rms.mean,
            "ret_rms_var": self.normalizer.ret_rms.var,
            "ret_rms_count": self.normalizer.ret_rms.count,
        }, path)
        print(f"Model saved to {path}")

    def load(self, path):
        checkpoint = torch.load(path, map_location=device)
        self.actor.load_state_dict(checkpoint["actor_state_dict"])
        self.critic.load_state_dict(checkpoint["critic_state_dict"])
        self.actor_optimizer.load_state_dict(checkpoint["actor_optimizer_state_dict"])
        self.critic_optimizer.load_state_dict(checkpoint["critic_optimizer_state_dict"])

        # Load normalizer
        self.normalizer.obs_rms.mean = checkpoint["normalizer_mean"]
        self.normalizer.obs_rms.var = checkpoint["normalizer_var"]
        self.normalizer.obs_rms.count = checkpoint["normalizer_count"]
        self.normalizer.ret_rms.mean = checkpoint["ret_rms_mean"]
        self.normalizer.ret_rms.var = checkpoint["ret_rms_var"]
        self.normalizer.ret_rms.count = checkpoint["ret_rms_count"]

        print(f"Model loaded from {path}")
        return self

# Callback system
class BaseCallback:
    def __init__(self, verbose=0):
        self.verbose = verbose
        self._init_callback()

    def _init_callback(self):
        pass

    def __call__(self, locals_dict, globals_dict):
        self.on_step()
        return True

    def on_step(self):
        pass

class CheckpointCallback(BaseCallback):
    def __init__(self, save_freq, save_path, name_prefix="model", verbose=0):
        super(CheckpointCallback, self).__init__(verbose)
        self.save_freq = save_freq
        self.save_path = save_path
        self.name_prefix = name_prefix
        self.last_save_step = 0

    def _init_callback(self):
        os.makedirs(self.save_path, exist_ok=True)

    def on_step(self):
        model = locals_dict["self"]
        timesteps = model.timesteps

        if timesteps > 0 and timesteps - self.last_save_step >= self.save_freq:
            path = f"{self.save_path}/{self.name_prefix}_{timesteps}_steps.pth"
            model.save(path)
            self.last_save_step = timesteps

def train_ppo(env_name="PandaReach-v3", total_timesteps=1000000, log_interval=1, save_interval=10000, eval_interval=50000, n_eval_episodes=10):
    print(f"Training on {env_name}...")

    env = gym.make(env_name)
    eval_env = gym.make(env_name)  # Separate env for evaluation

    os.makedirs('sb3_style_models', exist_ok=True)
    os.makedirs('sb3_style_logs', exist_ok=True)

    # Initialize PPO agent
    ppo = PPO(
        env=env,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        clip_range_vf=None,
        ent_coef=0.0,
        vf_coef=0.5,
        max_grad_norm=0.5,
        target_kl=0.01,
        policy_kwargs={"net_arch": [256, 256]}
    )

    checkpoint_callback = lambda local_vars, global_vars: (
        local_vars['self'].save(f"sb3_style_models/{env_name}_{local_vars['self'].timesteps}.pth")
        if local_vars['iterations'] % (save_interval // local_vars['self'].n_steps) == 0
        else None
    )

    # Train agent
    ppo.learn(
        total_timesteps=total_timesteps,
        callback=checkpoint_callback,
        log_interval=log_interval,
        eval_env=eval_env,
        eval_freq=eval_interval,
        n_eval_episodes=n_eval_episodes
    )

    ppo.save(f"sb3_style_models/{env_name}_final.pth")

    plt.figure(figsize=(15, 10))

    # Mean reward
    plt.subplot(2, 2, 1)
    plt.plot(ppo.logger["timesteps"], ppo.logger["mean_reward"])
    plt.title('Mean Reward')
    plt.xlabel('Timesteps')
    plt.ylabel('Reward')

    # Episode length
    plt.subplot(2, 2, 2)
    plt.plot(ppo.logger["timesteps"], ppo.logger["episode_length"])
    plt.title('Episode Length')
    plt.xlabel('Timesteps')
    plt.ylabel('Length')

    # Policy and value losses
    plt.subplot(2, 2, 3)
    plt.plot(ppo.logger["timesteps"], ppo.logger["policy_loss"], label='Policy Loss')
    plt.plot(ppo.logger["timesteps"], ppo.logger["value_loss"], label='Value Loss')
    plt.title('Policy and Value Losses')
    plt.xlabel('Timesteps')
    plt.ylabel('Loss')
    plt.legend()

    # Entropy loss and KL divergence
    plt.subplot(2, 2, 4)
    plt.plot(ppo.logger["timesteps"], ppo.logger["entropy_loss"], label='Entropy Loss')
    plt.plot(ppo.logger["timesteps"], ppo.logger["kl_divergence"], label='KL Divergence')
    plt.title('Entropy Loss and KL Divergence')
    plt.xlabel('Timesteps')
    plt.ylabel('Value')
    plt.legend()

    plt.tight_layout()
    plt.savefig('sb3_style_logs/training_progress.png')
    plt.close()

    return ppo

def evaluate_ppo(env_name="PandaReach-v3", model_path=None, n_eval_episodes=10, render=True):
    print(f"Evaluating on {env_name}...")

    env = gym.make(env_name, render_mode="human" if render else None)
    # Initialize agent
    ppo = PPO(env=env)

    if model_path:
        ppo.load(model_path)
        print(f"Model loaded from {model_path}")
    else:
        print("No model loaded, using random policy")

    mean_reward, mean_length = ppo.evaluate_policy(n_eval_episodes=n_eval_episodes)
    env.close()
    return mean_reward, mean_length

if __name__ == "__main__":
    env_name = "PandaReach-v3"
    ppo_agent = train_ppo(
        env_name=env_name,
        total_timesteps=1000000,
        log_interval=1,
        save_interval=50000,
        eval_interval=50000
    )


Using device: cuda
Training on PandaReach-v3...
Iteration: 1, timesteps: 2048
Mean reward: -45.13, mean episode length: 45.27
Losses: {'policy_loss': np.float64(-0.01423507425643038), 'value_loss': np.float64(0.3657156912377104), 'entropy_loss': np.float64(0.0), 'kl_divergence': np.float64(0.007575771590927616)}
Iteration: 2, timesteps: 4096
Mean reward: -47.67, mean episode length: 47.71
Losses: {'policy_loss': np.float64(-0.02100983641576022), 'value_loss': np.float64(0.29713847059756515), 'entropy_loss': np.float64(0.0), 'kl_divergence': np.float64(0.009160734160104766)}
Iteration: 3, timesteps: 6144
Mean reward: -43.35, mean episode length: 43.50
Losses: {'policy_loss': np.float64(-0.019710145736462438), 'value_loss': np.float64(0.43364804834127424), 'entropy_loss': np.float64(0.0), 'kl_divergence': np.float64(0.008433567130123266)}
Iteration: 4, timesteps: 8192
Mean reward: -45.98, mean episode length: 46.11
Losses: {'policy_loss': np.float64(-0.02010320747504011), 'value_loss': n

In [4]:
#check