In [1]:
!pip install gymnasium[mujoco]
!pip install torch

Collecting gymnasium[mujoco]
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1 (from gymnasium[mujoco])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Collecting mujoco>=2.3.3 (from gymnasium[mujoco])
  Downloading mujoco-3.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m67.2 MB/s[0m eta [36m0:00:00[0m
Collecting glfw (from mujoco>=2.3.3->gymnasium[mujoco])
  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.8/211.8 kB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: glfw, farama-notifications, gymnasium, mujoco
Successfully insta

In [14]:
import gymnasium as gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import MultivariateNormal

import sys

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
# 모델 정의
class PolicyNetwork(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(obs_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc_mean = nn.Linear(64, action_dim)
        self.fc_log_std = nn.Linear(64, action_dim)
        self.to(device)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_log_std(x)
        return mean, log_std

class ValueNetwork(nn.Module):
    def __init__(self, obs_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(obs_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc_value = nn.Linear(64, 1)
        self.to(device)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        value = self.fc_value(x)
        return value

In [16]:
# 가우시안 분포 생성 함수
def get_action_and_log_prob(state, policy):
    mean, log_std = policy(state)
    std = log_std.exp()
    dist = MultivariateNormal(mean, torch.diag_embed(std))
    action = dist.sample()
    log_prob = dist.log_prob(action)
    return action, log_prob

In [17]:
# PPO 업데이트 함수
def ppo_update(policy, value_net, optimizer_policy, optimizer_value, states, actions, rewards, old_log_probs, advantages):
    for _ in range(K_epochs):
        mean, log_std = policy(states)
        std = log_std.exp()
        dist = MultivariateNormal(mean, torch.diag_embed(std))
        new_log_probs = dist.log_prob(actions)
        ratio = (new_log_probs - old_log_probs).exp()

        surrogate1 = ratio * advantages
        surrogate2 = torch.clamp(ratio, 1 - epsilon_clip, 1 + epsilon_clip) * advantages

        policy_loss = -torch.min(surrogate1, surrogate2).mean()

        # Value loss
        values = value_net(states).squeeze()
        value_loss = ((values - rewards) ** 2).mean()

        optimizer_policy.zero_grad()
        policy_loss.backward()
        optimizer_policy.step()

        optimizer_value.zero_grad()
        value_loss.backward()
        optimizer_value.step()

In [26]:
# 환경 설정
env = gym.make('Ant-v4')
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

In [19]:
# 하이퍼파라미터 설정
learning_rate = 3e-4
gamma = 0.99
epsilon_clip = 0.2
K_epochs = 10
T_horizon = 2048

In [20]:
# 모델 초기화 및 옵티마이저 설정
policy = PolicyNetwork(obs_dim, action_dim).to(device)
value_net = ValueNetwork(obs_dim).to(device)
optimizer_policy = optim.Adam(policy.parameters(), lr=learning_rate)
optimizer_value = optim.Adam(value_net.parameters(), lr=learning_rate)

In [28]:
# 메인 학습 루프
for episode in range(1000):
    state, _ = env.reset()
    terminated, truncated = False, False
    rewards = []
    log_probs = []
    states = []
    actions = []
    values = []
    episode_reward = 0

    for t in range(T_horizon):
        state = torch.tensor(state, dtype=torch.float32).to(device)
        action, log_prob = get_action_and_log_prob(state, policy)
        value = value_net(state)

        next_state, reward, terminated, truncated, _ = env.step(action.cpu().detach().numpy())

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        log_probs.append(log_prob)
        values.append(value)

        state = next_state
        episode_reward += reward

        if terminated or truncated:
            break

    # Advantage 및 Discounted Rewards 계산
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
    values = torch.stack(values).squeeze().detach()
    discounted_rewards = []
    G = 0
    for reward in reversed(rewards):
        G = reward + gamma * G
        discounted_rewards.insert(0, G)
    discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32).to(device)
    advantages = discounted_rewards - values

    # Advantage 정규화
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)

    # PPO 업데이트
    states = torch.stack(states)
    actions = torch.stack(actions)
    old_log_probs = torch.stack(log_probs).detach()
    ppo_update(policy, value_net, optimizer_policy, optimizer_value, states, actions, discounted_rewards, old_log_probs, advantages)

    if episode % 10 == 0:
        print(f"Episode {episode}: Reward {episode_reward}")

env.close()

Episode 0: Reward 641.9183357919492


  logger.warn("Unable to save last video! Did you call close()?")


Episode 10: Reward 692.5380592047582
Episode 20: Reward 771.1388686969024
Episode 30: Reward 786.9480332789469
Episode 40: Reward 781.4883237282395


KeyboardInterrupt: 

In [22]:
import os

os.environ['MUJOCO_GL']='egl'

In [23]:
env = gym.make('Ant-v4', render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env, video_folder='./videos')

In [24]:
s, _ = env.reset()
terminated, truncated = False, False
while not (terminated or truncated):
    s = torch.tensor(s, dtype=torch.float32).to(device)
    a, _ = get_action_and_log_prob(s, policy)
    s, r, terminated, truncated, _ = env.step(a.cpu().detach().numpy())
env.close()

Moviepy - Building video /content/videos/rl-video-episode-0.mp4.
Moviepy - Writing video /content/videos/rl-video-episode-0.mp4





Moviepy - Done !
Moviepy - video ready /content/videos/rl-video-episode-0.mp4
