<a href="https://colab.research.google.com/github/xpurwar/DeepQ-Network/blob/main/DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install gymnasium



In [None]:
from time import sleep
import numpy as np
from IPython.display import clear_output
import gymnasium as gym
from gymnasium.envs.registration import register
import torch
from torch import nn


In [None]:
device='cuda'

In [None]:
#Give colab access to your google drive:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Change current directory to folder with MiniPacMan
%cd /content/drive/MyDrive/Reinforcement Learning

/content/drive/MyDrive/Reinforcement Learning


In [None]:
#Import MiniPacMan environment class definition
from MiniPacManGymV2 import MiniPacManEnv

In [None]:
register(
    id="MiniPacMan-v2",
    entry_point=MiniPacManEnv,
    max_episode_steps=20
)

In [None]:
#Create a MiniPacMan gymnasium environment
env = gym.make("MiniPacMan-v2", render_mode="human", frozen_ghost=False)

In [None]:
class QNetwork(nn.Module):
    #Define your network here
    #Should accept inputs of shape (6,6) and return (4,)
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(36, 32)
        self.activation1 = nn.ReLU()
        self.linear3 = nn.Linear(32, 16)
        self.activation3 = nn.ReLU()
        self.linear4 = nn.Linear(16, 4)

    def forward(self, x):
        x = self.activation1(self.linear1(x))
        x = self.activation3(self.linear3(x))
        x = self.linear4(x)
        return x

model = QNetwork().to(device)
x = torch.randn(1, 36).to(device)
#x = torch.flatten(x)
model(x)

tensor([[ 0.2932, -0.3092,  0.0383, -0.2417]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [None]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, next_states, dones = zip(*[self.buffer[i] for i in indices])
        return torch.stack(states).to(device), actions, torch.tensor(rewards).to(device), torch.stack(next_states).to(device), torch.tensor(dones).to(device)

In [None]:
Q = QNetwork().to(device) #initialize a Q network
Q_target = QNetwork().to(device)
Q_target.load_state_dict(Q.state_dict())
Q_optimizer = torch.optim.Adam(Q.parameters(), lr=0.001)

In [None]:
#set hyperparams!
gamma=0.95
buffer_size = 1000
batch_size = 32
num_episodes = 15000

RB=ReplayBuffer(buffer_size) #initialize Replay Buffer
epsilon=1 #initialize epsilon

for e in range(num_episodes):
  new_obs,info=env.reset()
  new_obs=torch.tensor(new_obs,dtype=torch.float32).to(device)

  done=False
  truncated=False
  steps=0

  while not done and not truncated: #Loop for one episode
    obs = new_obs
    #choose action
    t = np.random.random()
    obs = torch.flatten(obs)
    if t > epsilon:
      with torch.no_grad():
        pred = Q(obs)
        action = torch.argmax(pred)
    else:
      action=torch.randint(4,(1,)).item()

    #take a step:
    new_obs,reward, done, truncated, info=env.step(action)
    new_obs=torch.tensor(new_obs,dtype=torch.float32).to(device)
    # Add (s, a, r, s0, done) to D
    new_obs = torch.flatten(new_obs)
    RB.push(obs,action,reward,new_obs,done)
    steps+=1

    if len(RB.buffer) >= batch_size:
      states, actions, rewards, next_states, dones = RB.sample(batch_size)
      pred = Q(states)[torch.arange(batch_size), actions]
      with torch.no_grad():
        pred_1 = Q_target(next_states)
        best_action = torch.max(pred_1, axis = 1).values
        targets = rewards + (~dones) * gamma * best_action.to(device)
      #print(pred.shape, targets.shape)
      loss = torch.mean((pred - targets)**2)
      Q_optimizer.zero_grad()
      loss.backward()
      Q_optimizer.step()

  if e%100:
    Q_target.load_state_dict(Q.state_dict())


  #reduce episilon if its not too low:
  min_e = 0.01
  if epsilon > min_e:
    epsilon = epsilon - (1/num_episodes)

  #periodic reporting:
  if e>0 and e%100==0:
    print(f'episode: {e}, steps: {steps}, epislon: {epsilon},win: {reward==20}')


  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} is not within the observation space.")


episode: 100, steps: 1, epislon: 0.9932666666666674,win: False
episode: 200, steps: 3, epislon: 0.9866000000000015,win: False
episode: 300, steps: 1, epislon: 0.9799333333333355,win: False
episode: 400, steps: 1, epislon: 0.9732666666666696,win: False
episode: 500, steps: 1, epislon: 0.9666000000000037,win: False
episode: 600, steps: 3, epislon: 0.9599333333333377,win: False
episode: 700, steps: 1, epislon: 0.9532666666666718,win: False
episode: 800, steps: 5, epislon: 0.9466000000000059,win: False
episode: 900, steps: 1, epislon: 0.93993333333334,win: False
episode: 1000, steps: 3, epislon: 0.933266666666674,win: False
episode: 1100, steps: 6, epislon: 0.9266000000000081,win: False
episode: 1200, steps: 3, epislon: 0.9199333333333422,win: False
episode: 1300, steps: 5, epislon: 0.9132666666666762,win: False
episode: 1400, steps: 6, epislon: 0.9066000000000103,win: False
episode: 1500, steps: 4, epislon: 0.8999333333333444,win: False
episode: 1600, steps: 3, epislon: 0.8932666666666784

In [None]:
obs, info = env.reset()
done = False
truncated = False

while not done and not truncated:
    env.render()
    obs = torch.tensor(obs,dtype=torch.float32)
    obs = torch.flatten(obs).to(device)
    pred = Q(obs)
    action = torch.argmax(pred)
    obs, reward, done, truncated, info = env.step(action)
    sleep(1)
    clear_output(wait=True)

env.render()
env.close()

xxxxxx
x····x
x····x
xᗧ···x
x····x
xxxxxx

