In [25]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import cv2

from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Working with {device}")

Working with cpu


In [26]:
class FCDAP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=(32, 32), init_std=1, activation_fc = nn.ReLU()):
        super(FCDAP, self).__init__()
        self.activation_fc = activation_fc
        self.layers = [(nn.Linear(input_dim, hidden_dims[0]))]

        for i in range(len(hidden_dims) - 1):
            self.layers.append(self.activation_fc)
            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
        self.layers.append(self.activation_fc)
        self.layers.append(nn.Linear(hidden_dims[-1], output_dim))
        self.layers = nn.Sequential(*self.layers)

    def forward(self, state):
        x = state
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32).to(device)

        return self.layers(x)

    def full_pass(self, state):
        logits = self.forward(state)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        logpa = dist.log_prob(action)
        entropy = dist.entropy()
        is_exploratory = action != np.argmax(logits.detach().cpu().numpy())

        return action.item(), is_exploratory.item(), logpa, entropy

    def select_action(self, state):
        logits = self.forward(state).detach()
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample().detach().numpy()
        return action

    def select_greedy(self, state):
        logits = self.forward(state).detach()
        return np.argmax(logits.to('cpu').numpy())


In [32]:
class REINFORCE:
    def __init__(self, env, gamma, hidden_dims, optimizer, lr = 1e-4):
        self.env = env
        self.gamma = gamma
        self.model = FCDAP(env.observation_space.shape[0], env.action_space.n, hidden_dims).to(device)
        env.reset()

        self.optimizer = optimizer(self.model.parameters(), lr)

    def optimize_model(self):
        T = len(self.rewards)
        discounts = np.logspace(0, T, num=T, base=self.gamma, endpoint=False)
        returns = torch.Tensor([np.sum(discounts[:T-t]*self.rewards[t:]) for t in range(T)]).to(device)
        self.rewards = torch.Tensor(self.rewards).to(device)
        self.logpas = torch.stack(self.logpas).to(device)
        discounts = torch.Tensor(discounts).to(device)

        policy_loss = -1*(discounts*returns*self.logpas).mean()
        self.optimizer.zero_grad()
        policy_loss.backward()
        self.optimizer.step()

    def interactive_step(self, state):
        action, _, logpa, _ = self.model.full_pass(state)
        new_state, reward, terminated, truncated, _ = self.env.step(action)
        self.rewards.append(reward)
        self.logpas.append(logpa)
        return new_state, terminated or truncated

    def train(self, num_episodes, step):
        self.reward_means = []
        for i in range(1, num_episodes + 1):
            self.rewards, self.logpas = [], []
            terminated = False
            state = self.env.reset()[0]

            while(not terminated):
                state, terminated = self.interactive_step(state)

            self.optimize_model()

            self.reward_means.append(self.rewards.sum().cpu())

            if(i%step == 0):
                plt.plot(range(len(self.reward_means)), self.reward_means)
                plt.show()
                self.reward_means.clear()
                self.evaluate()
                
                # torch.save(self.model.state_dict(), f'./{i/step}_iteration.pth')

    def evaluate(self):
        terminated = False
        state = self.env.reset()[0]
        returns = 0
        while(not terminated):
            img = self.env.render()
            cv2.imshow("Environment", img)
            # cv2.waitKey(1)
            action = self.model.select_greedy(state)
            state, reward, terminate, truncated, _ = self.env.step(action)
            terminated = truncated or terminate
            returns += reward

        cv2.destroyAllWindows()
        print(f"Return Obtained: {returns}")


In [None]:
env = gym.make('CartPole-v1', render_mode = 'rgb_array')
agent = REINFORCE(env, 1, (5, 10), optim.Adam, 5e-5)
agent.train(10000, 500)