<a href="https://colab.research.google.com/github/toanpt74/Workflow/blob/main/agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch.nn.functional as F
from utils import normalized_columns_initializer, weights_init
import numpy as np
import os
import torch as T
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
import torch.distributions


class ActorCritic(nn.Module):
    def __init__(self, num_inputs, action_space, chkpt_dir='./A2C'):
        super(ActorCritic, self).__init__()
        self.checkpoint_file = os.path.join(chkpt_dir, 'a2c_torch')
        self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)

        self.lstm = nn.LSTMCell(32 * 553, 256)

        num_outputs = action_space

        self.critic_linear = nn.Linear(256, 1)
        self.actor_linear = nn.Linear(256, num_outputs)

        self.apply(weights_init)
        self.actor_linear.weight.data = normalized_columns_initializer(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.critic_linear.weight.data = normalized_columns_initializer(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.lstm.bias_ih.data.fill_(0)
        self.lstm.bias_hh.data.fill_(0)

        self.train()

    def forward(self, inputs):
        inputs, (hx, cx) = inputs
        x = F.relu(self.conv1(inputs))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))

        x = x.view(-1, 32 * 553)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        return self.critic_linear(x), self.actor_linear(x), (hx, cx)

    def choose_action(self, inputs, action_dim):
        s, (hx, cx) = inputs
        value, logit, (hx, cx) = self.forward((s.unsqueeze(0), (hx, cx)))
        prob = F.softmax(logit, dim=-1)
        log_prob = F.log_softmax(logit, dim=-1)
        entropy = -(log_prob * prob).sum(1, keepdim=True)

        action = []
        for i in range(action_dim):
            action.append(prob.multinomial(num_samples=1).detach()[0])
        action = T.from_numpy(np.array(action, dtype=np.int64).reshape(1, action_dim))
        return action, log_prob, entropy, value

    def save_checkpoint(self):
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        self.load_state_dict(T.load(self.checkpoint_file))


class PPOMemory:
    def __init__(self, batch_size):
        self.states = []
        self.probs = []
        self.vals = []
        self.actions = []
        self.rewards = []
        self.dones = []

        self.batch_size = batch_size

    def generate_batches(self):
        n_states = len(self.states)
        batch_start = np.arange(0, n_states, self.batch_size)
        indices = np.arange(n_states, dtype=np.int64)
        np.random.shuffle(indices)
        batches = [indices[i:i + self.batch_size] for i in batch_start]

        states = np.array([state.cpu().numpy() if isinstance(state, T.Tensor) else state for state in self.states])
        actions = np.array(
            [action.cpu().numpy() if isinstance(action, T.Tensor) else action for action in self.actions])
        probs = np.array([prob.cpu().detach().numpy() if isinstance(prob, T.Tensor) else prob for prob in self.probs])
        vals = np.array([val.cpu().detach().numpy() if isinstance(val, T.Tensor) else val for val in self.vals])
        rewards = np.array(self.rewards)
        dones = np.array(self.dones)

        return states, actions, probs, vals, rewards, dones, batches
        # return np.array(self.states), \
        #     np.array(self.actions), \
        #     np.array(self.probs), \
        #     np.array(self.vals), \
        #     np.array(self.rewards), \
        #     np.array(self.dones), \
        #     batches

    def store_memory(self, state, action, probs, vals, reward, done):
        self.states.append(state)
        self.actions.append(action)
        self.probs.append(probs)
        self.vals.append(vals)
        self.rewards.append(reward)
        self.dones.append(done)

    def clear_memory(self):
        self.states = []
        self.probs = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.vals = []


class ActorNetwork(nn.Module):
    def __init__(self, n_actions, input_dims, alpha, chkpt_dir='./PPO'):
        super(ActorNetwork, self).__init__()
        self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')

        # input_dims = int(input_dims / 2)
        self.conv1 = nn.Conv2d(input_dims, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)

        self.fc1 = nn.Linear(32 * 553, 256)
        self.fc2 = nn.Linear(256, n_actions)
        self.relu = nn.ReLU()
        # self.softmax = nn.Softmax(dim=-1)

        self.apply(weights_init)
        self.fc2.weight.data = normalized_columns_initializer(
            self.fc2.weight.data, 0.01)
        self.fc2.bias.data.fill_(0)

        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        x = F.relu(self.conv1(state))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))

        x = x.view(-1, 32 * 553)

        x = F.relu(self.fc1(x))
        dist = F.relu(self.fc2(x))
        # dist = self.softmax(x)

        # dist = Categorical(dist)
        return dist

    def save_checkpoint(self):
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        self.load_state_dict(T.load(self.checkpoint_file))


