In [31]:
import gymnasium as gym
import numpy as np

## Expert Data

Use pre-trained DQN as expert policy https://github.com/Wangyuxuan-xuan/DQN-Flappybird/blob/main/runs/cartpole2.pt
Use expert policy to generate expert data

Set a limit with max_steps_per_trajectory to control the length of trajectory

In [32]:
import torch
import torch.nn as nn

class ExpertDQN(nn.Module):
    def __init__(self, state_dim=4, action_dim=2, hidden_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# Load pre-trained expert
expert = ExpertDQN(hidden_dim=10)
expert.load_state_dict(torch.load("Expert_data\cartpole2.pt"))
expert.eval()  # Set to evaluation mode

ExpertDQN(
  (fc1): Linear(in_features=4, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=2, bias=True)
)

In [33]:
def generate_expert_data(num_trajectories, max_steps_per_trajectory=200):
    env = gym.make("CartPole-v1")
    expert_states = []
    expert_actions = []
    
    for _ in range(num_trajectories):
        state, _ = env.reset()
        done = False
        step = 0
        while not done and step < max_steps_per_trajectory:
            with torch.no_grad():
                q_values = expert(torch.FloatTensor(state))
                action = torch.argmax(q_values).item()
            
            expert_states.append(state)
            expert_actions.append(action)
            state, _, done, _, _ = env.step(action)

            step += 1

    env.close()
    return np.array(expert_states), np.array(expert_actions)

# Generate expert data using your trained DQN
expert_states, expert_actions = generate_expert_data(num_trajectories=200, max_steps_per_trajectory=100)

In [34]:
expert_states, expert_actions

(array([[-0.04611707,  0.01877279,  0.00495186, -0.0146288 ],
        [-0.04574162, -0.17641982,  0.00465928,  0.27961236],
        [-0.04927001,  0.01863535,  0.01025153, -0.01159739],
        ...,
        [ 0.07934535, -0.16616303, -0.00435626,  0.2865145 ],
        [ 0.07602208,  0.02902077,  0.00137403, -0.00753919],
        [ 0.0766025 ,  0.22412299,  0.00122324, -0.29978827]],
       dtype=float32),
 array([0, 1, 1, ..., 1, 1, 0]))

In [35]:
def test_expert():
    env = gym.make("CartPole-v1", render_mode="human")
    state, _ = env.reset()
    done = False
    total_reward = 0
    
    while not done:
        q_values = expert(torch.FloatTensor(state))
        action = torch.argmax(q_values).item()
        state, reward, done, _, _ = env.step(action)
        total_reward += reward
    
    print(f"Expert reward: {total_reward}")
    env.close()

# test_expert()

## Define Policy and Discriminator Networks
The discriminator networks Concatenate state and action, This allows the network to learn a joint representation of the state and action

In [36]:
import torch
import torch.nn as nn
import torch.optim as optim

class Policy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, state):
        return self.net(state)

