In [1]:
import matplotlib.pyplot as plt

from simulators.grid_world.grid_world import Maze, simulate_policy, set_granular_reward
from simulators.grid_world import GAMMA, HORIZON, COMMANDS

grid_type = "simple"
# feature_type = "one_hot"

feature_type = "simularity"
dimensions = 100
sigma = 0.25

env = Maze(grid_type, feature_type, dimensions=dimensions, sigma=sigma)

In [11]:
import copy 

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical


class SmallNAC(nn.Module):
    def __init__(self, env, entropy_weight):
        super(SmallNAC, self).__init__()
        self.env = env
        self.entropy_weight = entropy_weight

        self.net = nn.Sequential(
            nn.Linear(3, 50), nn.ReLU(), nn.Linear(50, 1)  # 2 for the state and 1 for the action  # 1 for Q(s, a)
        )

    def Q(self, state, action):
        (y, x) = self.env.index2coord[state]

        return self.net(torch.tensor([x, y, action], dtype=torch.float32))

    def V(self, state):
        v_value = 0

        for action in self.env._actions:
            v_value += torch.exp(self.Q(state, action) / self.entropy_weight)

        return self.entropy_weight * torch.log(v_value)

    def pi(self, state):
        return self.pi_distribution(state).sample()

    def pi_distribution(self, state):
        pi_actions = torch.zeros((self.env.Na))

        for action in self.env._actions:
            pi_actions[action] = torch.exp((self.Q(state, action) - self.V(state)) / self.entropy_weight)

        categorical = Categorical(probs=pi_actions)

        return categorical

    def copy(self):
        return copy.deepcopy(self)

In [12]:
def get_losses(batch, model, target_model):
    actor_loss = 0
    critic_loss = 0

    for (state, action, reward, next_state, entropy) in batch:
        # Actor's loss
        q_require_grad = model.Q(state, action)
        v_require_grad = model.V(state)

        with torch.no_grad():
            v_target = target_model.V(next_state)
        
        actor_loss += (q_require_grad - v_require_grad) * (reward + model.env.gamma * v_target - q_require_grad.detach())

        # Critic loss
        critic_loss += v_require_grad * (v_require_grad.detach() - (reward + model.env.gamma * v_target) + model.entropy_weight * entropy)

    return actor_loss / len(batch), critic_loss / len(batch)

In [None]:
class ReplayBuffer:
    def __init__(self, n_expert_samples):
        

    def collect_rl_sample(self, model):
        pass 
    
    def get_batch(self, from_expert=True):
        pass

In [10]:
import numpy as np


def nac(env, n_expert_samples, n_expert_iterations, n_rl_iterations, update_target_frequecy):
    # from algorithms.NAC.replay_buffer import ReplayBuffer
    # from algorithms.NAC.model import SmallNAC

    replay_buffer = ReplayBuffer(n_expert_samples)
    ReplayBuffer(
        env,
        n_expert_samples=n_expert_samples,
        expert_policy=expert_policy,
    )
    model = SmallNAC(env)
    target_model = model.copy()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for iteration in range(n_expert_iterations + n_rl_iterations):
        from_expert = True if iteration < n_expert_iterations else False

        if not from_expert:
            replay_buffer.collect_rl_sample(model)

        batch = replay_buffer.get_batch(from_expert=from_expert)

        actor_loss, critic_loss = get_losses(batch, model, target_model)
        loss = actor_loss + critic_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if iteration % update_target_frequecy == 0:
            target_model = model.copy()

    return model