In [2]:
"""
一文读懂强化学习：RL全面解析与Pytorch实战
https://blog.csdn.net/magicyangjay111/article/details/132645347
"""

import torch
import torch.nn as nn
import torch.optim as optim 

In [4]:
# 定义策略网络
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Linear(state_dim, 128)
        self.policy_head = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc(x))
        return torch.softmax(self.policy_head(x), dim=-1)


# 初始化
state_dim = 4  # 状态维度
action_dim = 2  # 动作维度
epsilon = 0.2
count = 20

# mock random
# states = torch.rand(count, state_dim)
# actions = torch.randint(0, action_dim, (count,))
# rewards = torch.rand(count)

# mock sin
data_a = torch.randint(-5, 5, (count, 4)) + torch.rand(count, 4)
data_b = torch.sin(data_a)
data_c = torch.where(data_b[:, 0] >= 0, 0, 1)
states = data_b
actions = data_c
rewards = torch.rand(count)

print("states:", states)
print("actions:", actions)
print("rewards:", rewards)

# net
policy_net = PolicyNetwork(state_dim, action_dim)
optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)

# 计算旧策略的动作概率
with torch.no_grad():
    old_probs = policy_net(states).gather(1, actions.unsqueeze(-1)).squeeze()

# PPO更新 Typically we run multiple epochs
for i in range(50):
    action_probs = policy_net(states).gather(1, actions.unsqueeze(-1)).squeeze()
    ratio = action_probs / old_probs

    # loss A
    # surr1 = ratio * rewards
    # surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * rewards
    # loss = -torch.min(surr1, surr2).mean()

    # loss B
    loss = -action_probs.mean()

    print(f"Loss: {loss}")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("PPO Update Done!")

states: tensor([[ 0.8431,  0.6183, -0.6976,  0.9636],
        [-0.9992, -0.9788,  0.6931, -0.2512],
        [ 0.8475, -0.2918,  0.8291,  0.1092],
        [ 0.5724,  0.0162,  0.8681,  0.1696],
        [-0.5243,  0.5001,  0.6195,  0.6132],
        [-0.4970,  0.5142,  0.1246,  0.9916],
        [-0.3402,  0.5892,  0.9689,  0.2393],
        [ 0.0043, -0.7968, -0.4943, -0.9505],
        [-0.8316, -0.2831,  0.9955, -0.3262],
        [ 0.3890,  0.1835, -0.9353, -0.2703],
        [-0.5330,  0.9403, -0.9971,  0.4100],
        [ 0.5093, -0.9723, -0.5617, -0.4294],
        [ 0.8732, -0.4998, -0.4779,  0.8237],
        [-0.1981, -0.9468,  0.2812,  0.2855],
        [ 0.5168, -0.9835,  0.9296, -0.2432],
        [ 0.2869,  0.2392, -0.6645,  0.0677],
        [ 0.9342, -0.9985, -0.9611,  0.8807],
        [ 0.2883, -0.2654, -0.9639,  0.9757],
        [ 0.9791,  0.3236,  0.0508, -0.9896],
        [ 0.9720, -0.5688, -0.7570,  0.9588]])
actions: tensor([0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 

In [5]:
with torch.no_grad():
    new_probs = policy_net(states)  # .gather(1, actions.unsqueeze(-1)).squeeze()
    print(torch.cat((data_b[:, 0].unsqueeze(-1), actions.unsqueeze(-1), new_probs), dim=1))
    print("new_probs:", new_probs.gather(1, actions.unsqueeze(-1)).squeeze())

tensor([[ 0.8431,  0.0000,  0.8339,  0.1661],
        [-0.9992,  1.0000,  0.2530,  0.7470],
        [ 0.8475,  0.0000,  0.7661,  0.2339],
        [ 0.5724,  0.0000,  0.6293,  0.3707],
        [-0.5243,  1.0000,  0.2059,  0.7941],
        [-0.4970,  1.0000,  0.2671,  0.7329],
        [-0.3402,  1.0000,  0.2074,  0.7926],
        [ 0.0043,  0.0000,  0.8600,  0.1400],
        [-0.8316,  1.0000,  0.1684,  0.8316],
        [ 0.3890,  0.0000,  0.8494,  0.1506],
        [-0.5330,  1.0000,  0.4939,  0.5061],
        [ 0.5093,  0.0000,  0.9259,  0.0741],
        [ 0.8732,  0.0000,  0.9183,  0.0817],
        [-0.1981,  1.0000,  0.6656,  0.3344],
        [ 0.5168,  0.0000,  0.7695,  0.2305],
        [ 0.2869,  0.0000,  0.7787,  0.2213],
        [ 0.9342,  0.0000,  0.9670,  0.0330],
        [ 0.2883,  0.0000,  0.8683,  0.1317],
        [ 0.9791,  0.0000,  0.8404,  0.1596],
        [ 0.9720,  0.0000,  0.9454,  0.0546]])
new_probs: tensor([0.8339, 0.7470, 0.7661, 0.6293, 0.7941, 0.7329, 0.7926, 0.86