In [1]:
import os
import sys

import numpy as np
import pandas as pd

from tqdm import tqdm

import torch
import random

import seaborn as sns
import matplotlib.pyplot as plt

# get the current script's directory
current_directory = os.path.dirname(os.path.abspath(__file__)) if "__file__" in locals() else os.getcwd()
# get the parent directory
parent_directory = os.path.dirname(current_directory)
# add the parent directory to the sys.path
sys.path.append(parent_directory)

from optimization import functions
from optimization.updater import Updater

from utils import constants, common
from utils.config import Config
from utils.dataset_loader import PolicyDatasetLoader

from models.policy_model import RobotPolicy
from models.reward_model import RewardFunction

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class AgentVPG(nn.Module):
    def __init__(self, state_shape, n_actions):
        
        super().__init__()
        
        self.state_shape = state_shape
        self.n_actions = n_actions
        
        self.model = nn.Sequential(
          nn.Linear(in_features = state_shape[0], out_features = 128),
          nn.ReLU(),
          nn.Linear(in_features = 128 , out_features = 64),
          nn.ReLU())
        self.std = 

    def forward(self, x):
        logits = self.model(x)
        return logits
    
    def predict_probs(self, states):
        states = torch.FloatTensor(states)
        logits = self.model(states).detach()
        
        probs = F.softmax(logits, dim = -1).numpy()
        
        return probs
    
    def generate_session(self, env, t_max=1000):
        states, actions, rewards = [], [], []
        s = env.reset()

        for t in range(t_max):
            action_probs = self.predict_probs(np.array([s]))[0]
            a = np.random.choice(self.n_actions,  p = action_probs)
            new_s, r, done, info = env.step(a)

            states.append(s)
            actions.append(a)
            rewards.append(r)

            s = new_s
            if done:
                break

        return states, actions, rewards