Import the Dependencies

In [1]:
# Importing Packages
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
from imitation.algorithms.adversarial.airl import AIRL
from imitation.util import util
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.util.util import make_vec_env
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.rewards.reward_nets import RewardNet
from imitation.util.networks import RunningNorm
from imitation.util import networks, util
import matplotlib.pyplot as plt
import pandas as pd
import torch as th
import torch.nn as nn
import torch.nn.init as init

pygame 2.5.0 (SDL 2.28.0, Python 3.9.20)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
SEED = 42
np.random.seed(SEED)
th.manual_seed(SEED)
if th.cuda.is_available():
    th.cuda.manual_seed_all(SEED)

Define the true reward function's weights

In [3]:
num_features = 16
num_actions = 4
weights = np.random.uniform(0, 100, num_features)

In [4]:
weights

array([37.45401188, 95.07143064, 73.19939418, 59.86584842, 15.60186404,
       15.59945203,  5.80836122, 86.61761458, 60.11150117, 70.80725778,
        2.05844943, 96.99098522, 83.24426408, 21.23391107, 18.18249672,
       18.34045099])

In [5]:
class CustomEnv(gym.Env):
    def __init__(self, num_features: int = num_features, num_actions: int = num_actions, weights=None):
        super().__init__()

        self.num_features = num_features
        self.num_actions = num_actions

        # Observation space is discrete 
        self.observation_space = spaces.Discrete(num_features)
        self.action_space = spaces.Discrete(num_actions) 

        # Define transition matrix: shape (num_features, num_actions, num_features)
        self.transition_matrix = np.random.rand(num_features, num_actions, num_features)
        self.transition_matrix /= self.transition_matrix.sum(axis=2, keepdims=True)  # Normalize to ensure probabilities

        self.state = None
        self.max_steps = 1000
        self.current_step = 0

        # Initialize weights for reward calculation and normalize them
        self.weights = weights
        self.weights = self.weights

    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)
        if seed is not None:
            np.random.seed(seed)
        # Sample a discrete state (integer)
        self.state = np.random.choice(self.num_features)
        self.current_step = 0
        return self.state, {}

    def step(self, action):
        # Calculate reward based on the normalized weights for the current state
        reward = self.weights[self.state]

        next_state_probs = self.transition_matrix[self.state, action]
        next_state = np.random.choice(range(self.num_features), p=next_state_probs)

        if next_state != self.state:
            reward += 0.1  # Small bonus for transitioning to a different state

        # Update the state to the new discrete state
        self.state = next_state

        # Increment step counter
        self.current_step += 1

        # Check if the episode is done based on the step limit
        done = self.current_step >= self.max_steps
        truncated = False

        info = {"obs": self.state, "rews": reward}

        # Return the next state (as a discrete integer), the reward, and whether the episode is done
        return self.state, reward, done, truncated, info

    def render(self, mode='human'):
        pass



Register the custom environment to use OpenAI's Gym API (intializes the environment with the true reward function's weights to determine state->action behavior)

In [6]:
gym.register(id='CustomEnv-v0', entry_point=lambda: CustomEnv(weights=weights), max_episode_steps=100)

Create a vectorized environment for efficient training (train multiple instances of same environment simultaneously)

In [7]:
venv = util.make_vec_env("CustomEnv-v0", rng=np.random.default_rng(SEED), n_envs=4, post_wrappers=[lambda env, _: RolloutInfoWrapper(env)])

Initalize a PPO agent to learn the environment over 10000 steps using the "MlpPolicy"
PPO (Proximal Policy Optimization):
    1. collects experiences (states, actions, rewards over some episodes)
    2. try to find advantageous actions
    3. Update policy

Using the trained policy, collect data of the trajectories (interactions of the trained agent with the environment from some episodes)

In [8]:
expert_policy = PPO('MlpPolicy', venv, verbose=1, seed=SEED)
expert_policy.learn(total_timesteps=20000)

# Collect rollouts
rollouts = rollout.rollout(
    expert_policy,
    venv,
    rollout.make_sample_until(min_episodes=500),
    rng=np.random.default_rng(SEED),
)

Using cpu device