class CriticNetwork(nn.Module):
    def __init__(self, input_dims, beta, chkpt_dir='./PPO'):
        super(CriticNetwork, self).__init__()
        # self.flatten = nn.Flatten()
        self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')

        self.conv1 = nn.Conv2d(input_dims, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)

        self.fc1 = nn.Linear(32 * 553, 256)
        self.fc2 = nn.Linear(256, 1)

        self.apply(weights_init)
        self.fc2.weight.data = normalized_columns_initializer(
            self.fc2.weight.data, 0.01)
        self.fc2.bias.data.fill_(0)

        self.optimizer = optim.Adam(self.parameters(), lr=beta)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        x = F.relu(self.conv1(state))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))

        x = x.view(-1, 32 * 553)

        x = F.relu(self.fc1(x))
        value = self.fc2(x)

        return value

    def save_checkpoint(self):
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        self.load_state_dict(T.load(self.checkpoint_file))


class Agent:
    def __init__(self, n_actions, input_dims, gamma=0.99, alpha=0.0003, gae_lambda=0.95,
                 policy_clip=0.2, batch_size=32, n_epochs=4, beta=0.01):
        self.gamma = gamma
        self.policy_clip = policy_clip
        self.n_epochs = n_epochs
        self.gae_lambda = gae_lambda

        self.actor = ActorNetwork(n_actions, input_dims, alpha)
        self.critic = CriticNetwork(input_dims, beta)
        self.memory = PPOMemory(batch_size)

    def remember(self, state, action, probs, vals, reward, done):
        self.memory.store_memory(state, action, probs, vals, reward, done)

    def save_models(self):
        print('... saving models ...')
        self.actor.save_checkpoint()
        self.critic.save_checkpoint()

    def load_models(self):
        print('... loading models ...')
        self.actor.load_checkpoint()
        self.critic.load_checkpoint()

    # def choose_action(self, observation, action_dim):
    #     state = T.tensor(observation, dtype=T.float).to(self.actor.device)
    #     dist = self.actor.forward(state)
    #     value = self.critic.forward(state)
    #     action = dist.sample()
    #
    #     print(action)
    #     tensor = dist.log_prob(action)
    #     print(tensor)
    #     tensor_mean = tensor.float().mean().unsqueeze(0)
    #     action_mean = action.float().mean().unsqueeze(0)
    #     value_mean = value.float().mean().unsqueeze(0)
    #
    #     probs = T.squeeze(tensor_mean).item()
    #
    #     print(probs)
    #     action = T.squeeze(action_mean).item()
    #     print(action)
    #     value = T.squeeze(value_mean).item()
    #
    #     return action, probs, value
    def choose_action(self, inputs, action_dim):
        s = T.tensor(inputs, dtype=T.float).to(self.actor.device)
        value = self.critic.forward(s)
        logit = self.actor.forward(s)
        prob = F.softmax(logit, dim=-1)
        log_prob = F.log_softmax(logit, dim=-1)

        # entropy = -(log_prob * prob).sum(1, keepdim=True)

        action = []
        for i in range(action_dim):
            action.append(prob.multinomial(num_samples=1).detach()[0])
        action = [t.cpu() for t in action]
        action = T.from_numpy(np.array(action, dtype=np.int64).reshape(1, action_dim))
        value = T.squeeze(value).item()
        return action, log_prob, value

    def learn(self):
        state_arr, action_arr, old_prob_arr, vals_arr, \
            reward_arr, dones_arr, batches = \
            self.memory.generate_batches()

        values = vals_arr
        advantage = np.zeros(len(reward_arr), dtype=np.float32)
        for t in range(len(reward_arr) - 1):
            discount = 1
            a_t = 0
            for k in range(t, len(reward_arr) - 1):
                a_t += discount * (reward_arr[k] + self.gamma * values[k + 1] *
                                   (1 - int(dones_arr[k])) - values[k])
                discount *= self.gamma * self.gae_lambda
            advantage[t] = a_t

        advantage = T.tensor(advantage).to(self.actor.device)
        values = T.tensor(values).to(self.actor.device)

        for _ in range(self.n_epochs):

            for batch in batches:
                states = T.tensor(state_arr[batch], dtype=T.float).to(self.actor.device)
                old_probs = T.tensor(old_prob_arr[batch]).to(self.actor.device)

                actions = T.tensor(action_arr[batch]).to(self.actor.device)
                actions = actions.squeeze(1)

                dist = self.actor.forward(states)
                critic_value = self.critic.forward(states)

                critic_value = T.squeeze(critic_value)

                new_probs = dist.gather(1, actions)[0]

                # prob_ratio = new_probs.exp() / old_probs.exp()
                prob_ratio = (new_probs - old_probs).exp()
                weighted_probs = advantage[batch].view(-1, 1) * prob_ratio
                weighted_clipped_probs = T.clamp(prob_ratio, 1 - self.policy_clip,
                                                 1 + self.policy_clip) * advantage[batch].view(-1, 1)
                actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean()

                returns = advantage[batch].view(-1, 1) + values[batch]
                critic_loss = (returns - critic_value) ** 2
                critic_loss = critic_loss.mean()

                total_loss = actor_loss + 0.5 * critic_loss
                self.actor.optimizer.zero_grad()
                self.critic.optimizer.zero_grad()
                total_loss.backward()
                self.actor.optimizer.step()
                self.critic.optimizer.step()

        self.memory.clear_memory()

