In [17]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import torch.nn.functional as F
import time

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available and set device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using {device} device")

# Wrapper to limit the number of steps per episode (Gymnasium API)
class LimitEpisodeSteps(gym.Wrapper):
    def __init__(self, env, max_steps):
        super(LimitEpisodeSteps, self).__init__(env)
        self.max_steps = max_steps
        self.current_step = 0

    def reset(self, **kwargs):
        self.current_step = 0
        return self.env.reset(**kwargs)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.current_step += 1
        done = terminated or truncated
        if self.current_step >= self.max_steps:
            done = True  # Force done when max steps are reached
            terminated = False  # Indicate truncation
            truncated = True
        return obs, reward, terminated, truncated, info

# Heuristic for determining the correct action based on pole angle
def heuristic_action(state):
    pole_angle = state[2]
    return 1 if pole_angle > 0 else 0

# Function to evaluate the model's accuracy and mean reward
def evaluate_model(agent, eval_env, num_episodes=5):
    total_rewards = []
    correct_actions = 0
    total_actions = 0
    agent.policy_network.eval()
    with torch.no_grad():
        for episode in range(num_episodes):
            obs, _ = eval_env.reset()
            terminated, truncated = False, False
            episode_reward = 0
            while not (terminated or truncated):
                obs_tensor = torch.tensor(obs, dtype=torch.float32).to(device)
                action_probs = agent.policy_network(obs_tensor)
                action = torch.argmax(action_probs).item()
                heuristic = heuristic_action(obs)
                if action == heuristic:
                    correct_actions += 1
                total_actions += 1
                obs, reward, terminated, truncated, _ = eval_env.step(action)
                episode_reward += reward
            total_rewards.append(episode_reward)
    mean_reward = np.mean(total_rewards)
    accuracy = (correct_actions / total_actions) * 100 if total_actions > 0 else 0.0
    agent.policy_network.train()
    return mean_reward, accuracy

# Function to create and wrap the environment with sequence length
def create_env(env_id, sequence_length):
    env = gym.make(env_id)
    env = LimitEpisodeSteps(env, max_steps=sequence_length)
    return env

# Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.action_head = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        action_logits = self.action_head(x)
        action_probs = F.softmax(action_logits, dim=-1)
        return action_probs

# Value Network
class ValueNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim=64):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.value_head = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        state_value = self.value_head(x)
        return state_value

# PPO Agent
class PPOAgent:
    def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, K_epochs=4, eps_clip=0.2):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.policy_network = PolicyNetwork(state_dim, action_dim).to(device)
        self.value_network = ValueNetwork(state_dim).to(device)
        self.optimizer = optim.Adam([
            {'params': self.policy_network.parameters(), 'lr': lr},
            {'params': self.value_network.parameters(), 'lr': lr}
        ])

        self.policy_network.train()
        self.value_network.train()

        # Memory
        self.memory = {
            'states': [],
            'actions': [],
            'log_probs': [],
            'rewards': [],
            'is_terminals': []
        }

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).to(device)
        action_probs = self.policy_network(state)
        dist = torch.distributions.Categorical(action_probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action).item()

    def store_transition(self, state, action, log_prob, reward, is_terminal):
        self.memory['states'].append(state)
        self.memory['actions'].append(action)
        self.memory['log_probs'].append(log_prob)
        self.memory['rewards'].append(reward)
        self.memory['is_terminals'].append(is_terminal)

    def clear_memory(self):
        for key in self.memory:
            self.memory[key] = []

    def compute_returns_and_advantages(self, next_value, dones):
        rewards = self.memory['rewards']
        returns = []
        Gt = next_value
        for reward, done in zip(reversed(rewards), reversed(dones)):
            Gt = reward + self.gamma * Gt * (1 - done)
            returns.insert(0, Gt)
        # Normalize returns
        returns = np.array(returns)
        if returns.std() != 0:
            returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        else:
            returns = returns - returns.mean()
        returns = torch.tensor(returns, dtype=torch.float32).to(device)
        return returns

    def update(self):
        # Convert lists to tensors using numpy array for efficiency
        if len(self.memory['states']) == 0:
            print("No transitions to update.")
            return

        states = torch.from_numpy(np.array(self.memory['states'], dtype=np.float32)).to(device)
        actions = torch.tensor(self.memory['actions'], dtype=torch.long).to(device)
        old_log_probs = torch.tensor(self.memory['log_probs'], dtype=torch.float32).to(device)
        rewards = self.memory['rewards']
        is_terminals = self.memory['is_terminals']

        # Compute returns
        with torch.no_grad():
            if len(self.memory['states']) > 0:
                last_state = torch.tensor(self.memory['states'][-1], dtype=torch.float32).to(device)
                next_value = self.value_network(last_state).item()
            else:
                next_value = 0
            returns = self.compute_returns_and_advantages(next_value, is_terminals)

        # Compute advantages
        values = self.value_network(states).squeeze(-1)  # [batch_size]
        advantages = returns - values  # [batch_size]

        # Ensure no NaNs in advantages
        if torch.isnan(advantages).any():
            print("NaNs detected in advantages. Skipping update.")
            self.clear_memory()
            return

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):
            # Get action probabilities
            action_probs = self.policy_network(states)  # [batch_size, action_dim]
            dist = torch.distributions.Categorical(action_probs)
            entropy = dist.entropy().mean()
            new_log_probs = dist.log_prob(actions)  # [batch_size]

            # Ratio for clipping
            ratios = torch.exp(new_log_probs - old_log_probs)  # [batch_size]

            # Surrogate loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values, returns)
            loss = policy_loss + value_loss - 0.01 * entropy

            # Check for NaNs in loss
            if torch.isnan(loss):
                print("NaNs detected in loss. Skipping update.")
                self.clear_memory()
                return

            # Take gradient step with gradient clipping
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy_network.parameters(), max_norm=0.5)
            self.optimizer.step()

        # Clear memory
        self.clear_memory()