class Discriminator(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    # Concatenate state and action, This allows the network to learn a joint representation of the state and action
    def forward(self, state, action):
        sa = torch.cat([state, action], dim=-1)
        return self.net(sa)

# Initialize networks
state_dim = 4  # CartPole state dimension
action_dim = 2 # CartPole action space (0 or 1)
policy = Policy(state_dim, action_dim)
discriminator = Discriminator(state_dim, 1)  # Action is 0/1, so we encode as 0 or 1

# # Test networks
# state = torch.randn(1, state_dim)
# action = torch.randint(0, 2, (1, 1))
# print("Policy output:", policy(state).shape)
# print("Discriminator output:", discriminator(state, action).shape)

## Train GAIL

#### Loss in GAN
The discriminator's loss in GANs is usually something like maximizing log(D(real)) + log(1 - D(fake)), where D(real) is the probability that real data is real, and D(fake) is the probability that fake data is real

Entropy Term
$$
H(\pi(\cdot \mid s)) = - \sum_{a \in \mathcal{A}} \pi(a \mid s) \log (\pi(a \mid s))
$$


In [37]:
def train_gail(expert_states, expert_actions, num_epochs=100, batch_size=32, λ=0.1):
    # Initialize optimizers
    policy_optim = optim.Adam(policy.parameters(), lr=1e-3)
    disc_optim = optim.Adam(discriminator.parameters(), lr=1e-3)
    
    for epoch in range(num_epochs):
        # --- Step 1: Sample trajectories from current policy ---
        env = gym.make("CartPole-v1")
        policy_states, policy_actions = [], []
        state, _ = env.reset()
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state)
            action_probs = policy(state_tensor)  # Get action probabilities from policy
            action = torch.distributions.Categorical(action_probs).sample().item()  # Sample action
            policy_states.append(state)
            policy_actions.append(action)
            state, _, done, _, _ = env.step(action)  # Execute action in environment
        
        # --- Step 2: Update Discriminator ---
        # Prepare expert and policy data as state-action pairs
        expert_sa = torch.cat([
            torch.FloatTensor(expert_states),
            torch.FloatTensor(expert_actions).unsqueeze(1)  # Shape: (N, state_dim + 1)
        ], dim=1)
        policy_sa = torch.cat([
            torch.FloatTensor(policy_states),
            torch.FloatTensor(policy_actions).unsqueeze(1)  # Shape: (M, state_dim + 1)
        ], dim=1)
        
        # Discriminator loss components:
        # - Maximize log(D(expert_sa)): Expert labeled as "real"
        # - Maximize log(1 - D(policy_sa)): Policy labeled as "fake"
        real_output = discriminator(expert_sa[:, :4], expert_sa[:, 4:])  # D(expert_sa)
        fake_output = discriminator(policy_sa[:, :4], policy_sa[:, 4:])   # D(policy_sa)
        real_loss = -torch.log(real_output).mean()  # -E[log(D(expert))]
        fake_loss = -torch.log(1 - fake_output).mean()  # -E[log(1 - D(policy))]
        disc_loss = real_loss + fake_loss  # Total loss
        
        # Update discriminator
        disc_optim.zero_grad()
        disc_loss.backward()
        disc_optim.step()
        
        # --- Step 3: Update Policy using Discriminator as Reward ---
        # Convert policy data to tensors
        policy_states_tensor = torch.FloatTensor(policy_states)
        policy_actions_tensor = torch.FloatTensor(policy_actions).unsqueeze(1)
        
        # Compute rewards: log(D(s,a)) 
        # the discriminator is not exactly a reward function, 
        # but rather a way to estimate the likelihood of a state-action pair being from the expert's policy.
        # If the discriminator thinks it is more likely from expert then it will give a higher reward
        # Thus we can use this as a reward signal to reinforce the actions that can fool the discriminator better
        with torch.no_grad():
            rewards = torch.log(discriminator(policy_states_tensor, policy_actions_tensor))
        
        # Compute policy loss:
        # L = -E[log(π(a|s)) * reward] - λ * entropy(π)
        action_probs = policy(policy_states_tensor)
        entropy = -torch.sum(action_probs * torch.log(action_probs + 1e-10), dim=-1).mean()  # Avoid log(0)
        log_probs = torch.log(action_probs.gather(1, policy_actions_tensor.long()))  # log(π(a|s))
        # TODO mind the sign of entropy here
        # Here we flip the sign cuz we wanna minimizing
        # We wanna maximizing/reinforce the action with high reward, that equals to minimizing the negative reward
        # The entropy is positive, thus minimizing the negative entropy is maximizing the entropy
        policy_loss = -(log_probs * rewards).mean() - λ * entropy
        
        # Update policy
        policy_optim.zero_grad()
        policy_loss.backward()
        policy_optim.step()
        
        print(f"Epoch {epoch}: Disc Loss = {disc_loss.item():.3f}, Policy Loss = {policy_loss.item():.3f}")

In [38]:
train_gail(expert_states, expert_actions, num_epochs=1000, batch_size=32, λ=0.1)

Epoch 0: Disc Loss = 1.422, Policy Loss = -0.508
Epoch 1: Disc Loss = 1.368, Policy Loss = -0.591
Epoch 2: Disc Loss = 1.415, Policy Loss = -0.485
Epoch 3: Disc Loss = 1.382, Policy Loss = -0.507
Epoch 4: Disc Loss = 1.392, Policy Loss = -0.498
Epoch 5: Disc Loss = 1.394, Policy Loss = -0.555
Epoch 6: Disc Loss = 1.391, Policy Loss = -0.512
Epoch 7: Disc Loss = 1.363, Policy Loss = -0.453
Epoch 8: Disc Loss = 1.360, Policy Loss = -0.469
Epoch 9: Disc Loss = 1.362, Policy Loss = -0.475
Epoch 10: Disc Loss = 1.380, Policy Loss = -0.537
Epoch 11: Disc Loss = 1.370, Policy Loss = -0.532
Epoch 12: Disc Loss = 1.356, Policy Loss = -0.534
Epoch 13: Disc Loss = 1.357, Policy Loss = -0.548
Epoch 14: Disc Loss = 1.338, Policy Loss = -0.547
Epoch 15: Disc Loss = 1.358, Policy Loss = -0.545
Epoch 16: Disc Loss = 1.336, Policy Loss = -0.551
Epoch 17: Disc Loss = 1.344, Policy Loss = -0.532
Epoch 18: Disc Loss = 1.347, Policy Loss = -0.561
Epoch 19: Disc Loss = 1.314, Policy Loss = -0.546
Epoch 20: 

In [None]:
def test_policy(policy, num_episodes=10):
    env = gym.make("CartPole-v1", render_mode="human")
    for _ in range(num_episodes):
        state, _ = env.reset()
        done = False
        while not done:
            with torch.no_grad():
                action_probs = policy(torch.FloatTensor(state))
            action = torch.argmax(action_probs).item()  # Take most probable action
            state, _, done, _, _ = env.step(action)
    env.close()

test_policy(policy)