In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import gymnasium as gym

In [3]:
# --- Hyperparameters ---
ENV_NAME = 'Pendulum-v1'
SEED = 42
EPISODES = 5
MAX_STEPS = 200
BATCH_SIZE = 64
GAMMA = 0.99
TAU = 0.005
LR_ACTOR = 3e-4
LR_CRITIC = 3e-4
TRAIN_TEST_SPLIT = 0.8

# Set random seeds
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)

# --- Preprocessing ---
def preprocess_state(state):
    return torch.FloatTensor(np.array(state)).unsqueeze(0)

# --- Actor and Critic Networks ---
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU()
        )
        self.mean = nn.Linear(128, action_dim)
        self.log_std = nn.Linear(128, action_dim)
        self.max_action = max_action

    def forward(self, state):
        x = self.net(state)
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), -20, 2)
        std = log_std.exp()
        return mean, std

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        def critic_net():
            return nn.Sequential(
                nn.Linear(state_dim + action_dim, 128), nn.ReLU(),
                nn.Linear(128, 128), nn.ReLU(),
                nn.Linear(128, 1)
            )
        self.q1 = critic_net()
        self.q2 = critic_net()

    def forward(self, state, action):
        sa = torch.cat([state, action], dim=-1)
        return self.q1(sa), self.q2(sa)

# --- Replay Buffer ---
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(np.array(states)),
            torch.FloatTensor(np.array(actions)),
            torch.FloatTensor(np.array(rewards)).unsqueeze(1),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(np.array(dones)).unsqueeze(1)
        )

# --- SAC Agent ---
class SACAgent:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action)
        self.critic = Critic(state_dim, action_dim)
        self.target_critic = Critic(state_dim, action_dim)
        self.target_critic.load_state_dict(self.critic.state_dict())

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=LR_ACTOR)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=LR_CRITIC)
        self.replay_buffer = ReplayBuffer()
        self.max_action = max_action

    def select_action(self, state):
        with torch.no_grad():
            mean, std = self.actor(state)
            action = mean + std * torch.randn_like(std)
            return action.clamp(-self.max_action, self.max_action).cpu().numpy().flatten()

    def train(self):
        if len(self.replay_buffer.buffer) < BATCH_SIZE:
            return

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(BATCH_SIZE)

        with torch.no_grad():
            next_mean, next_std = self.actor(next_states)
            next_actions = next_mean + next_std * torch.randn_like(next_std)
            q1_next, q2_next = self.target_critic(next_states, next_actions)
            target_q = rewards + GAMMA * (1 - dones) * torch.min(q1_next, q2_next)

        q1, q2 = self.critic(states, actions)
        critic_loss = nn.MSELoss()(q1, target_q) + nn.MSELoss()(q2, target_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        mean, std = self.actor(states)
        new_actions = mean + std * torch.randn_like(std)
        q1_new, _ = self.critic(states, new_actions)
        actor_loss = -q1_new.mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
            target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)

# --- Main Workflow ---
env = gym.make(ENV_NAME)
env.reset(seed=SEED)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

agent = SACAgent(state_dim, action_dim, max_action)
train_episodes = int(EPISODES * TRAIN_TEST_SPLIT)
test_episodes = EPISODES - train_episodes

# Train Base Model
for episode in range(train_episodes):
    state, _ = env.reset()
    for _ in range(MAX_STEPS):
        state_tensor = preprocess_state(state)
        action = agent.select_action(state_tensor)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        agent.replay_buffer.add(state, action, reward, next_state, done)
        agent.train()
        state = next_state
        if done:
            break

# Fine-Tune Model
LR_ACTOR /= 2
LR_CRITIC /= 2
agent.actor_optimizer = optim.Adam(agent.actor.parameters(), lr=LR_ACTOR)
agent.critic_optimizer = optim.Adam(agent.critic.parameters(), lr=LR_CRITIC)

# Evaluate
rewards = []
for episode in range(test_episodes):
    state, _ = env.reset()
    total_reward = 0
    for _ in range(MAX_STEPS):
        state_tensor = preprocess_state(state)
        action = agent.select_action(state_tensor)
        next_state, reward, terminated, truncated, _ = env.step(action)
        total_reward += reward
        state = next_state
        if terminated or truncated:
            break
    rewards.append(total_reward)

print(f"Average Test Reward: {np.mean(rewards):.2f}")

torch.save(agent.actor.state_dict(), 'sac_actor.pth')
print("Policy deployed and saved as 'sac_actor.pth'")

Average Test Reward: -1375.92
Policy deployed and saved as 'sac_actor.pth'
