In [None]:
!pip install gymnasium torch numpy matplotlib tqdm panda_gym

Collecting panda_gym
  Downloading panda_gym-3.0.7-py3-none-any.whl.metadata (4.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecti

In [None]:
import os
import gymnasium as gym
import panda_gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
from collections import deque
import time
import warnings
import random
from typing import Dict, List, Tuple, Type, Union, Optional, Any, Callable

warnings.filterwarnings("ignore", category=DeprecationWarning)

def set_seeds(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seeds()

device = torch.device("cuda" if torch.cuda.is_available() else
                     "mps" if torch.backends.mps.is_available() else
                     "cpu")
print(f"Using device: {device}")

class RunningMeanStd:
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, dtype=np.float64)
        self.var = np.ones(shape, dtype=np.float64)
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        delta = batch_mean - self.mean
        tot_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / tot_count
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
        new_var = m_2 / tot_count

        self.mean = new_mean
        self.var = new_var
        self.count = tot_count

class VecNormalize:
    def __init__(self, obs_rms=None, ret_rms=None, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        self.obs_rms = obs_rms or RunningMeanStd(shape=(1,))
        self.ret_rms = ret_rms or RunningMeanStd(shape=())
        self.clipob = clipob
        self.cliprew = cliprew
        self.gamma = gamma
        self.epsilon = epsilon
        self.returns = 0.0

    def normalize_obs(self, obs):
        if isinstance(obs, dict):
            normalized_obs = {}
            for key, value in obs.items():
                if key in ["observation", "desired_goal"]:
                    normalized_obs[key] = self._normalize_obs(value)
                else:
                    normalized_obs[key] = value
            return normalized_obs
        else:
            return self._normalize_obs(obs)

    def _normalize_obs(self, obs):
        obs_array = np.array(obs)
        self.obs_rms.update(obs_array)
        return np.clip((obs_array - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon),
                       -self.clipob, self.clipob)

    def normalize_reward(self, reward):
        self.returns = self.returns * self.gamma + reward
        self.ret_rms.update(np.array([self.returns]))
        return np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)

    def reset_returns(self):
        self.returns = 0.0

# NN for DQN
class QNetwork(nn.Module):
    def __init__(self, obs_dim, goal_dim, action_dim, net_arch=[256, 256], activation_fn=nn.ReLU):
        super(QNetwork, self).__init__()

        self.input_dim = obs_dim + goal_dim
        self.layers = nn.ModuleList()
        last_dim = self.input_dim

        for dim in net_arch:
            self.layers.append(nn.Linear(last_dim, dim))
            self.layers.append(activation_fn())
            last_dim = dim

        self.q_layer = nn.Linear(last_dim, action_dim)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=1)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        q_values = self.q_layer(x)
        return q_values

class ReplayBuffer:
    def __init__(self, buffer_size, obs_dim, goal_dim, action_dim):
        self.buffer_size = buffer_size
        self.observations = np.zeros((buffer_size, obs_dim), dtype=np.float32)
        self.goals = np.zeros((buffer_size, goal_dim), dtype=np.float32)
        self.actions = np.zeros(buffer_size, dtype=np.int64)
        self.rewards = np.zeros(buffer_size, dtype=np.float32)
        self.next_observations = np.zeros((buffer_size, obs_dim), dtype=np.float32)
        self.next_goals = np.zeros((buffer_size, goal_dim), dtype=np.float32)
        self.dones = np.zeros(buffer_size, dtype=np.float32)

        self.pos = 0
        self.full = False

    def add(self, obs, goal, action, reward, next_obs, next_goal, done):
        self.observations[self.pos] = obs
        self.goals[self.pos] = goal
        self.actions[self.pos] = action
        self.rewards[self.pos] = reward
        self.next_observations[self.pos] = next_obs
        self.next_goals[self.pos] = next_goal
        self.dones[self.pos] = done

        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True
            self.pos = 0

    def sample(self, batch_size):
        max_size = self.buffer_size if self.full else self.pos
        batch_indices = np.random.randint(0, max_size, size=batch_size)

        obs_batch = self.observations[batch_indices]
        goal_batch = self.goals[batch_indices]
        action_batch = self.actions[batch_indices]
        reward_batch = self.rewards[batch_indices]
        next_obs_batch = self.next_observations[batch_indices]
        next_goal_batch = self.next_goals[batch_indices]
        done_batch = self.dones[batch_indices]

        return (
            obs_batch,
            goal_batch,
            action_batch,
            reward_batch,
            next_obs_batch,
            next_goal_batch,
            done_batch
        )

    def __len__(self):
        return self.buffer_size if self.full else self.pos

