In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

In [3]:
GAMMA = 0.99
LEARNING_RATE = 0.002  # Increased learning rate for faster convergence
NUM_WORKERS = 1         # Reduced parallel processes
HIDDEN_DIM = 32         # Smaller network
MAX_EPISODES = 50       # More episodes with simpler computation
MAX_STEPS = 200         # Increased steps per episode
GRAD_CLIP = 1.0         # More aggressive gradient clipping
# =============================================================

device = torch.device("cpu")

class LightweightActorCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super().__init__()
        # Simplified network architecture
        self.actor = nn.Sequential(
            nn.Linear(input_dim, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, action_dim)
        )
        self.critic = nn.Sequential(
            nn.Linear(input_dim, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, 1)
        )

    def forward(self, x):
        policy = torch.softmax(self.actor(x), dim=-1)
        value = self.critic(x).squeeze(-1)
        return policy, value

def optimized_worker(env):
    state, _ = env.reset()
    while True:
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).to(device)
            policy, _ = model(state_tensor.unsqueeze(0))
            action = Categorical(policy).sample().item()
        
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        
        yield (state, action, reward, next_state, done)
        
        if done:
            state, _ = env.reset()
        else:
            state = next_state

def efficient_training(global_model):
    env = gym.make("CartPole-v1")
    optimizer = optim.RMSprop(global_model.parameters(), lr=LEARNING_RATE)
    
    episode_rewards = []
    state_batch = []
    action_batch = []
    reward_batch = []

    worker_gen = optimized_worker(env)
    
    for episode in range(MAX_EPISODES):
        total_reward = 0
        for step in range(MAX_STEPS):
            state, action, reward, next_state, done = next(worker_gen)
            
            state_batch.append(state)
            action_batch.append(action)
            reward_batch.append(reward)
            total_reward += reward

            if done or (step == MAX_STEPS - 1):
                # Batch processing for efficiency
                states = torch.FloatTensor(state_batch).to(device)
                actions = torch.LongTensor(action_batch).to(device)
                rewards = torch.FloatTensor(reward_batch).to(device)

                # Calculate returns
                returns = torch.zeros_like(rewards)
                running_return = 0
                for i in reversed(range(len(rewards))):
                    running_return = rewards[i] + GAMMA * running_return
                    returns[i] = running_return

                # Normalize returns
                returns = (returns - returns.mean()) / (returns.std() + 1e-8)

                # Forward pass
                policies, values = global_model(states)
                dist = Categorical(policies)
                log_probs = dist.log_prob(actions)
                
                # Calculate losses
                advantage = returns - values.detach()
                actor_loss = -(log_probs * advantage).mean()
                critic_loss = advantage.pow(2).mean()
                loss = actor_loss + 0.5 * critic_loss

                # Optimize
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(global_model.parameters(), GRAD_CLIP)
                optimizer.step()

                # Reset batches
                state_batch.clear()
                action_batch.clear()
                reward_batch.clear()

                episode_rewards.append(total_reward)
                print(f"Ep {episode+1} | Avg Reward: {sum(episode_rewards[-10:])/10:.1f} | Last: {total_reward}")
                break

def quick_evaluate(model, episodes=3):
    env = gym.make("CartPole-v1")
    total_reward = 0
    for _ in range(episodes):
        state, _ = env.reset()
        done = False
        episode_reward = 0
        while not done:
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).to(device)
                policy, _ = model(state_tensor.unsqueeze(0))
                action = policy.argmax().item()
            state, reward, terminated, truncated, _ = env.step(action)
            episode_reward += reward
            done = terminated or truncated
        total_reward += episode_reward
    print(f"Evaluation Average: {total_reward/episodes:.1f}")
    env.close()

if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    input_dim, action_dim = env.observation_space.shape[0], env.action_space.n
    env.close()

    model = LightweightActorCritic(input_dim, action_dim).to(device)
    
    print("Starting optimized training...")
    efficient_training(model)
    
    print("\nFinal evaluation:")
    quick_evaluate(model)

Starting optimized training...
Ep 1 | Avg Reward: 2.6 | Last: 26.0
Ep 2 | Avg Reward: 4.0 | Last: 14.0
Ep 3 | Avg Reward: 5.6 | Last: 16.0
Ep 4 | Avg Reward: 8.0 | Last: 24.0
Ep 5 | Avg Reward: 10.2 | Last: 22.0
Ep 6 | Avg Reward: 12.2 | Last: 20.0
Ep 7 | Avg Reward: 13.5 | Last: 13.0
Ep 8 | Avg Reward: 15.7 | Last: 22.0
Ep 9 | Avg Reward: 17.1 | Last: 14.0
Ep 10 | Avg Reward: 18.7 | Last: 16.0
Ep 11 | Avg Reward: 17.9 | Last: 18.0
Ep 12 | Avg Reward: 17.6 | Last: 11.0
Ep 13 | Avg Reward: 17.2 | Last: 12.0
Ep 14 | Avg Reward: 16.0 | Last: 12.0
Ep 15 | Avg Reward: 15.4 | Last: 16.0
Ep 16 | Avg Reward: 15.4 | Last: 20.0
Ep 17 | Avg Reward: 15.9 | Last: 18.0
Ep 18 | Avg Reward: 15.3 | Last: 16.0
Ep 19 | Avg Reward: 15.1 | Last: 12.0
Ep 20 | Avg Reward: 14.7 | Last: 12.0
Ep 21 | Avg Reward: 14.0 | Last: 11.0
Ep 22 | Avg Reward: 14.1 | Last: 12.0
Ep 23 | Avg Reward: 13.8 | Last: 9.0
Ep 24 | Avg Reward: 13.8 | Last: 12.0
Ep 25 | Avg Reward: 14.0 | Last: 18.0
Ep 26 | Avg Reward: 13.4 | Last: 