In [None]:
import gymnasium as gym # type: ignore
import numpy as np # type: ignore
import torch # type: ignore
from torch import nn # type: ignore
import matplotlib.pyplot as plt # type: ignore
import random

class Agent(nn.Module):
    def __init__(self):
        super(Agent, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 4)
        )
        self.critic = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        # Ensure input is a tensor
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        
        # The policy network returns logits for the action distribution
        action_logits = self.actor(x)
        # The value network returns the value of the state
        value = self.critic(x)
        return action_logits, value
    
agent = Agent()

optimizer = torch.optim.Adam(agent.parameters(), lr = 3e-4)
MSELoss = nn.MSELoss(reduction='none')

TOTAL_RUNS = 2000
GAMMA = 0.99
EPSILON = 0.2
C1 = 0.5
C2 = 0.01
NUM_BATCHES = 10
MINI_BATCH_SIZE = 64
LAMBDA = 0.95


rewards_over_time = [0] * TOTAL_RUNS
env = gym.make("LunarLander-v3")

def get_action(probs):
    with torch.no_grad():
        dist = torch.distributions.Categorical(logits=probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob


for i in range(TOTAL_RUNS):
    if i % 100 == 0 and i != 0:
        print(f"{i / TOTAL_RUNS * 100:.2f}% done: Score {np.mean(rewards_over_time[i-100:i]):.2f}")

    state, _ = env.reset()
    states = []
    actions = []
    rewards = []
    next_states = []
    dones = []
    log_probs = []
    values = []
    total_reward = 0

    while True:

        with torch.no_grad():
            probs, value = agent(state)
        action, log_prob = get_action(probs)
        next_state, reward, terminated, truncated, _ = env.step(action)

        total_reward += reward

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        next_states.append(next_state)
        dones.append(1 if terminated or truncated else 0)
        log_probs.append(log_prob)
        values.append(value)
        
        if terminated or truncated:
            break
        state = next_state


    states = torch.tensor(np.array(states))
    next_states = torch.tensor(np.array(next_states))
    actions = torch.tensor(actions)
    log_probs = torch.tensor(log_probs)
    values = torch.cat(values)
    rewards = torch.tensor(rewards)
    dones = torch.tensor(dones)


    advantage = 0
    with torch.no_grad():
        _, next_values = agent(next_states)

    advantages = torch.zeros_like(actions, dtype=torch.float32)
    for j in reversed(range(len(states))):
        TD_error = rewards[j] + GAMMA * next_values[j] * (1 - dones[j]) - values[j]
        advantage = TD_error + GAMMA * LAMBDA * advantage * (1 - dones[j])
        
        advantages[j] = advantage
    
    
    targets = advantages + values
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    for j in range(NUM_BATCHES):
        indicies = np.arange(len(states))
        random.shuffle(indicies)
        for k in range(0, len(states), MINI_BATCH_SIZE):
            index = indicies[k:k+MINI_BATCH_SIZE]
            state = states[index]
            advantage = advantages[index]
            action = actions[index]
            log_prob = log_probs[index]
            target = targets[index]
            

            probs, value = agent(state)
            value = value.squeeze(1)

            distribution = torch.distributions.Categorical(logits=probs)

            #target = reward.unsqueeze(1) + GAMMA * next_values
            #advantage = target.detach() - values
            
            ratio = torch.exp(distribution.log_prob(action) - log_prob.detach())
            l_clip = torch.min(ratio * advantage.detach(), torch.clamp(ratio, min = 1 - EPSILON, max = 1 + EPSILON) * advantage.detach()).mean()
            l_vf = MSELoss(value, target.detach()).mean()
            l_s = distribution.entropy().mean()
            optimizer.zero_grad()
            loss = -1 * l_clip + C1 * l_vf - C2 * l_s
            #loss = torch.tensor(loss, dtype = torch.float32, requires_grad=True)
            #loss = loss.mean()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
            optimizer.step()


    #print(total_reward)
    rewards_over_time[i] = total_reward

plt.plot(rewards_over_time)

#env = gym.make("LunarLander-v3", render_mode = "human")

state, _ = env.reset()
total_reward = 0
for _ in range(100):
    with torch.no_grad():
        while True:
            probs, _ = agent(state)
            action, _ = get_action(probs)
            next_state, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward
            if terminated or truncated:
                break
            state = next_state

print(total_reward / 100)
env.close()