In [8]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random

In [9]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

def preprocess_state(state):
    return torch.tensor(state, dtype=torch.float32).unsqueeze(0)

train_episodes = int(0.8 * 100)
val_episodes = int(0.1 * 100)
test_episodes = 100 - train_episodes - val_episodes

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, action_dim), nn.Tanh()
        )

    def forward(self, state):
        return self.net(state)

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state, action):
        state_action = torch.cat([state, action], dim=-1)
        return self.net(state_action)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

actor = Actor(state_dim, action_dim)
critic_1 = Critic(state_dim, action_dim)
critic_2 = Critic(state_dim, action_dim)

actor_optimizer = optim.Adam(actor.parameters(), lr=0.001)
critic_optimizer = optim.Adam(
    list(critic_1.parameters()) + list(critic_2.parameters()), lr=0.001
)

replay_buffer = deque(maxlen=5000)

def sample_action(state):
    with torch.no_grad():
        return actor(preprocess_state(state)).squeeze(0).numpy()

def update_sac(batch_size=64):
    if len(replay_buffer) < batch_size:
        return

    batch = random.sample(replay_buffer, batch_size)
    states, actions, rewards, next_states, dones = zip(*batch)

    states = torch.tensor(np.array(states), dtype=torch.float32)
    actions = torch.tensor(np.array(actions), dtype=torch.float32)
    rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
    next_states = torch.tensor(np.array(next_states), dtype=torch.float32)
    dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)

    with torch.no_grad():
        next_actions = actor(next_states)
        target_q1 = critic_1(next_states, next_actions)
        target_q2 = critic_2(next_states, next_actions)
        target_q = rewards + 0.99 * torch.min(target_q1, target_q2) * (1 - dones)

    current_q1 = critic_1(states, actions)
    current_q2 = critic_2(states, actions)

    critic_loss = ((current_q1 - target_q) ** 2).mean() + ((current_q2 - target_q) ** 2).mean()

    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

    actor_loss = -critic_1(states, actor(states)).mean()

    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

for episode in range(train_episodes):
    state, _ = env.reset()
    done = False
    while not done:
        action = sample_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.append((state, action, reward, next_state, done))
        state = next_state
        update_sac()

def fine_tune(episodes=10):
    for episode in range(episodes):
        state, _ = env.reset()
        done = False
        while not done:
            action = sample_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            replay_buffer.append((state, action, reward, next_state, done))
            state = next_state
            update_sac()

fine_tune()

def evaluate(episodes=5):
    total_rewards = []
    for _ in range(episodes):
        state, _ = env.reset()
        done = False
        episode_reward = 0
        while not done:
            action = sample_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            state = next_state
        total_rewards.append(episode_reward)
    print("Average Reward:", np.mean(total_rewards))

evaluate()

def deploy_policy(episodes=3):
    for _ in range(episodes):
        state, _ = env.reset()
        done = False
        while not done:
            env.render()
            action = sample_action(state)
            next_state, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            state = next_state
    env.close()

deploy_policy()

Average Reward: -1332.191177700093
