In [6]:
import os
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

# Define the action transformation function
def transform_action_values(n_actions, min_action=-2.0, max_action=2.0):
    linear_actions = np.linspace(-1, 1, n_actions)
    non_linear_actions = np.sign(linear_actions) * (linear_actions ** 2)
    scaled_actions = min_action + (non_linear_actions + 1) * (max_action - min_action) / 2
    return scaled_actions

# Define the replay buffer class
class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_mask_lst.append([done_mask])

        s_batch = torch.tensor(s_lst, dtype=torch.float)
        a_batch = torch.tensor(a_lst, dtype=torch.float)
        r_batch = torch.tensor(r_lst, dtype=torch.float)
        s_prime_batch = torch.tensor(s_prime_lst, dtype=torch.float)
        done_batch = torch.tensor(done_mask_lst, dtype=torch.float)

        return s_batch, a_batch, r_batch, s_prime_batch, done_batch

    def size(self):
        return len(self.buffer)

# Define the Q-network class
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, q_lr):
        super(QNetwork, self).__init__()
        self.fc_1 = nn.Linear(state_dim, 64)
        self.fc_2 = nn.Linear(64, 32)
        self.fc_out = nn.Linear(32, action_dim)
        self.lr = q_lr
        self.optimizer = optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, x):
        q = F.leaky_relu(self.fc_1(x))
        q = F.leaky_relu(self.fc_2(q))
        q = self.fc_out(q)
        return q

# Define the DQN agent class
class DQNAgent:
    def __init__(self):
        self.state_dim = 3
        self.action_dim = 11
        self.lr = 0.01
        self.gamma = 0.995
        self.tau = 0.01
        self.epsilon = 0.9
        self.epsilon_decay = 0.9
        self.epsilon_min = 0.001
        self.buffer_size = 1000000
        self.batch_size = 256
        self.memory = ReplayBuffer(self.buffer_size)
        self.action_list = transform_action_values(self.action_dim, min_action=-2.0, max_action=2.0)

        self.Q = QNetwork(self.state_dim, self.action_dim, self.lr)
        self.Q_target = QNetwork(self.state_dim, self.action_dim, self.lr)
        self.Q_target.load_state_dict(self.Q.state_dict())

    def select_action(self, state):
        random_number = np.random.rand()
        maxQ_action_count = 0
        if self.epsilon < random_number:
            with torch.no_grad():
                action = float(torch.argmax(self.Q(state)).numpy())
                maxQ_action_count = 1
        else:
            action = float(np.random.choice([n for n in range(self.action_dim)]))
        real_action = self.action_list[int(action)]

        return action, real_action, maxQ_action_count

    def calc_target(self, mini_batch):
        s, a, r, s_prime, done = mini_batch
        with torch.no_grad():
            q_target = self.Q_target(s_prime).max(1)[0].unsqueeze(1)
            target = r + self.gamma * done * q_target
        return target

    def train_agent(self):
        mini_batch = self.memory.sample(self.batch_size)
        s_batch, a_batch, r_batch, s_prime_batch, done_batch = mini_batch
        a_batch = a_batch.type(torch.int64)

        td_target = self.calc_target(mini_batch)

        Q_a = self.Q(s_batch).gather(1, a_batch)
        q_loss = F.smooth_l1_loss(Q_a, td_target)
        self.Q.optimizer.zero_grad()
        q_loss.mean().backward()
        self.Q.optimizer.step()

        # Q soft-update
        for param_target, param in zip(self.Q_target.parameters(), self.Q.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)
            
if __name__ == '__main__':
    agent = DQNAgent()
    env = gym.make('Pendulum-v1')

    EPOCHS = 100
    score_list = []

    while agent.memory.size() < 4 * agent.batch_size:
        state, info = env.reset()
        done = False

        while not done:
            action, real_action, _ = agent.select_action(torch.FloatTensor(state))
            state_prime, reward, terminated, truncated, _ = env.step([real_action])

            if terminated or truncated:
                done = True

            agent.memory.put((state, action, reward, state_prime, terminated))
            state = state_prime

    for EP in range(EPOCHS):
        state, info = env.reset()
        score, done = 0.0, False
        maxQ_action_count = 0

        while not done:
            action, real_action, count = agent.select_action(torch.FloatTensor(state))

            state_prime, reward, terminated, truncated, _ = env.step([real_action])

            agent.memory.put((state, action, reward, state_prime, terminated))

            agent.train_agent()

            if terminated or truncated:
                done = True

            score += reward
            maxQ_action_count += count

            state = state_prime

        print("Epoch:{}, Avg_Score:{:.1f}, MaxQ_Action_Count:{}, Epsilon:{:.5f}".format(EP, score, maxQ_action_count, agent.epsilon))
        score_list.append(score)
        agent.epsilon = max(agent.epsilon_min, agent.epsilon*agent.epsilon_decay)

    plt.plot(score_list)
    plt.show()


Episode:0, Avg_Score:-1366.1, MaxQ_Action_Count:17, Epsilon:0.90000
Episode:1, Avg_Score:-1812.2, MaxQ_Action_Count:40, Epsilon:0.81000
Episode:2, Avg_Score:-1057.3, MaxQ_Action_Count:55, Epsilon:0.72900
Episode:3, Avg_Score:-1198.3, MaxQ_Action_Count:70, Epsilon:0.65610
Episode:4, Avg_Score:-1350.8, MaxQ_Action_Count:68, Epsilon:0.59049
Episode:5, Avg_Score:-983.5, MaxQ_Action_Count:98, Epsilon:0.53144
Episode:6, Avg_Score:-232.4, MaxQ_Action_Count:98, Epsilon:0.47830
Episode:7, Avg_Score:-487.8, MaxQ_Action_Count:112, Epsilon:0.43047
Episode:8, Avg_Score:-380.0, MaxQ_Action_Count:109, Epsilon:0.38742
Episode:9, Avg_Score:-376.7, MaxQ_Action_Count:139, Epsilon:0.34868
Episode:10, Avg_Score:-380.6, MaxQ_Action_Count:132, Epsilon:0.31381
Episode:11, Avg_Score:-749.9, MaxQ_Action_Count:137, Epsilon:0.28243
Episode:12, Avg_Score:-373.6, MaxQ_Action_Count:154, Epsilon:0.25419
Episode:13, Avg_Score:-395.5, MaxQ_Action_Count:147, Epsilon:0.22877
Episode:14, Avg_Score:-127.0, MaxQ_Action_Coun

KeyboardInterrupt: 