# DQN Agent
class DQN:
    def __init__(
        self,
        env,
        learning_rate=1e-4,
        buffer_size=100000,
        batch_size=64,
        target_update_freq=1000,
        gamma=0.99,
        tau=1.0,
        eps_start=1.0,
        eps_end=0.05,
        eps_decay=50000,
        double_q=True,
        exploration_fraction=0.1,
        policy_kwargs=None,
        device=device
    ):
        # Env info
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.env = env

        self.obs_dim = env.observation_space['observation'].shape[0]
        self.goal_dim = env.observation_space['desired_goal'].shape[0]

        # Discretize continuous action space
        self.actions_per_dim = 5
        self.action_dim = self.actions_per_dim ** env.action_space.shape[0]

        # Discrete action to continuous action
        self.action_map = self._create_action_map(env.action_space.shape[0], self.actions_per_dim,
                                                 env.action_space.low, env.action_space.high)

        # Hyperparameters
        self.learning_rate = learning_rate
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq
        self.gamma = gamma
        self.tau = tau
        self.double_q = double_q

        self.exploration_fraction = exploration_fraction
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        self.epsilon = eps_start

        policy_kwargs = policy_kwargs or {}
        self.q_network = QNetwork(
            self.obs_dim,
            self.goal_dim,
            self.action_dim,
            net_arch=policy_kwargs.get("net_arch", [256, 256])
        )
        self.target_q_network = QNetwork(
            self.obs_dim,
            self.goal_dim,
            self.action_dim,
            net_arch=policy_kwargs.get("net_arch", [256, 256])
        )

        self.q_network.to(device)
        self.target_q_network.to(device)
        self.target_q_network.load_state_dict(self.q_network.state_dict())

        # Optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        self.replay_buffer = ReplayBuffer(buffer_size, self.obs_dim, self.goal_dim, 1)
        self.normalizer = VecNormalize(gamma=gamma)

        self.timesteps = 0
        self.logger = {
            "timesteps": [],
            "mean_reward": [],
            "episode_length": [],
            "loss": []
        }

        self.callbacks = []
        self.loss_history = deque(maxlen=100)

    def _create_action_map(self, action_dims, actions_per_dim, low, high):
        dim_values = []
        for dim in range(action_dims):
            values = np.linspace(low[dim], high[dim], actions_per_dim)
            dim_values.append(values)

        action_map = []
        for idx in range(actions_per_dim ** action_dims):
            action = []
            temp_idx = idx
            for dim in range(action_dims):
                action_idx = temp_idx % actions_per_dim
                temp_idx = temp_idx // actions_per_dim
                action.append(dim_values[dim][action_idx])
            action_map.append(np.array(action))

        return action_map

    def _process_observation(self, obs):
        if isinstance(obs, dict):
            observation = np.array(obs["observation"], dtype=np.float32)
            goal = np.array(obs["desired_goal"], dtype=np.float32)
        else:
            observation = obs[:self.obs_dim]
            goal = obs[self.obs_dim:]

        return observation, goal

    def _combined_input(self, obs, goal):
        if isinstance(obs, np.ndarray) and isinstance(goal, np.ndarray):
            return np.concatenate([obs, goal], axis=-1)
        else:
            return torch.cat([obs, goal], dim=-1)

    def _normalize_observation(self, obs):
        if isinstance(obs, dict):
            return self.normalizer.normalize_obs(obs)
        return obs

    def _normalize_reward(self, reward):
        return self.normalizer.normalize_reward(reward)

    def update_epsilon(self):
        fraction = min(1.0, self.timesteps / self.eps_decay)
        self.epsilon = self.eps_start + fraction * (self.eps_end - self.eps_start)

    def predict(self, observation, deterministic=False):
        with torch.no_grad():
            observation = self._normalize_observation(observation)
            obs_array, goal_array = self._process_observation(observation)

            obs_tensor = torch.FloatTensor(obs_array).unsqueeze(0).to(device)
            goal_tensor = torch.FloatTensor(goal_array).unsqueeze(0).to(device)

            x = self._combined_input(obs_tensor, goal_tensor)

            # Epsilon-greedy action
            if deterministic or np.random.random() > self.epsilon:
                # Greedy
                q_values = self.q_network(x)
                action_idx = q_values.argmax(dim=1).item()
            else:
                # Random
                action_idx = np.random.randint(0, self.action_dim)

            # Discrete action to continuous using action map
            action = self.action_map[action_idx]

            return action, action_idx

    def train_step(self):
        if len(self.replay_buffer) < self.batch_size:
            return 0.0 

        (
            obs_batch,
            goal_batch,
            action_batch,
            reward_batch,
            next_obs_batch,
            next_goal_batch,
            done_batch
        ) = self.replay_buffer.sample(self.batch_size)

        obs_tensor = torch.FloatTensor(obs_batch).to(device)
        goal_tensor = torch.FloatTensor(goal_batch).to(device)
        action_tensor = torch.LongTensor(action_batch).to(device)
        reward_tensor = torch.FloatTensor(reward_batch).to(device)
        next_obs_tensor = torch.FloatTensor(next_obs_batch).to(device)
        next_goal_tensor = torch.FloatTensor(next_goal_batch).to(device)
        done_tensor = torch.FloatTensor(done_batch).to(device)

        state_tensor = self._combined_input(obs_tensor, goal_tensor)
        next_state_tensor = self._combined_input(next_obs_tensor, next_goal_tensor)

        current_q_values = self.q_network(state_tensor).gather(1, action_tensor.unsqueeze(1))

        with torch.no_grad():
            if self.double_q:
                # Double DQN
                next_q_values_online = self.q_network(next_state_tensor)
                next_actions = next_q_values_online.argmax(dim=1, keepdim=True)
                next_q_values = self.target_q_network(next_state_tensor).gather(1, next_actions)
            else:
                # Regular DQN
                next_q_values = self.target_q_network(next_state_tensor).max(1, keepdim=True)[0]

            # Bellman equation
            target_q_values = reward_tensor.unsqueeze(1) + (1 - done_tensor.unsqueeze(1)) * self.gamma * next_q_values

        # Loss
        loss = F.smooth_l1_loss(current_q_values, target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10.0)
        self.optimizer.step()

        self.loss_history.append(loss.item())

        if self.timesteps % self.target_update_freq == 0:
            if self.tau >= 1.0:
                # Hard update
                self.target_q_network.load_state_dict(self.q_network.state_dict())
            else:
                # Soft update
                for target_param, param in zip(self.target_q_network.parameters(), self.q_network.parameters()):
                    target_param.data.copy_(
                        self.tau * param.data + (1 - self.tau) * target_param.data
                    )

        return loss.item()

    def learn(self, total_timesteps, callback=None, log_interval=100,
              train_freq=4, eval_env=None, eval_freq=10000, n_eval_episodes=5):
        episode_rewards = []
        episode_lengths = []
        current_episode_reward = 0
        current_episode_length = 0

        obs, _ = self.env.reset()
        all_losses = []

        while self.timesteps < total_timesteps:
            action, action_idx = self.predict(obs)

            next_obs, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated
            current_episode_reward += reward
            current_episode_length += 1

            obs_array, goal_array = self._process_observation(self._normalize_observation(obs))
            next_obs_array, next_goal_array = self._process_observation(self._normalize_observation(next_obs))
            norm_reward = self._normalize_reward(reward)

            self.replay_buffer.add(
                obs_array,
                goal_array,
                action_idx, 
                norm_reward,
                next_obs_array,
                next_goal_array,
                float(done)
            )

            obs = next_obs

            if self.timesteps % train_freq == 0:
                loss = self.train_step()
                if loss is not None: 
                    all_losses.append(loss)

            self.update_epsilon()

            if done:
                episode_rewards.append(current_episode_reward)
                episode_lengths.append(current_episode_length)

                if len(episode_rewards) % log_interval == 0:
                    mean_reward = np.mean(episode_rewards[-log_interval:])
                    mean_length = np.mean(episode_lengths[-log_interval:])
                    mean_loss = np.mean([l for l in all_losses[-log_interval*train_freq:] if l is not None]) if all_losses else 0

                    print(f"Timestep: {self.timesteps}, Episodes: {len(episode_rewards)}")
                    print(f"Mean reward: {mean_reward:.2f}, Mean length: {mean_length:.2f}")
                    print(f"Mean loss: {mean_loss:.5f}, Exploration rate: {self.epsilon:.3f}")

                    # Store 
                    self.logger["timesteps"].append(self.timesteps)
                    self.logger["mean_reward"].append(mean_reward)
                    self.logger["episode_length"].append(mean_length)
                    self.logger["loss"].append(mean_loss)

                current_episode_reward = 0
                current_episode_length = 0

                obs, _ = self.env.reset()
                self.normalizer.reset_returns()

            if eval_env is not None and eval_freq is not None and self.timesteps % eval_freq == 0:
                self.evaluate_policy(eval_env, n_eval_episodes)

            self.timesteps += 1

            if callback is not None:
                callback(locals(), globals())

        return self

    def evaluate_policy(self, env=None, n_eval_episodes=10, deterministic=True):
        eval_env = env or self.env

        episode_rewards = []
        episode_lengths = []

        for _ in range(n_eval_episodes):
            obs, _ = eval_env.reset()
            done = False
            episode_reward = 0
            episode_length = 0

            while not done:
                action, _ = self.predict(obs, deterministic=deterministic)
                obs, reward, terminated, truncated, _ = eval_env.step(action)
                done = terminated or truncated

                episode_reward += reward
                episode_length += 1

            episode_rewards.append(episode_reward)
            episode_lengths.append(episode_length)

        mean_reward = np.mean(episode_rewards)
        mean_length = np.mean(episode_lengths)

        print(f"Evaluation: Mean reward: {mean_reward:.2f}, mean episode length: {mean_length:.2f}")

        return mean_reward, mean_length

    def save(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            "q_network_state_dict": self.q_network.state_dict(),
            "target_q_network_state_dict": self.target_q_network.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "normalizer_mean": self.normalizer.obs_rms.mean,
            "normalizer_var": self.normalizer.obs_rms.var,
            "normalizer_count": self.normalizer.obs_rms.count,
            "ret_rms_mean": self.normalizer.ret_rms.mean,
            "ret_rms_var": self.normalizer.ret_rms.var,
            "ret_rms_count": self.normalizer.ret_rms.count,
            "action_map": self.action_map,
            "epsilon": self.epsilon
        }, path)
        print(f"Model saved to {path}")

    def load(self, path):
        checkpoint = torch.load(path, map_location=device)
        self.q_network.load_state_dict(checkpoint["q_network_state_dict"])
        self.target_q_network.load_state_dict(checkpoint["target_q_network_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        self.normalizer.obs_rms.mean = checkpoint["normalizer_mean"]
        self.normalizer.obs_rms.var = checkpoint["normalizer_var"]
        self.normalizer.obs_rms.count = checkpoint["normalizer_count"]
        self.normalizer.ret_rms.mean = checkpoint["ret_rms_mean"]
        self.normalizer.ret_rms.var = checkpoint["ret_rms_var"]
        self.normalizer.ret_rms.count = checkpoint["ret_rms_count"]

        self.action_map = checkpoint["action_map"]
        self.epsilon = checkpoint["epsilon"]

        print(f"Model loaded from {path}")
        return self

def train_dqn(env_name="PandaReach-v3", total_timesteps=1000000, log_interval=10, save_interval=10000, eval_interval=50000, n_eval_episodes=10):
    print(f"Training on {env_name}...")

    env = gym.make(env_name)
    eval_env = gym.make(env_name)

    os.makedirs('dqn_models', exist_ok=True)
    os.makedirs('dqn_logs', exist_ok=True)

    # DQN agent
    dqn = DQN(
        env=env,
        learning_rate=5e-4,
        buffer_size=100000,
        batch_size=128,
        target_update_freq=1000,
        gamma=0.99,
        tau=1.0,
        eps_start=1.0,
        eps_end=0.05,
        eps_decay=50000,
        double_q=True,
        policy_kwargs={"net_arch": [256, 256]}
    )

    checkpoint_callback = lambda local_vars, global_vars: (
        local_vars['self'].save(f"dqn_models/{env_name}_{local_vars['self'].timesteps}.pth")
        if local_vars['self'].timesteps % save_interval == 0
        else None
    )

    dqn.learn(
        total_timesteps=total_timesteps,
        callback=checkpoint_callback,
        log_interval=log_interval,
        train_freq=4,
        eval_env=eval_env,
        eval_freq=eval_interval,
        n_eval_episodes=n_eval_episodes
    )

    dqn.save(f"dqn_models/{env_name}_final.pth")

    plt.figure(figsize=(15, 10))

    # Mean Reward
    plt.subplot(3, 1, 1)
    plt.plot(dqn.logger["timesteps"], dqn.logger["mean_reward"])
    plt.title('Mean Reward')
    plt.xlabel('Timesteps')
    plt.ylabel('Reward')

    # Episode length
    plt.subplot(3, 1, 2)
    plt.plot(dqn.logger["timesteps"], dqn.logger["episode_length"])
    plt.title('Episode Length')
    plt.xlabel('Timesteps')
    plt.ylabel('Length')

    # Loss
    plt.subplot(3, 1, 3)
    plt.plot(dqn.logger["timesteps"], dqn.logger["loss"])
    plt.title('Loss')
    plt.xlabel('Timesteps')
    plt.ylabel('Loss')

    plt.tight_layout()
    plt.savefig('dqn_logs/training_progress.png')
    plt.close()

    return dqn

def evaluate_dqn(env_name="PandaReach-v3", model_path=None, n_eval_episodes=10, render=True):
    print(f"Evaluating on {env_name}...")

    env = gym.make(env_name, render_mode="human" if render else None)
    dqn = DQN(env=env)

    if model_path:
        dqn.load(model_path)
        print(f"Model loaded from {model_path}")
    else:
        print("No model loaded, using random policy")

    mean_reward, mean_length = dqn.evaluate_policy(n_eval_episodes=n_eval_episodes)

    env.close()
    return mean_reward, mean_length

if __name__ == "__main__":
    env_name = "PandaReach-v3"

    dqn_agent = train_dqn(
        env_name=env_name,
        total_timesteps=500000, 
        log_interval=10,
        save_interval=50000,
        eval_interval=50000
    )

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Timestep: 150802, Episodes: 16920
Mean reward: -1.60, Mean length: 2.60
Mean loss: 0.00079, Exploration rate: 0.050
Timestep: 150825, Episodes: 16930
Mean reward: -1.30, Mean length: 2.30
Mean loss: 0.00079, Exploration rate: 0.050
Timestep: 150852, Episodes: 16940
Mean reward: -1.70, Mean length: 2.70
Mean loss: 0.00081, Exploration rate: 0.050
Timestep: 150885, Episodes: 16950
Mean reward: -2.30, Mean length: 3.30
Mean loss: 0.00084, Exploration rate: 0.050
Timestep: 150916, Episodes: 16960
Mean reward: -2.10, Mean length: 3.10
Mean loss: 0.00084, Exploration rate: 0.050
Timestep: 150956, Episodes: 16970
Mean reward: -3.00, Mean length: 4.00
Mean loss: 0.00082, Exploration rate: 0.050
Timestep: 150985, Episodes: 16980
Mean reward: -1.90, Mean length: 2.90
Mean loss: 0.00084, Exploration rate: 0.050
Timestep: 151014, Episodes: 16990
Mean reward: -1.90, Mean length: 2.90
Mean loss: 0.00081, Exploration rate: 0.050
Timeste

In [None]:
#check