---------------------------------
| rollout/           |          |
|    ep_len_mean     | 100      |
|    ep_rew_mean     | 4.75e+03 |
| time/              |          |
|    fps             | 15259    |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 8192     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 100          |
|    ep_rew_mean          | 4.72e+03     |
| time/                   |              |
|    fps                  | 6674         |
|    iterations           | 2            |
|    time_elapsed         | 2            |
|    total_timesteps      | 16384        |
| train/                  |              |
|    approx_kl            | 0.0002125399 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.39        |
|    explained_variance   | -6.93e-05    |
|    learning_r

In [9]:
save_rollouts = pd.DataFrame(rollouts)
save_rollouts.to_csv("rollouts.csv")

Define a Simple Linear Model Reward Function to Learn During Training

In [10]:
class LinearRewardNet(RewardNet):
    def __init__(self, observation_space, action_space):
        super().__init__(observation_space, action_space)

        # Handle different types of observation spaces
        if isinstance(observation_space, spaces.Box):
            self.state_dim = observation_space.shape[0]
        elif isinstance(observation_space, spaces.Discrete):
            self.state_dim = observation_space.n  # Number of possible discrete states
        else:
            raise ValueError(f"Unsupported observation space: {type(observation_space)}")
        
        # Define a linear layer that maps from state (one-hot encoded) to reward
        self.linear = nn.Linear(self.state_dim, 1)  
        init.xavier_uniform_(self.linear.weight)
        init.constant_(self.linear.bias, 0.0)  

    def forward(
        self,
        state: th.Tensor,
        action: th.Tensor,  
        next_state: th.Tensor,
        done: th.Tensor,
    ) -> th.Tensor:
        batch_size = state.shape[0]
        # print(f"BATCH SIZE = {batch_size}")
        # print(f"STATE SHAPE = {state.shape}")
        reward = self.linear(state)
        # print(f"REWARD SHAPE = {reward.shape}")


        reward = reward.squeeze(-1)

        return reward

In [11]:
# Imitation library's nonlinear model
reward_net = BasicShapedRewardNet(
    observation_space=venv.observation_space,
    action_space=venv.action_space,
    normalize_input_layer=RunningNorm,
)

In [12]:
reward_net = LinearRewardNet(
    observation_space=venv.observation_space,
    action_space=venv.action_space,
)

Initialize AIRL to be trained on the environment, with the expert data, and the same MlpPolicy as the generator to train the data

In [13]:
learner = PPO(
    env=venv,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0005,
    gamma=0.95,
    clip_range=0.1,
    vf_coef=0.1,
    n_epochs=5,
    seed=SEED,
)

airl_trainer = AIRL(
    venv=venv,
    demonstrations=rollouts,
    demo_batch_size=2048,
    gen_replay_buffer_capacity=512,
    gen_algo= expert_policy,
    reward_net=reward_net,
    n_disc_updates_per_round=16,
    gen_train_timesteps=20000,
)

In [14]:
venv.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(airl_trainer.gen_algo, venv, 100, return_episode_rewards=True)

# Train AIRL
airl_trainer.train(total_timesteps=400000)
venv.seed(SEED)

# Evaluate policy after training
learner_rewards_after_training, _ = evaluate_policy(airl_trainer.gen_algo, venv, 100, return_episode_rewards=True)


# Print results
print("Rewards before training:", learner_rewards_before_training)
print("Rewards after training:", learner_rewards_after_training)

print("Mean Rewards before training:", np.mean(learner_rewards_before_training))
print("Mean Rewards after training:", np.mean(learner_rewards_after_training))

round:   0%|          | 0/20 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 100      |
|    gen/rollout/ep_rew_mean  | 4.78e+03 |
|    gen/time/fps             | 10090    |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 0        |
|    gen/time/total_timesteps | 32768    |
------------------------------------------
----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.78e+03    |
|    gen/rollout/ep_rew_wrapped_mean | 10.8        |
|    gen/time/fps                    | 5588        |
|    gen/time/iterations             | 2           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 40960       |
|    gen/train/approx_kl             | 0.010390077 |
|    gen/train/clip_fraction         | 0.0578      |
|    gen/train/clip_range     

