In [None]:
!pip install wandb



In [None]:
import gym
import numpy as np
import cv2
from random import randint
import random
import torch
import torch.nn as nn
import torch.nn.functional as f
import wandb
from tqdm import tqdm
import torch.optim as optim
from collections import deque


if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
    
    
class ConvModel(nn.Module):
    def __init__(self,obs_shape,num_action):
        super(ConvModel,self).__init__()
        self.obs_shape=obs_shape
        self.num_action=num_action
        self.cv1=torch.nn.Conv2d(4,16,(8,8),stride=(4,4))
        self.cv2=torch.nn.Conv2d(16,32,(4,4),stride=(2,2))
        self.fc1=torch.nn.Linear(32*9*9,256)
        self.fc2=torch.nn.Linear(256,num_action)
 
        
    def forward(self,inpt):
        inpt=self.cv1(inpt)
        inpt=self.cv2(f.relu(inpt))
        inpt=inpt.view(-1,32*9*9)
        inpt=self.fc1(inpt)
        inpt=self.fc2(f.relu(inpt))
        
        
        return inpt
    
class FrameStacker():
    def __init__(self,env,w,h,num_stack=4):
        self.env=env
        self.n=num_stack
        self.w=w
        self.h=h
        self.observation_space=np.zeros((self.n,self.h,self.w))
        self.action_space=self.env.action_space
        self.frame=None
        self.buffer=np.zeros((num_stack,h,w),'uint8')
        
    def preprocess_frame(self,frame):
        self.frame=frame[30:195,10:152]
        self.frame=cv2.resize(self.frame,(self.w,self.h))
        self.frame=cv2.cvtColor(self.frame,cv2.COLOR_RGB2GRAY)
        return self.frame
    
    def render(self):
        return self.env.render()
    
    def render_processed(self):
        cv2.imshow('processed',self.frame)
        cv2.waitKey(0)
        
    def destroyWindows(self):
        cv2.destroyAllWindows()
    
    def close(self):
        return self.env.close()
    
    
    def step(self,action):
        im,reward,done,info = self.env.step(action)
        im = self.preprocess_frame(im)
        self.buffer[1:self.n]=self.buffer[0:self.n-1]
        self.buffer[0]=im
        return self.buffer.copy(),reward,done,info
    
    def reset(self):
        im = self.env.reset()
        im = self.preprocess_frame(im)
        self.buffer=np.stack([im]*self.n)
        return self.buffer.copy()
                    
class ReplayBuffer:
    def __init__(self, buffer_size=1000000):
        self.buffer_size = buffer_size
        self.buffer = [None]*buffer_size
        self.idx=0
        
    def insert(self,sars):
        self.buffer[self.idx%self.buffer_size]=sars
        self.idx+=1
#         self.buffer = self.buffer[-self.buffer_size:]
        
    def sample(self,num_samples):
        if self.idx < self.buffer_size:
            return random.sample(self.buffer[:self.idx],num_samples)
        return random.sample(self.buffer,num_samples)
    
def update_tgt_model(m,tgt):
    tgt.load_state_dict(m.state_dict())

    
def train_step(model,state_trans,tgt,loss_fn,lr=0.0001):
    cur_states=[]
    actions=[]
    rewards=[]
    next_states=[]
    masks=[]
    for state in state_trans:
        
        cur_states.append(torch.tensor(state[0],dtype=torch.float32))
        actions.append(state[1])
        rewards.append(torch.tensor(state[2]))
        next_states.append(torch.tensor(state[3],dtype=torch.float32))
        masks.append(torch.tensor(0 if state[4] else 1))
        
    cur_states = torch.stack(cur_states).to(device)
    rewards = torch.stack(rewards).to(device)
    next_states = torch.stack(next_states).to(device)
    masks = torch.stack(masks).to(device)
    with torch.no_grad():
        qvals_next = tgt(next_states.view(-1,4,84,84)).max(-1)[0]
    
    opt = optim.Adam(model.parameters(),lr=lr)
    opt.zero_grad()
    qvals = model(cur_states.view(-1,4,84,84))
    one_hot_actions = f.one_hot(torch.tensor(actions),4).to(device)
        
