## Expert Data

In [2]:
import gymnasium as gym
import numpy as np



def generate_expert_data(num_trajectories=10):
    env = gym.make("CartPole-v1", render_mode="human")
    expert_states = []
    expert_actions = []
    
    for _ in range(num_trajectories):
        state, _ = env.reset()
        done = False
        while not done:
            # Heuristic expert: move left if pole leans left, else right
            angle = state[2]  # Pole angle
            action = 0 if angle < 0 else 1  # Simplified expert policy
            expert_states.append(state)
            expert_actions.append(action)
            state, _, done, _, _ = env.step(action)
    
    return np.array(expert_states), np.array(expert_actions)



expert_states, expert_actions = generate_expert_data()
expert_states, expert_actions


(array([[-0.0451982 ,  0.02113391, -0.01404097, -0.0441334 ],
        [-0.04477552, -0.17378391, -0.01492364,  0.24408661],
        [-0.0482512 , -0.36868957, -0.0100419 ,  0.53202516],
        ...,
        [ 0.18665619,  1.9400018 , -0.09427878, -2.3084128 ],
        [ 0.22545622,  1.7458605 , -0.14044704, -2.046171  ],
        [ 0.26037344,  1.552432  , -0.18137045, -1.8000408 ]],
       dtype=float32),
 array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,

## Define Policy and Discriminator Networks
The discriminator networks Concatenate state and action, This allows the network to learn a joint representation of the state and action

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

class Policy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, state):
        return self.net(state)

class Discriminator(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    # Concatenate state and action, This allows the network to learn a joint representation of the state and action
    def forward(self, state, action):
        sa = torch.cat([state, action], dim=-1)
        return self.net(sa)

# Initialize networks
state_dim = 4  # CartPole state dimension
action_dim = 2 # CartPole action space (0 or 1)
policy = Policy(state_dim, action_dim)
discriminator = Discriminator(state_dim, 1)  # Action is 0/1, so we encode as 0 or 1

# # Test networks
# state = torch.randn(1, state_dim)
# action = torch.randint(0, 2, (1, 1))
# print("Policy output:", policy(state).shape)
# print("Discriminator output:", discriminator(state, action).shape)

## Train GAIL

#### Loss in GAN
The discriminator's loss in GANs is usually something like maximizing log(D(real)) + log(1 - D(fake)), where D(real) is the probability that real data is real, and D(fake) is the probability that fake data is real

Entropy Term
$$
H(\pi(\cdot \mid s)) = - \sum_{a \in \mathcal{A}} \pi(a \mid s) \log (\pi(a \mid s))
$$


In [None]:
def train_gail(expert_states, expert_actions, num_epochs=100, batch_size=32, λ=0.1):
    # Initialize optimizers
    policy_optim = optim.Adam(policy.parameters(), lr=1e-3)
    disc_optim = optim.Adam(discriminator.parameters(), lr=1e-3)
    
    for epoch in range(num_epochs):
        # --- Step 1: Sample trajectories from current policy ---
        env = gym.make("CartPole-v1")
        policy_states, policy_actions = [], []
        state, _ = env.reset()
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state)
            action_probs = policy(state_tensor)  # Get action probabilities from policy
            action = torch.distributions.Categorical(action_probs).sample().item()  # Sample action
            policy_states.append(state)
            policy_actions.append(action)
            state, _, done, _, _ = env.step(action)  # Execute action in environment
        
        # --- Step 2: Update Discriminator ---
        # Prepare expert and policy data as state-action pairs
        expert_sa = torch.cat([
            torch.FloatTensor(expert_states),
            torch.FloatTensor(expert_actions).unsqueeze(1)  # Shape: (N, state_dim + 1)
        ], dim=1)
        policy_sa = torch.cat([
            torch.FloatTensor(policy_states),
            torch.FloatTensor(policy_actions).unsqueeze(1)  # Shape: (M, state_dim + 1)
        ], dim=1)
        
        # Discriminator loss components:
        # - Maximize log(D(expert_sa)): Expert labeled as "real"
        # - Maximize log(1 - D(policy_sa)): Policy labeled as "fake"
        real_output = discriminator(expert_sa[:, :4], expert_sa[:, 4:])  # D(expert_sa)
        fake_output = discriminator(policy_sa[:, :4], policy_sa[:, 4:])   # D(policy_sa)
        real_loss = -torch.log(real_output).mean()  # -E[log(D(expert))]
        fake_loss = -torch.log(1 - fake_output).mean()  # -E[log(1 - D(policy))]
        disc_loss = real_loss + fake_loss  # Total loss
        
        # Update discriminator
        disc_optim.zero_grad()
        disc_loss.backward()
        disc_optim.step()
        
        # --- Step 3: Update Policy using Discriminator as Reward ---
        # Convert policy data to tensors
        policy_states_tensor = torch.FloatTensor(policy_states)
        policy_actions_tensor = torch.FloatTensor(policy_actions).unsqueeze(1)
        
        # Compute rewards: log(D(s,a)) 
        # the discriminator is not exactly a reward function, 
        # but rather a way to estimate the likelihood of a state-action pair being from the expert's policy.
        # If the discriminator thinks it is more likely from expert then it will give a higher reward
        # Thus we can use this as a reward signal to reinforce the actions that can fool the discriminator better
        with torch.no_grad():
            rewards = torch.log(discriminator(policy_states_tensor, policy_actions_tensor))
        
        # Compute policy loss:
        # L = -E[log(π(a|s)) * reward] - λ * entropy(π)
        action_probs = policy(policy_states_tensor)
        entropy = -torch.sum(action_probs * torch.log(action_probs + 1e-10), dim=-1).mean()  # Avoid log(0)
        log_probs = torch.log(action_probs.gather(1, policy_actions_tensor.long()))  # log(π(a|s))
        # TODO mind the sign of entropy here
        # Here we flip the sign cuz we wanna minimizing
        # We wanna maximizing/reinforce the action with high reward, that equals to minimizing the negative reward
        # The entropy is positive, thus minimizing the negative entropy is maximizing the entropy
        policy_loss = -(log_probs * rewards).mean() - λ * entropy
        
        # Update policy
        policy_optim.zero_grad()
        policy_loss.backward()
        policy_optim.step()
        
        print(f"Epoch {epoch}: Disc Loss = {disc_loss.item():.3f}, Policy Loss = {policy_loss.item():.3f}")

