In [30]:
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random

from itertools import count
from collections import namedtuple


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Replay Buffer of Transition 

In [24]:
Transition = namedtuple('Transition', 
                       ('state', 'action', 'next_state', 'reward'))

class ReplayBuffer(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
    
    def push(self, *args):
        if len(self.memory) == self.capacity:
            self.memory.pop()
        
        self.memory.insert(0, Transition(*arg))
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size) # a list of element from self.memory

# Critic - Q Function

In [22]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        '''
        2 Hidden layers
        400 & 300 units each
        action is only introduced in 2nd hidden layer
        '''
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim, 400)
        self.l2 = nn.Linear(400 + action_dim, 300)
        self.l3 = nn.Linear(300, 1)
        
    def forward(self, state, action):
        x = F.relu(self.l1(state))
        x = F.relu(self.l2(torch.cat([x, action], dim=1)))
        x = self.l3(x)
        return x

# Actor - Policy

In [23]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, sup_action):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)
        self.sup_action = sup_action  # suppermum of action space
    
    def forward(self, state):
        x = F.relu(self.l1(state))
        x = F.relu(self.l2(x))
        x = self.sup_action * torch.tanh(self.l3(x))
        return x

# DDPG (mainly everything)

In [None]:
class DDPG(object):
    def __init__(self, state_dim, action_dim, sup_action, memory_capacity, batch_size, 
                 gamma=0.99, actor_lr=1e-3, polyak=0.99):
        # critic
        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_optimizer = optim.RMSprop(self.critic.parameters())
        
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_target.eval()
        # actor
        self.actor = Actor(state_dim, action_dim, sup_action).to(device)
        self.actor_optimizer = optim.RMSprop(self.actor.parameters())
        
        self.actor_target = Actor(state_dim, action_dim, sup_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_target.eval()
        # replay buffer
        self.replay_buffer = ReplayBuffer(capacity)
        self.states_batch = None  # store the state batch sampled from replay buffer for later use in update_actor
        # hyper parameters
        self.BATCH_SIZE = batch_size
        self.GAMMA = gamma
        self.actor_lr = actor_lr
        self.polyak = polyak
    
    def update_critic(self):
        # sample exp from replay buffer
        transitions = self.replay_buffer.sample(self.BATCH_SIZE)
        batch = Transition(*zip(*transitions))
        # pull out index (i.e. mask) in the batch & next_state of transitions having next_state is not None
        mask_nonFinal_next_states = torch.tensor(tuple(map(lambda s: s is not None, 
                                                          batch.next_state)), device=device, dtype=torch.uint8)
        nonFinal_next_states = torch.cat([s for s in batch.next_state
                                             if s is not None])
        # construct tensor of state, action, reward
        self.states_batch = torch.cat(batch.state)
        actions_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        # compute TD target
        next_state_act_val = torch.zeros(self.BATCH_SIZE, device=device)
        next_state_act_val[mask_nonFinal_next_states] = self.critic_target(nonFinal_next_states, 
                                                                        self.actor_target(nonFinal_next_states))
        td_target = reward_batch + self.GAMMA * next_state_act_val
        
        loss = F.mse_loss(self.critic(self.states_batch, actions_batch), td_target)  # MS of td_error
        
        # call optimizer of critic to minimize loss
        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()
    
    def update_actor(self):
        loss = -self.critic(self.states_batch, actor(self.states_batch)).mean() 
        
        self.actor_optimizer.zero_grad()
        loss.backward()
        self.actor_optimizer.step()
        
    def update_target_net(self):
        with torch.no_grad():
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_((1 - self.polyak) * param.data + self.polyak * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_((1 - self.polyak) * param.data + self.polyak * target_param.data)
    
    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()
    
    def train(self, _env):
        #TODO

# Training

In [None]:
env = gym.envs.make("LunarLanderContinuous-v2")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
sup_action = env.action_space.high
algo = DDPG()

In [29]:
a = [0, 1, 2]
a.pop()
print(a)
a.insert(0, -12)
a
random.sample(a, 2)

[0, 1]


[1, 0]