<a href="https://colab.research.google.com/github/srikarraju/GridWorld/blob/main/Natural_Actor_Critic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import gym
from collections import deque
import numpy as np

In [None]:
def generate_state_action_features(state_vec, d, num_actions):
  state_action_features = np.zeros(shape=(num_actions,d*num_actions))
  for i in range(num_actions):
    for j in range(d):
      state_action_features[i][d*i+j] = state_vec[j]
  return state_action_features

#print(generate_state_action_features([1,0,0,1,0,0,1,1],8,4))

In [None]:
env = gym.make('CartPole-v0')

m = env.action_space.n
d = 4

weights_v = np.zeros(d,dtype=float)
weights_p = np.zeros(d*m,dtype=float)
print(weights_v)
print(weights_p)

returns = deque(maxlen=100)

[0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0.]


In [None]:
alpha_0, beta_0, gamma, epsilon = 0.1, 0.01, 0.95, 0.1
alpha_c, beta_c = 1000, 100000
t = 0
n_episode = 1
actions_list = []
avg_reward = 0

while n_episode <=50000:
  rewards,states,actions = [],[],[]
  state = env.reset()

  state = np.asarray(state)
  value_curr = np.dot(weights_v,state)
  #print(value_curr)
  t = 0
  while True:
    t += 1
    state_action_features = generate_state_action_features(state,d,m)

    probs = np.dot(state_action_features,weights_p)
    probs -= probs.max()
    probs = np.exp(np.clip(probs/epsilon, -500, 500))
    #probs = np.exp(probs)
    probs /= probs.sum()

    probs2 = probs.cumsum()
    action = np.where(probs2 >= np.random.random())[0][0]
    #print(action)

    new_state, reward, done, info = env.step(action.item())


    value_curr = np.dot(weights_v,np.asarray(state))
    value_next = np.dot(weights_v,np.asarray(new_state))

    beta = (beta_0 * beta_c) / (beta_c + t)
    alpha = (alpha_0 * alpha_c) / (alpha_c + t**(2/3))

    avg_reward = (1 - 0.95* alpha)*avg_reward + 0.95 * alpha * reward
    td_error = reward + value_curr - value_next - avg_reward



    weights_v += alpha * td_error * np.asarray(state)

    #print(beta,td_error,state_action,weights_p)
    weights_p += beta * td_error * state_action_features[action]



    states.append(state)
    actions.append(action)
    rewards.append(reward)
    actions_list.append(action)

    state = new_state
    if done==True:
      break

  returns.append(np.sum(rewards))
  #print(np.sum(rewards))
  if n_episode%100==0:
    print("Episode: {:6d}\tAvg. Return: {:6.2f}".format(n_episode, np.mean(returns)))
  #reinforce_baseline_returns.append(np.mean(returns))
  n_episode += 1

env.close()

Episode:    100	Avg. Return:  15.81
Episode:    200	Avg. Return:  16.79
Episode:    300	Avg. Return:  15.45
Episode:    400	Avg. Return:  14.79
Episode:    500	Avg. Return:  16.76
Episode:    600	Avg. Return:  15.49
Episode:    700	Avg. Return:  15.63
Episode:    800	Avg. Return:  15.41
Episode:    900	Avg. Return:  15.85
Episode:   1000	Avg. Return:  15.18
Episode:   1100	Avg. Return:  15.72
Episode:   1200	Avg. Return:  14.95
Episode:   1300	Avg. Return:  15.24
Episode:   1400	Avg. Return:  17.10
Episode:   1500	Avg. Return:  16.59
Episode:   1600	Avg. Return:  16.34
Episode:   1700	Avg. Return:  15.28
Episode:   1800	Avg. Return:  15.55
Episode:   1900	Avg. Return:  16.94
Episode:   2000	Avg. Return:  15.04
Episode:   2100	Avg. Return:  15.48
Episode:   2200	Avg. Return:  16.82
Episode:   2300	Avg. Return:  16.66
Episode:   2400	Avg. Return:  16.54
Episode:   2500	Avg. Return:  15.46
Episode:   2600	Avg. Return:  15.20
Episode:   2700	Avg. Return:  17.12
Episode:   2800	Avg. Return:

In [None]:
print(actions_list)

[0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 