In [13]:
train_gail(expert_states, expert_actions, num_epochs=100, batch_size=32, λ=0.1)

Epoch 0: Disc Loss = 1.343, Policy Loss = -0.412
Epoch 1: Disc Loss = 1.345, Policy Loss = -0.430
Epoch 2: Disc Loss = 1.389, Policy Loss = -0.479
Epoch 3: Disc Loss = 1.222, Policy Loss = -0.492
Epoch 4: Disc Loss = 1.252, Policy Loss = -0.436
Epoch 5: Disc Loss = 1.256, Policy Loss = -0.431
Epoch 6: Disc Loss = 1.286, Policy Loss = -0.435
Epoch 7: Disc Loss = 1.215, Policy Loss = -0.476
Epoch 8: Disc Loss = 1.186, Policy Loss = -0.486
Epoch 9: Disc Loss = 1.199, Policy Loss = -0.490
Epoch 10: Disc Loss = 1.310, Policy Loss = -0.481
Epoch 11: Disc Loss = 1.235, Policy Loss = -0.439
Epoch 12: Disc Loss = 1.223, Policy Loss = -0.455
Epoch 13: Disc Loss = 1.226, Policy Loss = -0.479
Epoch 14: Disc Loss = 1.289, Policy Loss = -0.486
Epoch 15: Disc Loss = 1.321, Policy Loss = -0.431
Epoch 16: Disc Loss = 1.271, Policy Loss = -0.479
Epoch 17: Disc Loss = 1.274, Policy Loss = -0.420
Epoch 18: Disc Loss = 1.235, Policy Loss = -0.505
Epoch 19: Disc Loss = 1.217, Policy Loss = -0.483
Epoch 20: 

In [14]:
def test_policy(policy, num_episodes=10):
    env = gym.make("CartPole-v1", render_mode="human")
    for _ in range(num_episodes):
        state, _ = env.reset()
        done = False
        while not done:
            with torch.no_grad():
                action_probs = policy(torch.FloatTensor(state))
            action = torch.argmax(action_probs).item()  # Take most probable action
            state, _, done, _, _ = env.step(action)
    env.close()

test_policy(policy)  # Visualize the learned policy!

KeyboardInterrupt: 