#     act_qval=torch.stack([qvals[i][actions[i]] for i in range(len(actions))]).to(device)
    act_qval=torch.sum(qvals*one_hot_actions,-1)
    
    
    loss = loss_fn((rewards + 0.99*masks*qvals_next),act_qval)
    loss.backward()
    opt.step()
    return loss

    
wandb.init(project="dqn",name="breakout")

loss_fn = torch.nn.SmoothL1Loss()

min_rb_size=200000
sample_size=32
eps=1.0
eps_min=0.1

eps_decay=0.9999999

env_steps_before_train = 10
tgt_model_update = 1500
        
env = gym.make('Breakout-v0')
env = FrameStacker(env,84,84)

last_obs=env.reset()

# lr_max=0.001
# lr_decay=0.00000000099
# lr_min=0.00001
# lr_count = 0
# lr_count_max=1000000
# lr = lr_max
lr=0.0001
m=ConvModel(env.observation_space.shape,env.action_space.n).to(device)
tgt=ConvModel(env.observation_space.shape,env.action_space.n).to(device)

rb=ReplayBuffer()
steps_since_train=0
epochs_since_tgt=0
step_num=-1*min_rb_size
episode_reward=[]
rolling_reward=0
# qvals = m(torch.Tensor(observation))
tq=tqdm()

try:
    while eps>eps_min:
        tq.update(1)
        eps=eps_decay**step_num
        # lr_count+=1

        # if lr_count > lr_count_max:
        #     lr_count = 0
        # lr = lr_max-(lr_decay*lr_count)
        
        ##boltzman exploration
        logits=m(torch.tensor(last_obs,dtype=torch.float32).view(-1,4,84,84).to(device))[0]
        action = torch.distributions.Categorical(logits=logits).sample().item()
        
        ##epsillon greedy
#         if random.random() < eps:
#             action = env.action_space.sample()
#         else:
#             action = int(torch.argmax(m(torch.tensor(last_obs,dtype=torch.float32).view(-1,4,84,84).to(device))))
            
        obs,rew,done,info=env.step(action)
        rolling_reward+=rew
        rew=rew*100
        rb.insert((last_obs,action,rew,obs,done))
        last_obs = obs
        if done:
            episode_reward.append(rolling_reward)
            rolling_reward=0
            obs = env.reset()
        
        steps_since_train += 1
        step_num+=1
        
        if (rb.idx) > min_rb_size and steps_since_train>env_steps_before_train:
            
            epochs_since_tgt+=1
            loss=train_step(m,rb.sample(sample_size),tgt,loss_fn,lr)
            
            wandb.log({'loss':loss,'eps':eps,'avg_rew':np.mean(episode_reward)},step=step_num)
#             print(step_num,loss)
            episode_reward=[]
            if epochs_since_tgt > tgt_model_update:
                print('updating tgt model')
                update_tgt_model(m,tgt)
                epochs_since_tgt=0
                torch.save(tgt.state_dict(),f"models/{step_num}.pth")
            steps_since_train=0
            
except KeyboardInterrupt:
    pass
        
env.close()
        

[34m[1mwandb[0m: Currently logged in as: [33mrachiteagles[0m (use `wandb login --relogin` to force relogin)


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
216554it [06:13, 318.28it/s]

updating tgt model


233046it [07:00, 322.69it/s]

updating tgt model


249573it [07:48, 342.99it/s]

updating tgt model


266100it [08:37, 324.08it/s]

updating tgt model


282595it [09:27, 325.91it/s]

updating tgt model


299105it [10:18, 312.26it/s]

updating tgt model


304721it [29:08,  8.86s/it]

Exception: ignored