# <span style='color:blue'> LAB 9: </span>
# <span style='color:blue'> DEEP REINFORCEMENT LEARNING </span>

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random 
from collections import namedtuple, deque 

## Policy/target network

In [None]:
class Q_network(nn.Module):

    def __init__(self, state_size,action_size, seed, fc1_unit,
                 fc2_unit):

        super(QNetwork,self).__init__() 
        self.seed = torch.manmual_seed(seed)
        self.fc1= nn.Linear(state_size,fc1_unit)
        seed.fc2 = nn.Linear(fc1_unit,fc2_unit)
        seed.fc3 = nn.Linear(fc2_unit,action_size)
        
    def forward(self,x):

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

## Define Agent

In [None]:
class Agent():
    
    def __init__(self, state_size, action_size, seed):
        
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        
        #Q- Network
        self.qnetwork_policy = Q_network(state_size, action_size, seed).to(device)
        self.qnetwork_target = Q_network(state_size, action_size, seed).to(device)
        
        self.optimizer = optim.Adam(self.qnetwork_policy.parameters(),lr=LR)
        
        # Replay memory 
        self.memory = ExperienceRelay(action_size, BUFFER_SIZE,BATCH_SIZE,seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
        
    def step(self, state, action, reward, next_step, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_step, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step+1)% UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get radom subset and learn

            if len(self.memory)>BATCH_SIZE:
                experience = self.memory.sample()
                self.learn(experience, GAMMA)
                
    def act(self, state, eps = 0):

        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_policy.eval()
        with torch.no_grad():
            action_values = self.qnetwork_policy(state)
        self.qnetwork_policy.train()

        #Epsilon -greedy action selction
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))
            
    def learn(self, experiences, gamma):

        states, actions, rewards, next_state, dones = experiences
        criterion = torch.nn.MSELoss()
        self.qnetwork_policy.train()
        self.qnetwork_target.eval()
        predicted_targets = self.qnetwork_policy(states).gather(1,actions)
    
        with torch.no_grad():
            labels_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)

        labels = rewards + (gamma* labels_next*(1-dones))
        
        loss = criterion(predicted_targets,labels).to(device)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_policy,self.qnetwork_target,TAU)
            
    def soft_update(self, local_model, target_model, tau):

        for target_param, local_param in zip(target_model.parameters(),
                                           local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1-tau)*target_param.data)

## Define Experience relay

In [None]:
class ExperienceRelay:
    
    def __init__(self, action_size, buffer_size, batch_size, seed):
        
        self.action_size = action_size
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.experiences = namedtuple("Experience", field_names=["state",
                                                               "action",
                                                               "reward",
                                                               "next_state",
                                                               "done"])
        self.seed = random.seed(seed)
        
    def add(self,state, action, reward, next_state,done):
        e = self.experiences(state,action,reward,next_state,done)
        self.memory.append(e)
        
    def sample(self):
        experiences = random.sample(self.memory,k=self.batch_size)
        
        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
        
        return (states,actions,rewards,next_states,dones)
    
    def __len__(self):
        return len(self.memory)

## Example hyperparameters

In [None]:
BUFFER_SIZE = int(1e5)  # replay buffer size
BATCH_SIZE = 64         # minibatch size
GAMMA = 0.99            # reward discount factor
TAU = 1e-3              # for soft update of target parameters (0 -> hard update where target net == policy net)
LR = 5e-4               # learning rate
UPDATE_EVERY = 4        # how often to update the network

# Also see original publication https://www.nature.com/articles/nature14236