round:   5%|▌         | 1/20 [00:06<02:06,  6.67s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.64e+03    |
|    gen/rollout/ep_rew_wrapped_mean | 11.7        |
|    gen/time/fps                    | 10912       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 57344       |
|    gen/train/approx_kl             | 0.010310242 |
|    gen/train/clip_fraction         | 0.0376      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.36       |
|    gen/train/explained_variance    | 0.0176      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 3.74        |
|    gen/train/n_updates             | 60          |
|    gen/train/policy_gradient_loss  | -0.00491    |
|    gen/train/value_loss            | 6.56   

round:  10%|█         | 2/20 [00:13<01:57,  6.54s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.76e+03    |
|    gen/rollout/ep_rew_wrapped_mean | 7.81        |
|    gen/time/fps                    | 10974       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 81920       |
|    gen/train/approx_kl             | 0.011175511 |
|    gen/train/clip_fraction         | 0.0562      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.34       |
|    gen/train/explained_variance    | 0.0241      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.781       |
|    gen/train/n_updates             | 90          |
|    gen/train/policy_gradient_loss  | -0.00565    |
|    gen/train/value_loss            | 3.17   

round:  15%|█▌        | 3/20 [00:20<01:56,  6.83s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.73e+03    |
|    gen/rollout/ep_rew_wrapped_mean | 4.8         |
|    gen/time/fps                    | 11041       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 106496      |
|    gen/train/approx_kl             | 0.010782806 |
|    gen/train/clip_fraction         | 0.0694      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.26       |
|    gen/train/explained_variance    | 0.0414      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.822       |
|    gen/train/n_updates             | 120         |
|    gen/train/policy_gradient_loss  | -0.00535    |
|    gen/train/value_loss            | 1.9    

round:  20%|██        | 4/20 [00:26<01:46,  6.68s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.73e+03    |
|    gen/rollout/ep_rew_wrapped_mean | 1.79        |
|    gen/time/fps                    | 11078       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 131072      |
|    gen/train/approx_kl             | 0.008802869 |
|    gen/train/clip_fraction         | 0.0687      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.22       |
|    gen/train/explained_variance    | 0.0656      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.839       |
|    gen/train/n_updates             | 150         |
|    gen/train/policy_gradient_loss  | -0.00568    |
|    gen/train/value_loss            | 1.38   

round:  25%|██▌       | 5/20 [00:33<01:38,  6.59s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.69e+03    |
|    gen/rollout/ep_rew_wrapped_mean | -0.958      |
|    gen/time/fps                    | 11192       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 155648      |
|    gen/train/approx_kl             | 0.008709215 |
|    gen/train/clip_fraction         | 0.0692      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.13       |
|    gen/train/explained_variance    | 0.0659      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.54        |
|    gen/train/n_updates             | 180         |
|    gen/train/policy_gradient_loss  | -0.00506    |
|    gen/train/value_loss            | 1.25   

round:  30%|███       | 6/20 [00:40<01:34,  6.76s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.73e+03    |
|    gen/rollout/ep_rew_wrapped_mean | -3.97       |
|    gen/time/fps                    | 11279       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 180224      |
|    gen/train/approx_kl             | 0.008354871 |
|    gen/train/clip_fraction         | 0.0562      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.06       |
|    gen/train/explained_variance    | 0.091       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.597       |
|    gen/train/n_updates             | 210         |
|    gen/train/policy_gradient_loss  | -0.00436    |
|    gen/train/value_loss            | 1.1    

round:  35%|███▌      | 7/20 [00:46<01:26,  6.65s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.65e+03    |
|    gen/rollout/ep_rew_wrapped_mean | -5.97       |
|    gen/time/fps                    | 11220       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 204800      |
|    gen/train/approx_kl             | 0.008793866 |
|    gen/train/clip_fraction         | 0.0911      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.994      |
|    gen/train/explained_variance    | 0.0891      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.361       |
|    gen/train/n_updates             | 240         |
|    gen/train/policy_gradient_loss  | -0.00659    |
|    gen/train/value_loss            | 0.887  

round:  40%|████      | 8/20 [00:52<01:18,  6.52s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.71e+03    |
|    gen/rollout/ep_rew_wrapped_mean | -9.48       |
|    gen/time/fps                    | 11227       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 229376      |
|    gen/train/approx_kl             | 0.008034572 |
|    gen/train/clip_fraction         | 0.0732      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.957      |
|    gen/train/explained_variance    | 0.0961      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.446       |
|    gen/train/n_updates             | 270         |
|    gen/train/policy_gradient_loss  | -0.00554    |
|    gen/train/value_loss            | 0.886  

round:  45%|████▌     | 9/20 [00:59<01:12,  6.57s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.71e+03    |
|    gen/rollout/ep_rew_wrapped_mean | -11.7       |
|    gen/time/fps                    | 9751        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 253952      |
|    gen/train/approx_kl             | 0.005858802 |
|    gen/train/clip_fraction         | 0.0404      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.887      |
|    gen/train/explained_variance    | 0.0993      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.383       |
|    gen/train/n_updates             | 300         |
|    gen/train/policy_gradient_loss  | -0.00178    |
|    gen/train/value_loss            | 0.771  

round:  50%|█████     | 10/20 [01:07<01:08,  6.84s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.69e+03    |
|    gen/rollout/ep_rew_wrapped_mean | -14.5       |
|    gen/time/fps                    | 11051       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 278528      |
|    gen/train/approx_kl             | 0.005741942 |
|    gen/train/clip_fraction         | 0.0697      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.798      |
|    gen/train/explained_variance    | 0.0928      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.421       |
|    gen/train/n_updates             | 330         |
|    gen/train/policy_gradient_loss  | -0.00498    |
|    gen/train/value_loss            | 0.771  

round:  55%|█████▌    | 11/20 [01:14<01:03,  7.06s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.66e+03    |
|    gen/rollout/ep_rew_wrapped_mean | -16.5       |
|    gen/time/fps                    | 9858        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 303104      |
|    gen/train/approx_kl             | 0.005797805 |
|    gen/train/clip_fraction         | 0.0586      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.775      |
|    gen/train/explained_variance    | 0.105       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.41        |
|    gen/train/n_updates             | 360         |
|    gen/train/policy_gradient_loss  | -0.00414    |
|    gen/train/value_loss            | 0.848  

round:  60%|██████    | 12/20 [01:22<00:58,  7.26s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 100          |
|    gen/rollout/ep_rew_mean         | 4.67e+03     |
|    gen/rollout/ep_rew_wrapped_mean | -19.2        |
|    gen/time/fps                    | 9680         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 327680       |
|    gen/train/approx_kl             | 0.0046292916 |
|    gen/train/clip_fraction         | 0.0384       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.72        |
|    gen/train/explained_variance    | 0.102        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 0.306        |
|    gen/train/n_updates             | 390          |
|    gen/train/policy_gradient_loss  | -0.0018      |
|    gen/train/value_loss   

round:  65%|██████▌   | 13/20 [01:30<00:53,  7.66s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 100          |
|    gen/rollout/ep_rew_mean         | 4.74e+03     |
|    gen/rollout/ep_rew_wrapped_mean | -21.7        |
|    gen/time/fps                    | 8657         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 352256       |
|    gen/train/approx_kl             | 0.0042950725 |
|    gen/train/clip_fraction         | 0.0472       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.716       |
|    gen/train/explained_variance    | 0.0991       |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 0.472        |
|    gen/train/n_updates             | 420          |
|    gen/train/policy_gradient_loss  | -0.0028      |
|    gen/train/value_loss   

round:  70%|███████   | 14/20 [01:38<00:45,  7.62s/it]

--------------------------------------------------
| raw/                               |           |
|    gen/rollout/ep_len_mean         | 100       |
|    gen/rollout/ep_rew_mean         | 4.65e+03  |
|    gen/rollout/ep_rew_wrapped_mean | -24.4     |
|    gen/time/fps                    | 10716     |
|    gen/time/iterations             | 1         |
|    gen/time/time_elapsed           | 0         |
|    gen/time/total_timesteps        | 376832    |
|    gen/train/approx_kl             | 0.0040518 |
|    gen/train/clip_fraction         | 0.0448    |
|    gen/train/clip_range            | 0.2       |
|    gen/train/entropy_loss          | -0.661    |
|    gen/train/explained_variance    | 0.11      |
|    gen/train/learning_rate         | 0.0003    |
|    gen/train/loss                  | 0.268     |
|    gen/train/n_updates             | 450       |
|    gen/train/policy_gradient_loss  | -0.00252  |
|    gen/train/value_loss            | 0.671     |
-------------------------------

round:  75%|███████▌  | 15/20 [01:47<00:39,  7.98s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.62e+03    |
|    gen/rollout/ep_rew_wrapped_mean | -26.5       |
|    gen/time/fps                    | 7992        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 401408      |
|    gen/train/approx_kl             | 0.003811107 |
|    gen/train/clip_fraction         | 0.0475      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.607      |
|    gen/train/explained_variance    | 0.11        |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.416       |
|    gen/train/n_updates             | 480         |
|    gen/train/policy_gradient_loss  | -0.00277    |
|    gen/train/value_loss            | 0.825  

round:  80%|████████  | 16/20 [01:54<00:31,  7.88s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 100         |
|    gen/rollout/ep_rew_mean         | 4.7e+03     |
|    gen/rollout/ep_rew_wrapped_mean | -29.5       |
|    gen/time/fps                    | 9795        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 425984      |
|    gen/train/approx_kl             | 0.003427391 |
|    gen/train/clip_fraction         | 0.0443      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.598      |
|    gen/train/explained_variance    | 0.104       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.264       |
|    gen/train/n_updates             | 510         |
|    gen/train/policy_gradient_loss  | -0.00255    |
|    gen/train/value_loss            | 0.767  

round:  85%|████████▌ | 17/20 [02:02<00:23,  7.75s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 100          |
|    gen/rollout/ep_rew_mean         | 4.71e+03     |
|    gen/rollout/ep_rew_wrapped_mean | -31.7        |
|    gen/time/fps                    | 10159        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 450560       |
|    gen/train/approx_kl             | 0.0033179338 |
|    gen/train/clip_fraction         | 0.0414       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.545       |
|    gen/train/explained_variance    | 0.106        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 0.248        |
|    gen/train/n_updates             | 540          |
|    gen/train/policy_gradient_loss  | -0.00195     |
|    gen/train/value_loss   

round:  90%|█████████ | 18/20 [02:09<00:14,  7.47s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 100          |
|    gen/rollout/ep_rew_mean         | 4.75e+03     |
|    gen/rollout/ep_rew_wrapped_mean | -34.3        |
|    gen/time/fps                    | 10734        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 475136       |
|    gen/train/approx_kl             | 0.0025103237 |
|    gen/train/clip_fraction         | 0.0372       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.494       |
|    gen/train/explained_variance    | 0.108        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 0.555        |
|    gen/train/n_updates             | 570          |
|    gen/train/policy_gradient_loss  | -0.00214     |
|    gen/train/value_loss   

round:  95%|█████████▌| 19/20 [02:15<00:07,  7.16s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 100          |
|    gen/rollout/ep_rew_mean         | 4.73e+03     |
|    gen/rollout/ep_rew_wrapped_mean | -36.1        |
|    gen/time/fps                    | 11077        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 499712       |
|    gen/train/approx_kl             | 0.0022415079 |
|    gen/train/clip_fraction         | 0.0385       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.446       |
|    gen/train/explained_variance    | 0.107        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 0.368        |
|    gen/train/n_updates             | 600          |
|    gen/train/policy_gradient_loss  | -0.0022      |
|    gen/train/value_loss   

round: 100%|██████████| 20/20 [02:22<00:00,  7.13s/it]


Rewards before training: [4768.710115, 4692.632661, 4869.663429, 4784.441342, 4545.624197, 4428.667949, 4534.754933, 4713.298339, 4764.056606, 3870.04105, 4911.166673, 4420.958159, 5799.934528, 5004.778598, 5026.814824, 4017.348386, 4773.924581, 4207.680165, 5264.951371, 5377.535353, 4624.131204, 4270.633714, 4603.282079, 4743.720706, 4941.311618, 4891.641531, 4385.264596, 5191.852421, 5083.66841, 4733.089949, 4505.029009, 4300.593913, 4935.688861, 4909.506509, 4542.052414, 5269.351496, 5124.718136, 3979.485401, 5243.085019, 4626.234915, 5466.679829, 4788.53652, 4898.127843, 4685.616203, 4635.893149, 3908.566627, 4695.442003, 4845.485937, 5342.25892, 4394.702881, 4623.211026, 4935.420394, 5208.647175, 4691.422595, 4873.151505, 4317.530833, 5060.209517, 4473.888249, 4759.913742, 4984.560705, 5227.746115, 4352.277946, 4597.445216, 5539.130039, 4658.579698, 4499.066206, 4828.247455, 5087.109284, 5107.076415, 4944.104384, 4679.261463, 4807.64509, 4011.939786, 4426.38132, 4252.59888, 4966.4

Finding MSE and MAE over Individual State-Action Pairs

In [15]:
def evaluate(env, reward_net, weights, num_samples=100):
    true_rewards = []
    learned_rewards = []

    for _ in range(num_samples):
        # Sample a discrete state and action from the environment
        state = env.observation_space.sample()  
        action = env.action_space.sample()   

        # Compute true reward based on the state's corresponding weight
        true_reward = weights[state]  
        print(f"TRUE REWARD = {true_reward}")
        true_rewards.append(true_reward)

        # Convert state to one-hot encoded tensor
        num_states = env.observation_space.n  # Number of discrete states
        state_one_hot = np.zeros(num_states)
        state_one_hot[state] = 1
        state_tensor = th.FloatTensor(state_one_hot).unsqueeze(0)  

        # Convert action to one-hot encoded tensor
        num_actions = env.action_space.n  # Number of discrete actions
        action_one_hot = np.zeros(num_actions)
        action_one_hot[action] = 1
        action_tensor = th.FloatTensor(action_one_hot).unsqueeze(0)  

        # Compute learned reward using the reward network
        learned_reward = reward_net(state_tensor, action_tensor, None, None).item()
        print(f"LEARNED REWARD = {learned_reward}")
        learned_rewards.append(learned_reward)

    # Convert rewards to tensors for calculation
    true_rewards_tensor = th.tensor(true_rewards, dtype=th.float32)
    learned_rewards_tensor = th.tensor(learned_rewards, dtype=th.float32)

    print("True rewards:", true_rewards_tensor)
    print("Learned rewards:", learned_rewards_tensor)
    
    # Calculate Mean Absolute Error (MAE) and Mean Squared Error (MSE)
    mae = th.mean(th.abs(true_rewards_tensor - learned_rewards_tensor)).item()
    mse = th.mean((true_rewards_tensor - learned_rewards_tensor) ** 2).item()
    return mse, mae


In [16]:
mse, mae = evaluate(venv, reward_net, weights, 250)

TRUE REWARD = 15.601864044243651
LEARNED REWARD = 0.02954721450805664
TRUE REWARD = 5.8083612168199465
LEARNED REWARD = -0.3113442659378052
TRUE REWARD = 15.601864044243651
LEARNED REWARD = 0.02954721450805664
TRUE REWARD = 15.599452033620265
LEARNED REWARD = -0.3942369520664215
TRUE REWARD = 70.80725777960456
LEARNED REWARD = -0.31745225191116333
TRUE REWARD = 37.454011884736246
LEARNED REWARD = -1.02402925491333
TRUE REWARD = 95.07143064099162
LEARNED REWARD = -0.1516696661710739
TRUE REWARD = 2.0584494295802447
LEARNED REWARD = -0.45703357458114624
TRUE REWARD = 70.80725777960456
LEARNED REWARD = -0.31745225191116333
TRUE REWARD = 2.0584494295802447
LEARNED REWARD = -0.45703357458114624
TRUE REWARD = 59.86584841970366
LEARNED REWARD = -0.8066972494125366
TRUE REWARD = 2.0584494295802447
LEARNED REWARD = -0.45703357458114624
TRUE REWARD = 2.0584494295802447
LEARNED REWARD = -0.45703357458114624
TRUE REWARD = 2.0584494295802447
LEARNED REWARD = -0.45703357458114624
TRUE REWARD = 60.11

Finding MSE and MAE over Trajectories

In [17]:
print("MSE between true and learned rewards:", mse)
print("MAE between true and learned rewards:", mae)


MSE between true and learned rewards: 3446.799072265625
MAE between true and learned rewards: 48.63444137573242


In [18]:
def evaluate_policy_rewards(policy, env, reward_net, num_episodes=10):
    true_rewards = []
    learned_rewards = []

    for _ in range(num_episodes):
        # Reset the environment and get the initial observation for all environments
        obs = env.reset()  # VecEnv returns a batch of observations
        print(f"OBS = {obs}")
        
        done = [False] * env.num_envs  # Initialize the done flags for all environments

        while not any(done):  # Continue until all environments are done
            actions = []
            for i in range(env.num_envs):
                # Get the action from the policy for each environment
                action, _ = policy.predict(obs[i])  # Predict action based on each environment's observation
                actions.append(action)  # Store the action for each environment

            # print(f"ACTIONS = {actions}")
            
            # Step the vectorized environment with actions for all environments
            next_obs, true_reward, done, info = env.step(actions)  # Step all environments at once

            # print(f"NEXT_OBS = {next_obs}")
            # print(f"TRUE_REWARD = {true_reward}")
            # print(f"DONE = {done}")
            # print(f"INFO = {info}")

            # Process each environment's result separately
            for i in range(env.num_envs):
                print(f"Processing environment {i}:")
                # print(f"  OBS = {obs[i]}")
                # print(f"  NEXT_OBS = {next_obs[i]}")
                print(f"  TRUE_REWARD = {true_reward[i]}")
                # print(f"  DONE = {done[i]}")
                # print(f"  INFO = {info[i]}")

                # Append the true reward for the current environment
                true_rewards.append(true_reward[i])

                if isinstance(env.observation_space, spaces.Discrete):
                    num_states = env.observation_space.n
                    state_one_hot = np.zeros(num_states)
                    state_one_hot[obs[i]] = 1
                    state_tensor = th.FloatTensor(state_one_hot).unsqueeze(0)  
                else:
                    state_tensor = th.FloatTensor(obs[i]).unsqueeze(0)  

                num_actions = env.action_space.n
                action_one_hot = np.zeros(num_actions)
                action_one_hot[actions[i]] = 1
                action_tensor = th.FloatTensor(action_one_hot).unsqueeze(0)  

                learned_reward = reward_net(state_tensor, action_tensor, None, None).item()
                print(f"LEARNED REWARD = {learned_reward}")

                learned_rewards.append(learned_reward)

            # Move to the next observations for all environments
            obs = next_obs

    true_rewards_tensor = th.tensor(true_rewards, dtype=th.float32)
    learned_rewards_tensor = th.tensor(learned_rewards, dtype=th.float32)

    # Calculate Mean Squared Error (MSE)
    mse = th.mean((true_rewards_tensor - learned_rewards_tensor) ** 2).item()

    # Calculate Mean Absolute Error (MAE)
    mae = th.mean(th.abs(true_rewards_tensor - learned_rewards_tensor)).item()

    return mse, mae

mse, mae = evaluate_policy_rewards(expert_policy, venv, reward_net, num_episodes=10)



OBS = [ 2 15 11 14]
Processing environment 0:
  TRUE_REWARD = 73.29939270019531
LEARNED REWARD = -0.4889352023601532
Processing environment 1:
  TRUE_REWARD = 18.44045066833496
LEARNED REWARD = -0.2718842029571533
Processing environment 2:
  TRUE_REWARD = 97.09098815917969
LEARNED REWARD = -0.05955219268798828
Processing environment 3:
  TRUE_REWARD = 18.28249740600586
LEARNED REWARD = -0.29250359535217285
Processing environment 0:
  TRUE_REWARD = 18.28249740600586
LEARNED REWARD = -0.29250359535217285
Processing environment 1:
  TRUE_REWARD = 2.158449411392212
LEARNED REWARD = -0.45703357458114624
Processing environment 2:
  TRUE_REWARD = 83.3442611694336
LEARNED REWARD = -1.0217804908752441
Processing environment 3:
  TRUE_REWARD = 86.71761322021484
LEARNED REWARD = -0.2305210828781128
Processing environment 0:
  TRUE_REWARD = 59.96584701538086
LEARNED REWARD = -0.8066972494125366
Processing environment 1:
  TRUE_REWARD = 5.908361434936523
LEARNED REWARD = -0.3113442659378052
Process

In [19]:
print("MSE between true and learned rewards:", mse)
print("MAE between true and learned rewards:", mae)

MSE between true and learned rewards: 3280.7822265625
MAE between true and learned rewards: 46.88328552246094


For evaluation sample some states and actions, and try running it through both true reward and learned reward (find MAE/MSE)

want to also create a deterministic transition function with a NN

rational set theory
    - if you want to build a model to approximate something, there will be 10s of thousands with similar MSE and MAE
    - add regularization
        - longer you play less likely to quit
        - more likely to purchase games
    - using this we can reduce the model space