<a href="https://colab.research.google.com/github/srikarraju/GridWorld/blob/main/TRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
  def __init__(self,state_dim,hidden_dim,action_dim):
    super(PolicyNetwork,self).__init__()
    self.hidden = nn.Linear(state_dim,hidden_dim)
    self.out = nn.Linear(hidden_dim,action_dim)

  def forward(self,x):
    x = F.relu(self.hidden(x))
    x = F.softmax(self.out(x),dim=1)
    return x

class ValueNetwork(nn.Module):
  def __init__(self,state_dim,hidden_dim):
    super(ValueNetwork,self).__init__()
    self.hidden = nn.Linear(state_dim,hidden_dim)
    self.out = nn.Linear(hidden_dim,1)

  def forward(self,x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

In [None]:
import gym
import torch

env = gym.make('CartPole-v0')

state_dim = env.observation_space.shape[0]
hidden_dim = 50
action_dim = env.action_space.n

policy = PolicyNetwork(state_dim,hidden_dim,action_dim)
value_fn = ValueNetwork(state_dim,hidden_dim)

optimizer_pol = torch.optim.Adam(policy.parameters())
optimizer_val_fn = torch.optim.Adam(value_fn.parameters())
gamma = 0.99

In [None]:
from torch.distributions import Categorical

max_episodes = 5000
n_episode = 0
while n_episode < max_episodes:
  state = env.reset()
  rewards, states,actions = [], [], []
  V_curr_state, V_next_state = [],[]
  advantage_estimates = []
  while True:
    action_probs = policy(torch.tensor(state))
    V_curr_state.append(value_fn(torch.tensor(state)).numpy()[0])
    sampler = Categorical(probs)
    curr_action = sampler.sample()
    new_state, reward, done, info = env.step(curr_action)

    V_next_state.append(value_fn(torch.tensor(V_next_state)).numpy()[0])

    advantage_estimates.append(reward + value_fn(torch.tensor(V_next_state)).numpy()[0] - value_fn(torch.tensor(state)).numpy()[0])

    states.append(state)
    rewards.append(reward)
    actions.append(curr_action)

    state = new_state
    if done == True:
      break

  advantage_estimates = torch.tensor(advantage_estimates)

  states = torch.tensor(states)
  probs = policy(states)
  sampler = Categorical(probs)
  logprobs = -sampler.log_prob(actions)

  actor_loss = torch.sum(logprobs*advantage_estimates)

