In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import gymnasium as gym

In [3]:
# ---- Environment ----
env = gym.make("Pendulum-v1", render_mode="rgb_array")  # Lightweight continuous control environment
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = float(env.action_space.high[0])

# ---- Preprocessing ----
def preprocess_state(state):
    return torch.tensor(state, dtype=torch.float32)

# ---- Actor and Critic Networks ----
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, action_bound):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Tanh()
        )
        self.action_bound = action_bound

    def forward(self, state):
        return self.fc(state) * self.action_bound

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state, action):
        return self.fc(torch.cat([state, action], dim=-1))

# ---- Replay Buffer ----
class ReplayBuffer:
    def __init__(self, max_size=10000):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        samples = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*samples)
        return (
            torch.tensor(np.array(states), dtype=torch.float32),
            torch.tensor(np.array(actions), dtype=torch.float32),
            torch.tensor(np.array(rewards), dtype=torch.float32).unsqueeze(-1),
            torch.tensor(np.array(next_states), dtype=torch.float32),
            torch.tensor(np.array(dones), dtype=torch.float32).unsqueeze(-1)
        )

# ---- Train-Test Split ----
train_ratio = 0.8
train_episodes = int(200 * train_ratio)
val_episodes = int(200 * 0.1)
test_episodes = 200 - train_episodes - val_episodes

# ---- TD3 Hyperparameters ----
actor_lr = 1e-3
critic_lr = 1e-3
discount = 0.99
tau = 0.005
policy_noise = 0.2
noise_clip = 0.5
policy_freq = 2
batch_size = 64

# ---- Initialize Networks and Optimizers ----
actor = Actor(state_dim, action_dim, action_bound)
actor_target = Actor(state_dim, action_dim, action_bound)
critic1 = Critic(state_dim, action_dim)
critic1_target = Critic(state_dim, action_dim)
critic2 = Critic(state_dim, action_dim)
critic2_target = Critic(state_dim, action_dim)

actor_target.load_state_dict(actor.state_dict())
critic1_target.load_state_dict(critic1.state_dict())
critic2_target.load_state_dict(critic2.state_dict())

actor_optimizer = optim.Adam(actor.parameters(), lr=actor_lr)
critic1_optimizer = optim.Adam(critic1.parameters(), lr=critic_lr)
critic2_optimizer = optim.Adam(critic2.parameters(), lr=critic_lr)

replay_buffer = ReplayBuffer()

# ---- Utilities ----
def soft_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

# ---- Train Base Model (TD3) ----
def train_td3(episodes):
    for episode in range(episodes):
        state = preprocess_state(env.reset()[0])
        episode_reward = 0.0
        done = False

        while not done:
            action = actor(state).detach().cpu().numpy()
            noise = np.random.normal(0, policy_noise, size=action_dim)
            action = np.clip(action + noise, -action_bound, action_bound)

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            replay_buffer.add(state.numpy(), action, reward, next_state, done)
            state = preprocess_state(next_state)
            episode_reward += reward

            if len(replay_buffer.buffer) >= batch_size:
                states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

                with torch.no_grad():
                    noise = (torch.randn_like(actions) * policy_noise).clamp(-noise_clip, noise_clip)
                    next_actions = (actor_target(next_states) + noise).clamp(-action_bound, action_bound)
                    target_q1 = critic1_target(next_states, next_actions)
                    target_q2 = critic2_target(next_states, next_actions)
                    target_q = rewards + discount * (1 - dones) * torch.min(target_q1, target_q2)

                current_q1 = critic1(states, actions)
                current_q2 = critic2(states, actions)

                critic1_loss = nn.MSELoss()(current_q1, target_q)
                critic2_loss = nn.MSELoss()(current_q2, target_q)

                critic1_optimizer.zero_grad()
                critic1_loss.backward()
                critic1_optimizer.step()

                critic2_optimizer.zero_grad()
                critic2_loss.backward()
                critic2_optimizer.step()

                if episode % policy_freq == 0:
                    actor_loss = -critic1(states, actor(states)).mean()
                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    soft_update(actor_target, actor)
                    soft_update(critic1_target, critic1)
                    soft_update(critic2_target, critic2)

        print(f"Episode {episode + 1}/{episodes}: Reward = {episode_reward:.2f}")

train_td3(train_episodes)

# ---- Fine-Tune Model ----
def fine_tune_model(episodes):
    print("Fine-tuning model...")
    train_td3(episodes)

fine_tune_model(val_episodes)

# ---- Evaluate ----
def evaluate_model(episodes):
    total_reward = 0.0
    for episode in range(episodes):
        state = preprocess_state(env.reset()[0])
        episode_reward = 0.0
        done = False

        while not done:
            action = actor(state).detach().cpu().numpy()
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            state = preprocess_state(next_state)
            episode_reward += reward

        total_reward += episode_reward
        print(f"Test Episode {episode + 1}/{episodes}: Reward = {episode_reward:.2f}")

    avg_reward = total_reward / episodes
    print(f"Average Test Reward: {avg_reward:.2f}")

evaluate_model(test_episodes)

# ---- Deploy Policy ----
def deploy_policy():
    print("Deploying policy...")
    for episode in range(3):
        state = preprocess_state(env.reset()[0])
        done = False
        while not done:
            env.render()
            action = actor(state).detach().cpu().numpy()
            next_state, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            state = preprocess_state(next_state)
    env.close()

deploy_policy()

Episode 1/160: Reward = -1146.41
Episode 2/160: Reward = -1072.43
Episode 3/160: Reward = -1398.38
Episode 4/160: Reward = -1858.73
Episode 5/160: Reward = -1566.06
Episode 6/160: Reward = -1588.78
Episode 7/160: Reward = -1498.01
Episode 8/160: Reward = -1484.39
Episode 9/160: Reward = -1565.92
Episode 10/160: Reward = -1306.10
Episode 11/160: Reward = -1478.86
Episode 12/160: Reward = -1402.98
Episode 13/160: Reward = -1501.75
Episode 14/160: Reward = -1334.18
Episode 15/160: Reward = -1463.20
Episode 16/160: Reward = -1580.51
Episode 17/160: Reward = -1551.30
Episode 18/160: Reward = -1538.05
Episode 19/160: Reward = -1574.26
Episode 20/160: Reward = -1139.00
Episode 21/160: Reward = -1154.60
Episode 22/160: Reward = -1163.12
Episode 23/160: Reward = -1294.75
Episode 24/160: Reward = -1211.27
Episode 25/160: Reward = -1113.14
Episode 26/160: Reward = -1002.39
Episode 27/160: Reward = -1020.69
Episode 28/160: Reward = -900.82
Episode 29/160: Reward = -934.16
Episode 30/160: Reward = 