# Задача о Маятнике

Наша цель обучить несколько моделей для этой [задачи](https://gymnasium.farama.org/environments/classic_control/pendulum/):

## Подготовка данных

In [None]:
import gymnasium as gym

In [None]:
env = gym.make("Pendulum-v1", render_mode="rgb_array", g=9.81)

In [None]:
env

<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>

In [None]:
env.reset(seed=123, options={"low": -0.7, "high": 0.5})

(array([ 0.4123625 ,  0.91101986, -0.89235795], dtype=float32), {})

In [None]:
!pip install stable_baselines3

Collecting stable_baselines3
  Downloading stable_baselines3-2.6.0-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3.0,>=2.3->stable_baselines3)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3.0,>=2.3->stable_baselines3)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3.0,>=2.3->stable_baselines3)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3.0,>=2.3->stable_baselines3)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3.0,>=2.3->stable_baselines3)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (

## Генеративная модель (предсказываю оптимальное действие)

### Обучение

In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.preprocessing import StandardScaler

# ⚙️ Настройки
ENV_NAME = "Pendulum-v1"
LATENT_SIZE = 32  # Увеличим для лучшего представления
EPOCHS = 80
BATCH_SIZE = 64
LR = 3e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 🎯 Создаем среду
env = gym.make(ENV_NAME)
state_size = env.observation_space.shape[0]
action_size = env.action_space.shape[0]

# 🧱 Улучшенный автоэнкодер с ветвью для действий
class PendulumVAE(nn.Module):
    def __init__(self, state_size, action_size, latent_size):
        super().__init__()

        # Энкодер
        self.encoder = nn.Sequential(
            nn.Linear(state_size, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2)
        )

        # Латентное пространство
        self.fc_mu = nn.Linear(64, latent_size)
        self.fc_var = nn.Linear(64, latent_size)

        # Декодер для состояний
        self.state_decoder = nn.Sequential(
            nn.Linear(latent_size, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, state_size)
        )

        # Предиктор действий
        self.action_predictor = nn.Sequential(
            nn.Linear(latent_size, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.2),
            nn.Linear(32, action_size),
            nn.Tanh()  # Ограничиваем выход [-1, 1]
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        state_recon = self.state_decoder(z)
        action_pred = self.action_predictor(z)
        return state_recon, action_pred

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        state_recon, action_pred = self.decode(z)
        return state_recon, action_pred, mu, logvar

# 📦 Сбор данных с оптимальными действиями
def get_optimal_action(state):
    """Упрощенный PD-контроллер"""
    angle = np.arctan2(state[1], state[0])  # Извлекаем угол θ из [cosθ, sinθ]
    angular_vel = state[2]
    return np.clip(-1.5*angle - 0.3*angular_vel, -2.0, 2.0)

print("📦 Сбор данных...")
states, optimal_actions = [], []
for _ in range(1000):
    state, _ = env.reset()
    done = False
    while not done:
        action = get_optimal_action(state)
        states.append(state)
        optimal_actions.append(action)
        state, _, terminated, truncated, _ = env.step([action])
        done = terminated or truncated
env.close()

states = np.array(states)
optimal_actions = np.array(optimal_actions)

# ⚖️ Нормализация
state_scaler = StandardScaler()
states_normalized = state_scaler.fit_transform(states)

action_scaler = StandardScaler()
actions_normalized = action_scaler.fit_transform(optimal_actions.reshape(-1, 1))

# 🔧 Обучение
model = PendulumVAE(state_size, action_size, LATENT_SIZE).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)

# Функция потерь с компонентами:
# 1. Реконструкция состояния (MSE)
# 2. Предсказание действия (MSE)
# 3. KL-дивергенция (регуляризация латентного пространства)
def loss_function(recon_state, state, pred_action, action, mu, logvar):
    # Реконструкция состояния
    recon_loss = nn.MSELoss()(recon_state, state)

    # Предсказание действия
    action_loss = nn.MSELoss()(pred_action, action)

    # KL-дивергенция
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + action_loss + 0.1*kl_loss  # Вес KL можно регулировать

# Преобразуем данные в тензоры
states_tensor = torch.FloatTensor(states_normalized).to(DEVICE)
actions_tensor = torch.FloatTensor(actions_normalized).to(DEVICE)

# 🚀 Обучение
print("🚀 Начало обучения...")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    # Перемешиваем данные
    perm = torch.randperm(len(states_tensor))

    for i in range(0, len(states_tensor), BATCH_SIZE):
        batch_idx = perm[i:i+BATCH_SIZE]
        batch_states = states_tensor[batch_idx]
        batch_actions = actions_tensor[batch_idx]

        # Forward pass
        recon_states, pred_actions, mu, logvar = model(batch_states)

        # Вычисление потерь
        loss = loss_function(recon_states, batch_states,
                           pred_actions, batch_actions,
                           mu, logvar)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Логирование
    if epoch % 20 == 0:
        avg_loss = total_loss / (len(states_tensor) // BATCH_SIZE)
        print(f"Epoch {epoch}/{EPOCHS}, Loss: {avg_loss:.4f}")

# 🧪 Тестирование
def test_model():
    model.eval()
    test_state, _ = env.reset()

    for _ in range(100):
        # Нормализуем состояние
        test_state_norm = state_scaler.transform([test_state])
        test_state_tensor = torch.FloatTensor(test_state_norm).to(DEVICE)

        with torch.no_grad():
            _, pred_action_norm, _, _ = model(test_state_tensor)

        # Денормализуем действие
        pred_action = action_scaler.inverse_transform(
            pred_action_norm.cpu().numpy())[0]

        # Применяем действие
        next_state, _, terminated, truncated, _ = env.step(pred_action)
        done = terminated or truncated

        # Визуализация
        env.render()

        if done:
            break

        test_state = next_state

    env.close()

print("🧪 Тестирование модели...")
test_model()

📦 Сбор данных...
🚀 Начало обучения...
Epoch 0/80, Loss: 2.0209
Epoch 20/80, Loss: 2.0001
Epoch 40/80, Loss: 2.0002
Epoch 60/80, Loss: 2.0001
🧪 Тестирование модели...


  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger.warn(
  gym.logger