In [None]:
!pip install gymnasium[mujoco]
!pip install torch

Collecting gymnasium[mujoco]
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1 (from gymnasium[mujoco])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Collecting mujoco>=2.3.3 (from gymnasium[mujoco])
  Downloading mujoco-3.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Collecting glfw (from mujoco>=2.3.3->gymnasium[mujoco])
  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.8/211.8 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: glfw, farama-notifications, gymnasium, mujoco
Successfully instal

In [None]:
import gymnasium as gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from torch.distributions import MultivariateNormal

import sys

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        self.log_std = nn.Parameter(torch.zeros(action_dim))

    def forward(self, x):
        mean = self.fc(x)
        std = self.log_std.exp().expand_as(mean)
        return mean, std

class ValueNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.fc(x)

In [None]:
def compute_gae(rewards, masks, values, gamma=0.99, lam=0.95):
    returns = []
    gae = 0
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * lam * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

In [None]:
env = gym.make('Ant-v4')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
hidden_dim = 256
gamma = 0.99
lam = 0.95
policy_lr = 1e-3
value_lr = 1e-3

policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim)
value_net = ValueNetwork(state_dim, hidden_dim)
optimizer_policy = optim.Adam(policy_net.parameters(), lr=policy_lr)
optimizer_value = optim.Adam(value_net.parameters(), lr=value_lr)

In [None]:
for episode in range(1000):
    state, _ = env.reset()
    log_probs = []
    values = []
    rewards = []
    masks = []
    entropy = 0

    for _ in range(200):
        state = torch.FloatTensor(state)
        mean, std = policy_net(state)
        dist = Normal(mean, std)
        action = Categorical(dist).sample()

        next_state, reward, terminated, truncated, info = env.step(action.item())
        log_prob = Categorical(dist).log_prob(action)
        entropy += Categorical(dist).entropy().mean()

        log_probs.append(log_prob)
        values.append(value)
        rewards.append(reward)
        masks.append(1 - done)

        state = next_state
        if done:
            break

    next_state = torch.FloatTensor(next_state)
    next_value = value_net(next_state)
    values.append(next_value)

    returns = compute_gae(rewards, masks, values)
    returns = torch.cat(returns).detach()
    log_probs = torch.cat(log_probs)
    values = torch.cat(values[:-1])

    advantage = returns - values

    policy_loss = -(log_probs * advantage.detach()).mean()
    value_loss = (returns - values).pow(2).mean()

    optimizer_policy.zero_grad()
    policy_loss.backward()
    optimizer_policy.step()

    optimizer_value.zero_grad()
    value_loss.backward()
    optimizer_value.step()


    if episode % 10 == 0:
        print(f'Episode {episode}, Policy Loss: {policy_loss.item()}, Value Loss: {value_loss.item()}')

0


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