# CartPole environment setup
env_id = "CartPole-v1"

# Training parameters
initial_sequence_length = 1  # Start with 1 step
max_sequence_length = 500  # Maximum sequence length to avoid infinite loop
timesteps_per_sequence = 1000  # Timesteps to train per sequence length
num_eval_episodes = 5  # Number of episodes to evaluate after each training
accuracy_threshold = 95.0  # Percentage accuracy required to increase sequence length

# Initialize environment
env = create_env(env_id, initial_sequence_length)

# Get state and action dimensions
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Initialize PPO Agent
agent = PPOAgent(state_dim, action_dim)

# Training loop with strict continual learning
sequence_length = initial_sequence_length

start_time = time.time()

while sequence_length <= max_sequence_length:
    print(f"\nTraining on sequence length: {sequence_length}")
    timesteps = 0
    episodes = 0

    while timesteps < timesteps_per_sequence:
        obs, _ = env.reset()
        terminated, truncated = False, False
        episode_reward = 0
        episodes += 1

        while not (terminated or truncated):
            action, log_prob = agent.select_action(obs)
            heuristic = heuristic_action(obs)
            reward = 1.0 if action == heuristic else 0.0  # Modify reward based on action correctness
            obs, reward_env, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            agent.store_transition(obs, action, log_prob, reward, done)
            episode_reward += reward
            timesteps += 1

            if done:
                break

        # Update PPO agent after each episode
        agent.update()

        # Optional: Print progress every N episodes
        if episodes % 100 == 0:
            elapsed = time.time() - start_time
            print(f"  Episodes: {episodes}, Timesteps: {timesteps}, Elapsed Time: {elapsed:.2f}s")

    # Create evaluation environment
    eval_env = create_env(env_id, sequence_length)

    # Evaluate the model's performance
    mean_reward, accuracy = evaluate_model(agent, eval_env, num_eval_episodes)

    # Print the required metrics on a single line
    print(f"Seq Length: {sequence_length}, Mean Reward: {mean_reward:.2f}, Accuracy: {accuracy:.2f}%")

    # Check if accuracy threshold is met to increase sequence length
    if accuracy >= accuracy_threshold:
        sequence_length += 1
        print(f"Reward threshold achieved! Increasing sequence length to {sequence_length}")

        if sequence_length > max_sequence_length:
            break

        # Update the training environment with increased sequence length
        env = create_env(env_id, sequence_length)
    else:
        print("Accuracy threshold not met, continuing training on the same sequence length.")

    # Clean up evaluation environment
    eval_env.close()

# Save the final model parameters
torch.save({
    'policy_network_state_dict': agent.policy_network.state_dict(),
    'value_network_state_dict': agent.value_network.state_dict(),
    'optimizer_state_dict': agent.optimizer.state_dict(),
}, "ppo_cartpole_cl.pth")

# Load the model
checkpoint = torch.load("ppo_cartpole_cl.pth", map_location=device)
agent.policy_network.load_state_dict(checkpoint['policy_network_state_dict'])
agent.value_network.load_state_dict(checkpoint['value_network_state_dict'])
agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
agent.policy_network.eval()
agent.value_network.eval()

# Test the model with printouts
test_env = create_env(env_id, sequence_length)
obs, _ = test_env.reset()
test_episodes = 0

print("\nStarting Testing Phase...\n")

for i in range(1000):
    obs_tensor = torch.tensor(obs, dtype=torch.float32).to(device)
    with torch.no_grad():
        action_probs = agent.policy_network(obs_tensor)
    action = torch.argmax(action_probs).item()
    heuristic = heuristic_action(obs)
    accuracy_step = "Correct" if action == heuristic else "Incorrect"
    print(f"Step: {i}, Seq Length: {sequence_length}, Model Action: {action}, Heuristic Action: {heuristic}, Accuracy: {accuracy_step}")
    obs, _, terminated, truncated, _ = test_env.step(action)
    test_episodes += 1
    if terminated or truncated:
        obs, _ = test_env.reset()

test_env.close()


Using cuda device

Training on sequence length: 1


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.