In [5]:
!pip install tensorboard
!pip install gymnasium torch numpy pandas matplotlib

import os, json, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import gymnasium as gym
from gymnasium import spaces

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device =", device)

device = cpu


In [6]:
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from itertools import count

import gymnasium as gym

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Bernoulli

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 64)
        self.lstm = nn.LSTM(64, 128, batch_first=True)
        self.fc2 = nn.Linear(128, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, hidden):
        # x: (B, T, 2)
        x = self.relu(self.fc1(x))
        x, hidden = self.lstm(x, hidden)
        x = self.relu(x)
        x = self.sigmoid(self.fc2(x))  # (B, T, 1)
        return x, hidden

    def select_action(self, state, hidden):
        # state: (1, 1, 2)
        with torch.no_grad():
            prob, hidden = self.forward(state, hidden)  # (1,1,1)
            b = Bernoulli(prob)
            action = b.sample()  # 0/1
        return int(action.item()), hidden


class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 64)
        self.lstm = nn.LSTM(64, 256, batch_first=True)
        self.fc2 = nn.Linear(256, 1)
        self.relu = nn.ReLU()

    def forward(self, x, hidden):
        x = self.relu(self.fc1(x))
        x, hidden = self.lstm(x, hidden)
        x = self.relu(x)
        x = self.fc2(x)  # (B,T,1)
        return x, hidden


def obs_to_partial(obs):
    # CartPole obs: [x, x_dot, theta, theta_dot]
    # keep only x, theta
    return np.array([obs[0], obs[2]], dtype=np.float32)


if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    policy = PolicyNetwork().to(device)
    value = ValueNetwork().to(device)

    optim = torch.optim.Adam(policy.parameters(), lr=1e-4)
    value_optim = torch.optim.Adam(value.parameters(), lr=3e-4)

    gamma = 0.99
    writer = SummaryWriter("./lstm_logs")

    for epoch in count():
        obs, info = env.reset(seed=None)
        state = obs_to_partial(obs)
        episode_reward = 0.0

        # LSTM hidden init
        a_hx = torch.zeros((1, 1, 128), device=device)
        a_cx = torch.zeros((1, 1, 128), device=device)

        rewards = []
        actions = []
        states = []

        for t in range(500):  # CartPole-v1 max is typically 500
            states.append(state.copy())

            state_t = torch.tensor(state, dtype=torch.float32, device=device).view(1, 1, 2)
            action, (a_hx, a_cx) = policy.select_action(state_t, (a_hx, a_cx))
            actions.append(action)

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            next_state = obs_to_partial(next_obs)
            episode_reward += float(reward)

            rewards.append(float(reward))
            state = next_state

            if done:
                break

        # returns
        returns = np.zeros(len(rewards), dtype=np.float32)
        R = 0.0
        for i in reversed(range(len(rewards))):
            R = gamma * R + rewards[i]
            returns[i] = R

        # normalize returns (stability)
        mean, std = returns.mean(), returns.std()
        std = std if std > 1e-8 else 1.0
        returns = (returns - mean) / std

        # tensors
        states_tensor = torch.tensor(np.array(states), dtype=torch.float32, device=device).unsqueeze(0)  # (1,T,2)
        actions_tensor = torch.tensor(np.array(actions), dtype=torch.float32, device=device).view(-1, 1)  # (T,1)
        returns_tensor = torch.tensor(returns, dtype=torch.float32, device=device).view(-1, 1)  # (T,1)

        # critic to get baseline
        with torch.no_grad():
            c_hx = torch.zeros((1, 1, 256), device=device)
            c_cx = torch.zeros((1, 1, 256), device=device)
            v, _ = value(states_tensor, (c_hx, c_cx))  # (1,T,1)
            v = v.squeeze(0)  # (T,1)
            advantage = returns_tensor - v  # (T,1)

        # actor update (re-run policy on full sequence)
        a_hx = torch.zeros((1, 1, 128), device=device)
        a_cx = torch.zeros((1, 1, 128), device=device)
        prob, _ = policy(states_tensor, (a_hx, a_cx))  # (1,T,1)
        prob = prob.squeeze(0)  # (T,1)

        b = Bernoulli(prob)
        log_prob = b.log_prob(actions_tensor)  # (T,1)

        actor_loss = -(log_prob * advantage.detach()).mean()

        optim.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
        optim.step()
        writer.add_scalar("loss/actor", actor_loss.item(), epoch)

        # critic update
        c_hx = torch.zeros((1, 1, 256), device=device)
        c_cx = torch.zeros((1, 1, 256), device=device)
        v, _ = value(states_tensor, (c_hx, c_cx))
        v = v.squeeze(0)
        value_loss = F.mse_loss(v, returns_tensor)

        value_optim.zero_grad()
        value_loss.backward()
        torch.nn.utils.clip_grad_norm_(value.parameters(), 1.0)
        value_optim.step()
        writer.add_scalar("loss/value", value_loss.item(), epoch)

        writer.add_scalar("episode_reward", episode_reward, epoch)

        if epoch % 10 == 0:
            print(f"Epoch {epoch:05d} | ep_reward {episode_reward:.1f} | T={len(rewards)}")
            torch.save(policy.state_dict(), "lstm-policy.pt")


      

Epoch 00000 | ep_reward 12.0 | T=12
Epoch 00010 | ep_reward 15.0 | T=15
Epoch 00020 | ep_reward 13.0 | T=13
Epoch 00030 | ep_reward 28.0 | T=28
Epoch 00040 | ep_reward 40.0 | T=40
Epoch 00050 | ep_reward 23.0 | T=23
Epoch 00060 | ep_reward 32.0 | T=32
Epoch 00070 | ep_reward 48.0 | T=48
Epoch 00080 | ep_reward 13.0 | T=13
Epoch 00090 | ep_reward 19.0 | T=19
Epoch 00100 | ep_reward 12.0 | T=12
Epoch 00110 | ep_reward 17.0 | T=17
Epoch 00120 | ep_reward 22.0 | T=22
Epoch 00130 | ep_reward 27.0 | T=27
Epoch 00140 | ep_reward 23.0 | T=23
Epoch 00150 | ep_reward 27.0 | T=27
Epoch 00160 | ep_reward 9.0 | T=9
Epoch 00170 | ep_reward 11.0 | T=11
Epoch 00180 | ep_reward 10.0 | T=10
Epoch 00190 | ep_reward 16.0 | T=16
Epoch 00200 | ep_reward 20.0 | T=20
Epoch 00210 | ep_reward 58.0 | T=58
Epoch 00220 | ep_reward 20.0 | T=20
Epoch 00230 | ep_reward 31.0 | T=31
Epoch 00240 | ep_reward 14.0 | T=14
Epoch 00250 | ep_reward 36.0 | T=36
Epoch 00260 | ep_reward 22.0 | T=22
Epoch 00270 | ep_reward 19.0 |

KeyboardInterrupt: 