In [11]:
from environment import Environment
from generic_agent import GenericAgent
import gym
import numpy as np
import torch
from torch import nn, optim
import matplotlib
from collections import deque, namedtuple



In [12]:
env_name = "CartPole-v1"
env = Environment(env_name)

Environment Name:  CartPole
Action Space Type:  DISCRETE
Observation Space Type:  CONTINUOUS
Observation Space:  Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)


In [13]:
# Device Utility

device = "cpu" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print("Device : ", device)

Device :  cpu


In [18]:
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward", "terminate"))

class ReplayMemory:
    
    def __init__(self, capacity=10000):
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [19]:
memory = ReplayMemory()

### Constants

In [16]:
BATCH_SIZE = 16


In [5]:
class DQN(nn.Module):
    
    def __init__(self, in_feature, out_feature):
        
        self.inp = nn.Linear(in_feature, 128)
        self.dense1 = nn.Linear(128, 128)
        self.out = nn.Linear(128, out_feature)
    
    def forward(self, observation):
        
        resp = nn.functional.relu(self.inp(observation))
        resp = nn.functional.relu(self.dense1(resp))
        return self.out(self.out(resp))
    

In [17]:
class QAgent(GenericAgent):
    
    def __init__(self, env, ep):
        super(QAgent, self).__init__(env)
        self.dqn = DQN(self.observation_size[0], self.action_size).to(device)
        self.optim = optim.AdamW(self.dqn.parameters(), amsgrad=True)
        self.criteria = nn.SmoothL1Loss()
        self.steps_done = 0
    
    def get_action(self, observation):
        
        # Defining Exploration and exploration tradeoff
        EPS_START, EPS_END, EPS_DECAY = 0.05, 0.9, 1000    
        threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1 * self.steps_done/EPS_DECAY)
        self.steps_done += 1
        
        if np.random.random() > threshold:
            action = self.get_random_action()
            return torch.tensor(action, dtype=torch.long)
        with torch.no_grad():
            obs = torch.tensor([observation], dtype=torch.float32).to(device)
            return self.dqn(obs).argmax()[0]
    
    def train_model():
        
        batch_data = memory.sample(BATCH_SIZE)
        
        batches = Transition(*zip(*batch_data))
        
            
        # Non-Terminated States
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                              batches.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.tensor([x for x in batches.next_state if x is not None], device=device, dtype=torch.float32).view(-1, 1)

        # Preprocess records
        states = torch.tensor(batches.state).view(-1, 1).to(device)
        actions = torch.tensor(batches.action).view(-1, 1).to(device)
        rewards = torch.tensor(batches.reward).view(-1, 1).to(device)

        # Train Execution
        pred_actions = self.dqn(states).gather(1, actions)


        next_state_batch = torch.zeros((BATCH_SIZE)).to(device)
        with torch.no_grad():
            next_state_batch[non_final_mask] = target_net(non_final_next_states).max(1)[0]
        expected_state_action_values = (next_state_batch * GAMMA) + rewards


        # Define Criteria 
        criteria = nn.SmoothL1Loss()
        loss = criteria(pred_actions, expected_state_action_values)

        # Optimization 
        optimizer.zero_grad()
        loss.backward()

        # Gradient Clipping
        nn.utils.clip_grad_norm(policy_net.parameters(), 100)
        optimizer.step()


In [9]:
torch.tensor([1,2,3]).argmax().item()

2