In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.special import softmax
import numpy as np
import os

import matplotlib.pyplot as plt
from tqdm import tqdm

import draw
import utils

In [2]:
!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gym==0.25.2
!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pygame==2.5.2
!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple moviepy

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [2]:
from base64 import b64encode
from IPython.display import display, HTML
from moviepy.editor import ImageSequenceClip

# 录制CartPole环境的视频
def record_video(agent, env_name='CartPole-v0', video_dir='video'):
    env = gym.make(env_name)
    env = gym.wrappers.RecordVideo(env, video_dir, episode_trigger=lambda x: True)
    state = env.reset()
    done = False
    cnt = 0
    while not done:
        action = agent.take_action(state)
        state, _, done, _ = env.step(action)
        cnt = cnt+1
    print(cnt)
    env.close()

def display_video(file_path):
    video = open(file_path, "rb").read()
    encoded_video = b64encode(video).decode("ascii")
    display(HTML(data=f"""
        <video width="640" height="480" controls>
            <source src="data:video/mp4;base64,{encoded_video}" type="video/mp4" />
        </video>
    """))

In [3]:
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, 128)
        self.fc2 = torch.nn.Linear(128, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=-1)

In [4]:
class REINFORCE:
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device = torch.device("cpu")):
        self.action_dim = action_dim
        self.state_dim =state_dim
        self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.device = device
        
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=learning_rate)  # 使用Adam优化器
        self.gamma = gamma  # 折扣因子
        self.epsilon = 1

    def save_model(self, path):
        torch.save(self.policy_net.state_dict(), path)

    def load_model(self, path):
        self.policy_net.load_state_dict(torch.load(path, map_location=self.device))
        
    def save_model(self, path):
        torch.save(self.policy_net, path)
    def load_model(self, path):
        self.policy_net = torch.load(path)
        
    def take_action(self, state): 
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.action_dim)
        else:
            input = torch.tensor(state, dtype=torch.float).to(self.device)
            output = self.policy_net(input)
            now_frame_probabilities  = output.detach().cpu().numpy()
            action = np.random.choice(self.action_dim, p=now_frame_probabilities)
            return action
    
    def update(self, transition_dict):
        reward_list = transition_dict['rewards']
        state_list = transition_dict['states']
        action_list = transition_dict['actions']
        G = 0
        self.optimizer.zero_grad()
        for i in reversed(range(len(reward_list))):  # 从最后一步算起
            reward = reward_list[i]
            state = state_list[i]
            action = action_list[i]
            G = self.gamma * G + reward
            input = torch.tensor(state, dtype=torch.float).to(self.device)
            log_prob = torch.log(self.policy_net(input)[action]) #对概率求log，这个是用来求梯度用的中间变量
            loss = log_prob * G  #乘以G，G是当前这个时刻点到终止时的得分，如果是正的就鼓励它，如果是负的就惩罚它
            loss = -loss  #梯度上升，得把符号逆置一下
            loss.backward()
        self.optimizer.step()
        

In [5]:
def train():
    env = gym.make('CartPole-v0')
    agent = REINFORCE(state_dim = 4, 
              hidden_dim= 128, 
              action_dim = 2, 
              learning_rate = 0.001, 
              gamma = 0.98)

    print(agent.policy_net)
    return_list = []
    agent.epsilon = 0
    num_episodes = 300
    pbar = tqdm(range(num_episodes))
    for i in pbar:# 10000
        episode_return = 0
        transition_dict = {
            'states': [],
            'actions': [],
            'next_states': [],
            'rewards': [],
            'dones': []
        }
        state = env.reset()
        done = False
        cnt = 0
        while not done:
            cnt = cnt + 1
            # print(state)
            action = agent.take_action(state) ########
            # print(action)
            next_state, reward, done, _ = env.step(action)
            transition_dict['states'].append(state)
            transition_dict['actions'].append(action)
            transition_dict['next_states'].append(next_state)
            transition_dict['rewards'].append(reward)
            transition_dict['dones'].append(done)
            state = next_state
            episode_return += reward
        return_list.append(episode_return)
        agent.update(transition_dict)
        if (i + 1) % 10 == 0:
            pbar.set_postfix({
                'episode':
                '%d' % i,
                'return':
                '%.3f' % np.mean(return_list[-10:]),
                'cnt': cnt
            })
        pbar.update(1)

        if(i % 100==0):
            agent.save_model('./models/'+f"{i}.pth")

    print(sum(return_list)/len(return_list))
    agent.save_model('./models/'+f"{num_episodes}.pth")
    return agent

In [6]:
training_agent = train()

PolicyNet(
  (fc1): Linear(in_features=4, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:12<00:00, 24.10it/s, episode=299, return=155.600, cnt=168]

60.13





In [7]:
record_video(training_agent)
display_video('./video/rl-video-episode-0.mp4')

144


In [8]:
agent123 = REINFORCE(state_dim = 4, 
          hidden_dim= 128, 
          action_dim = 2, 
          learning_rate = 0.001, 
          gamma = 0.98)
agent123.epsilon = 0
agent123.load_model('./models/300.pth')
record_video(agent123)
display_video('./video/rl-video-episode-0.mp4')

195
