In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
import matplotlib.pyplot as plt
import time

# DQN Model
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    BATCH_SIZE = 64
    EPSILON_START = 1.0
    EPSILON_END = 0.01
    EPSILON_DECAY = 1000  # Decay over ~1000 episodes
    GAMMA = 0.99
    LR = 1e-3
    MEMORY_SIZE = 10000

    def __init__(self, env):
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.LR)
        self.memory = ReplayBuffer(self.MEMORY_SIZE)
        self.steps_done = 0
        self.current_episode = 0
        self.epsilon = self.EPSILON_START

    def select_e_greedy_action(self, env, state):
        if random.random() < self.epsilon:
            return env.action_space.sample()
        else:
            state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
            with torch.no_grad():
                return self.policy_net(state).argmax(dim=1).item()

    def decay_epsilon(self):
        self.epsilon = self.EPSILON_END + (self.EPSILON_START - self.EPSILON_END) * \
            np.exp(-1. * self.current_episode / self.EPSILON_DECAY)

    def step(self):
        if len(self.memory) < self.BATCH_SIZE:
            return
        batch = self.memory.sample(self.BATCH_SIZE)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.tensor(np.array(states), dtype=torch.float32, device=self.device)
        actions = torch.tensor(actions, dtype=torch.int64, device=self.device).unsqueeze(1)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1)
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32, device=self.device)
        dones = torch.tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1)

        current_q = self.policy_net(states).gather(1, actions)
        next_q = self.target_net(next_states).max(1, keepdim=True)[0].detach()
        target_q = rewards + (self.GAMMA * next_q * (1 - dones))

        loss = F.smooth_l1_loss(current_q, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.steps_done += 1

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

# Training Function
def train_dqn(env, env_name, target_update_interval=1000, training_episodes=1000, agent_class=DQNAgent):
    agent = agent_class(env)
    rewards_per_episode = []
    epsilon_values = []
    episode_times = []
    solved = False

    for episode in range(training_episodes):
        start_time = time.time()
        state, _ = env.reset()
        total_reward = 0
        done = False
        truncated = False
        steps = 0

        while not (done or truncated):
            action = agent.select_e_greedy_action(env, state)
            next_state, reward, done, truncated, _ = env.step(action)
            # Reward shaping
            reward += 10 * abs(next_state[1])
            if next_state[0] >= 0.5:
                reward += 100
            agent.memory.push(state, action, reward, next_state, done or truncated)
            state = next_state
            total_reward += reward
            steps += 1
            agent.step()

            if agent.steps_done % target_update_interval == 0:
                agent.update_target_network()

        agent.current_episode = episode
        agent.decay_epsilon()
        rewards_per_episode.append(total_reward)
        epsilon_values.append(agent.epsilon)
        episode_time = time.time() - start_time
        episode_times.append(episode_time)

        # Logging
        if episode % 10 == 0:
            print(f"Episode {episode}: Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.3f}, Steps: {steps}, Time: {episode_time:.2f}s")

        # Early stopping
        if len(rewards_per_episode) >= 100:
            avg_reward = np.mean(rewards_per_episode[-100:])
            if avg_reward >= -110 and not solved:
                print(f"Environment solved at episode {episode}! Average reward: {avg_reward:.2f}")
                solved = True
                torch.save(agent.policy_net.state_dict(), "dqn_mountain_car_solved.pth")

        # Save checkpoint
        if episode % 100 == 0:
            checkpoint = {
                'policy_net_state_dict': agent.policy_net.state_dict(),
                'optimizer_state_dict': agent.optimizer.state_dict(),
                'steps_done': agent.steps_done,
                'episode': episode,
                'rewards': rewards_per_episode,
                'epsilons': epsilon_values,
                'times': episode_times
            }
            torch.save(checkpoint, "dqn_mountain_car_checkpoint.pth")

    # Plot results
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(rewards_per_episode)
    plt.title("Rewards per Episode")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.subplot(1, 2, 2)
    plt.plot(epsilon_values)
    plt.title("Epsilon Decay")
    plt.xlabel("Episode")
    plt.ylabel("Epsilon")
    plt.tight_layout()
    plt.savefig("training_results.png")
    plt.show()

    return agent, rewards_per_episode, epsilon_values, episode_times

# Run Training
if __name__ == "__main__":
    env_name = "MountainCar-v0"
    env = gym.make(env_name, max_episode_steps=1000)
    trained_agent, rewards, epsilons, times = train_dqn(env, env_name, training_episodes=1000)
    torch.save(trained_agent.policy_net.state_dict(), "dqn_mountain_car_final.pth")
    env.close()

Episode 0: Reward: -881.39, Epsilon: 1.000, Steps: 1000, Time: 2.35s
Episode 10: Reward: -897.45, Epsilon: 0.990, Steps: 1000, Time: 2.05s
Episode 20: Reward: -932.55, Epsilon: 0.980, Steps: 1000, Time: 1.95s
Episode 30: Reward: -875.83, Epsilon: 0.971, Steps: 1000, Time: 2.00s
Episode 40: Reward: -843.20, Epsilon: 0.961, Steps: 1000, Time: 1.95s
Episode 50: Reward: -794.37, Epsilon: 0.952, Steps: 1000, Time: 2.03s
Episode 60: Reward: -822.31, Epsilon: 0.942, Steps: 1000, Time: 1.95s
Episode 70: Reward: -894.49, Epsilon: 0.933, Steps: 1000, Time: 1.92s
Episode 80: Reward: -818.43, Epsilon: 0.924, Steps: 1000, Time: 1.95s
Episode 90: Reward: -815.76, Epsilon: 0.915, Steps: 1000, Time: 2.37s
Episode 100: Reward: -570.40, Epsilon: 0.906, Steps: 879, Time: 1.76s
Episode 110: Reward: -525.18, Epsilon: 0.897, Steps: 787, Time: 1.92s
Episode 120: Reward: -769.78, Epsilon: 0.888, Steps: 1000, Time: 2.44s
Episode 130: Reward: -758.54, Epsilon: 0.879, Steps: 1000, Time: 1.95s
Episode 140: Reward