In [2]:
#!pip install gym[atari]==0.19.0

## Coding Paper Replay Buffer

In [None]:
import numpy as np


class ReplayBuffer():
    def __init__(self,max_size,input_shape,n_actions):
        self.mem_size=max_size
        self.mem_cntr=0
        self.state_memory=np.zeros((self.mem_size,*input_shape),dtype=np.float32)
        self.new_state_memory=np.zeros((self.mem_size,*input_shape),dtype=np.float32)
        self.action_memory=np.zeros(self.mem_size,dtype=np.int64)
        self.reward_memory=np.zeros(self.mem_size,dtype=np.float32)
        self.terminal_memory=np.zeros(self.mem_size,dtype=np.uint8)
    
    def store_transition(self,state,action,reward,state_,done):
        index=self.mem_cntr % self.mem_size
        self.state_memory[index]=state
        self.action_memory[index]=action
        self.reward_memory[index]=reward
        self.new_state_memory[index] = state_
        self.terminal_memory[index]=done
        self.mem_cntr += 1

    def sample_buffer(self,batch_size):
        max_mem=min(self.mem_cntrl,self.mem_size) ## to account for number of instances stored
        batch=np.random.choice(max_mem,batch_size,replace=False)
        states=self.state_memory[batch]
        actions=self.action_memory[batch]
        rewards=self.reward_memory[batch]
        states_=self.new_state_memory[batch]
        dones= self.terminal_memory[batch]
        return states,actions,rewards,states_,dones

## Coding Q-agent

In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import os

DEVICE="cuda:0" if torch.cuda.is_available() else "cpu"

class DeepQNetwork(torch.nn.Module):
    def __init__(self,lr,n_actions,name,input_dims,chkpt_dir):
        super(DeepQNetwork,self).__init__()
        self.checkpoint_dir=chkpt_dir
        self.checkpoint_file=os.path.join(self.checkpoint_dir,name)
        self.cnn1=torch.nn.Conv2d(input_dims[0],32,8,stride=4)
        self.cnn2=torch.nn.Conv2d(32,64,4,stride=2)
        self.cnn3=torch.nn.Conv2d(64,64,3,stride=1)
        fc_input_dims=self.calculate_conv_output_dims(input_dims)
        self.fc1=torch.nn.Linear(fc_input_dims,512)
        self.fc2=torch.nn.Linear(512,n_actions)
        self.optimizer=optim.RMSprop(self.parameters(),lr=lr)
        self.criterion=torch.nn.MSELoss()

    def calculate_conv_output_dims(self,input_dims):
        state=torch.zeros(1,*input_dims)
        dims=self.cnn1(state)
        dims=self.cnn2(dims)
        dims=self.cnn3(dims)
        return int(np.prod(dims.size()))

    def forward(self,state):
        out=F.relu(self.cnn1(state))
        out=F.relu(self.cnn2(out))
        out=F.relu(self.cnn3(out))
#        out=out.view(out.size()[0],-1)
        out=torch.flatten(out,-1)
        out=F.relu(self.fc1(out))
        out=self.fc2(out)
        actions=self.fc2()
        return actions

    def save_checkpoint(self):
        print("...saving checkpoint...")
        torch.save(self.state_dict(),self.checkpoint_file)

    def load_checkpoint(self):
        print("...loading checkpoint...")
        self.load_state_dict(torch.load(self.checkpoint_file))


class DQNAgent():
    def __init__(self,gamma,epsilon,lr,n_actions,input_dims,mem_size,batch_size,eps_min=0.01,eps_dec=5e-7,replace=1000,algo=None,env_name=None,chkpt_dir='tmp/dqn'):
        self.lr=lr
        self.input_dims=input_dims
        self.n_actions=n_actions
        self.gamma=gamma
        self.batch_size=batch_size
        self.epsilon=epsilon
        self.eps_dec=eps_dec
        self.eps_min=eps_min
        self.replace_target_cnt=replace
        self.algo=algo
        self.chkpt_dir=chkpt_dir
        self.action_space=[i for i in range(self.n_actions)]
        self.learn_step_counter=0
        self.q_eval=DeepQNetwork(self.lr,self.n_actions,self.input_dims,name=f"{self.env_name}_{self.algo}_q_eval",chkpt_dir=self.chkpt_dir)
        self.q_next=DeepQNetwork(self.lr,self.n_actions,self.input_dims,name=f"{self.env_name}_{self.algo}_q_next",chkpt_dir=self.chkpt_dir)
        self.memory=ReplayBuffer(mem_size,input_dims,n_actions)
        
    def choose_action(self,observation):
        if np.random.random() > self.epsilon:
            state=torch.tensor([observation],dtype=torch.float).to(DEVICE)
            actions=self.q_eval(state)
            action=torch.argmax(actions).item()
        else:
            action=np.random.choice(self.action_space)
        return action


    def store_transition(self,state,action,reward,state_,done):
        self.memory.store_transition(state,action,reward,state_,done)

    def sample_memory(self):
        state,action, reward,new_state,done= self.memory.sample_buffer(self.batch_size)
        states=torch.tensor(state).to(DEVICE)
        rewards=torch.tensor(reward).to(DEVICE)
        dones=torch.tensor(done).to(DEVICE)
        actions=torch.tensor(actions).to(DEVICE)
        states_=torch.tensor(new_state).to(DEVICE)
        return states,actions,rewards,states_,dones
    
    def replace_target_network(self):
        if self.learn_step_counter % self.replace_target_cnt==0:
            self.q_next.load_state_dict(self.q_eval.state_dict())

    def save_models(self):
        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def load_models(self):
        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()

    def decrease_epsilon(self):
        self.epsilon=self.epsilon-self.eps_dec if self.epsilon > self.eps_min else self.eps_min
        
        
    def learn(self,state,action,reward,new_state):

        if self.memory.mem_cntr < self.batch_size:
            return

        self.q_eval.optimizer.zero_grad()
        self.replace_target_network()

        states,actions,rewards,states_,dones=self.sample_memory()
        indices=np.arange(self.batch_size)

        q_pred=self.q_eval(states)[indices,actions]
        q_next=self.q_next(states_).max(dim=1)[0]

        q_next[dones]=0.0
        q_target=rewards+ self.gamma*q_next

        loss= self.q_eval.criterion(q_target,q_pred).to(DEVICE)
        loss.backward()
        self.q_eval.optimizer.step()
        self.learn_step_counter += 1
        self.decrease_epsilon()
        

## Training Loop

In [None]:
from utils import make_env

env=make_env("PongNoFrameskip-v4")
best_score=-np.inf
load_checkpoint=False
n_games=500
