<a href="https://colab.research.google.com/github/worthlessFella/deep_learning/blob/main/deep_rl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import random
import torch
import torch.nn as nn
import torch.optim as optim

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995, lr=0.001, batch_size=64):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.lr = lr
        self.batch_size = batch_size

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.q_network = QNetwork(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()

        self.memory = []
        self.steps = 0

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_dim)
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.q_network(state)
        return torch.argmax(q_values, dim=1).item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        self.steps += 1
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

    def replay(self):
        if len(self.memory) < self.batch_size:
            return
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.tensor(states, dtype=torch.float32).to(self.device)
        actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(self.device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(self.device)

        q_values = self.q_network(states).gather(1, actions)
        next_q_values = self.q_network(next_states).max(dim=1, keepdim=True).values.detach()
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
        loss = self.loss_fn(q_values, target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def save(self, path):
        torch.save(self.q_network.state_dict(), path)

    def load(self, path):
        self.q_network.load_state_dict(torch.load(path))



In [4]:
import gym
import torch

env = gym.make('CartPole-v0')
env.seed(0)
torch.manual_seed(0)

agent = DQNAgent(env.observation_space.shape[0], env.action_space.n)

for episode in range(100):
    state = env.reset()
    total_reward = 0

    for step in range(200):
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        total_reward += reward
        agent.remember(state, action, reward, next_state, done)
        agent.replay()

        if done:
            break

        state = next_state

    print(f'Episode {episode}: Total reward = {total_reward}')

env.close()


Episode 0: Total reward = 9.0
Episode 1: Total reward = 25.0
Episode 2: Total reward = 19.0
Episode 3: Total reward = 11.0
Episode 4: Total reward = 16.0
Episode 5: Total reward = 16.0
Episode 6: Total reward = 13.0
Episode 7: Total reward = 18.0
Episode 8: Total reward = 11.0
Episode 9: Total reward = 19.0
Episode 10: Total reward = 14.0
Episode 11: Total reward = 13.0
Episode 12: Total reward = 9.0
Episode 13: Total reward = 29.0
Episode 14: Total reward = 9.0
Episode 15: Total reward = 21.0
Episode 16: Total reward = 10.0
Episode 17: Total reward = 10.0
Episode 18: Total reward = 11.0
Episode 19: Total reward = 10.0
Episode 20: Total reward = 11.0
Episode 21: Total reward = 12.0
Episode 22: Total reward = 11.0
Episode 23: Total reward = 11.0
Episode 24: Total reward = 14.0
Episode 25: Total reward = 73.0
Episode 26: Total reward = 21.0
Episode 27: Total reward = 59.0
Episode 28: Total reward = 84.0
Episode 29: Total reward = 145.0
Episode 30: Total reward = 47.0
Episode 31: Total re

KeyboardInterrupt: ignored