In [1]:
%matplotlib inline
import cv2
import numpy as np
import os
from tqdm.notebook import tqdm

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib_inline import backend_inline
from IPython.display import HTML, clear_output


backend_inline.set_matplotlib_formats('svg')

In [2]:
import sys
# print(sys.path)

from rl_agents import ppo, mult_processor
from models import example
import logging
import params
import torch
import os
import utils
from tqdm import tqdm
import torch.nn.functional as F
import shutil
from torch.utils.tensorboard import SummaryWriter
import time
import json

pygame 2.5.2 (SDL 2.28.2, Python 3.10.0)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
# 全局定义训练环境
sys.path.append("./env/g_env") 
from env.g_env import mask_env

ENV      = mask_env.GuidewireEnv
ENV_NAME = "g_env"

train_param = params.TrainParams()
env_param   = params.EnvParams()
run_param    = params.RuntimeParams()


train_param.load_from_json(f'./env/{ENV_NAME}/params')
env_param.load_from_json(f'./env/{ENV_NAME}/params')
run_param.load_from_json(f'./env/{ENV_NAME}/params')

In [4]:
model = example.CNN_FC(env_param.input_dense, env_param.actions)
model = model.to(train_param.device)
agent = ppo.Agent(model)
weight_path = os.path.join("./weights", run_param.task_name)
agent.load(os.path.join(weight_path, f"best.pth"))
for key in dir(train_param):
    if not callable(getattr(train_param, key)) and not key.startswith('__'):
        setattr(agent, key, getattr(train_param, key))

cuda:0


In [5]:
model2 = example.CNN_FC(env_param.input_dense, env_param.actions)
model2 = model.to(train_param.device)
agent2 = ppo.Agent(model2)
for key in dir(train_param):
    if not callable(getattr(train_param, key)) and not key.startswith('__'):
        setattr(agent, key, getattr(train_param, key))

cuda:0


In [32]:
def render(agent:ppo.Agent, Env, 
			env_param:params.EnvParams, index=None):
	"""运行一局并测试性能"""
	# if not isinstance(Env, _env.BaseEnv):
	env = Env()
	out = []
	out_action = []
	# else:
	# 	env = Env
	env.set_params(env_param)
	s = env.reset(index)
	reward_total = 0
	spend_steps  = 0
	agent.ac_model.eval()
	for _ in tqdm(range(env_param.max_steps)):
		img:np.ndarray = env.render()
		out.append(img.swapaxes(0, 2))
		spend_steps += 1
		a, _, _ = agent.desision(s)
		out_action.append(a)
		s1, r, d, _ = env.step(a)
		s = s1
		reward_total += r
		if d:
			break
	return out, out_action

def display_video(frames:list, framerate:int=30, dpi:int=70):
    '''
        在Jupyter Notebook页面中生成视频
    '''
    height, width, _ = frames[0].shape
    orig_backend = matplotlib.get_backend()
    matplotlib.use('Agg') 
    fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi))
    matplotlib.use(orig_backend)
    ax.set_axis_off()
    ax.set_aspect('equal')
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0], cmap='gray')
    def update(frame):
      im.set_data(frame)
      return [im]
    interval = 1000/framerate
    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                   interval=interval, blit=True, repeat=False)
    return HTML(anim.to_html5_video())

def save_frames(frames:list, out_dir:str, dpi:int=70):
    height, width, _ = frames[0].shape
    orig_backend = matplotlib.get_backend()
    matplotlib.use('Agg') 
    for k, frame in enumerate(frames):
      fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi))
      matplotlib.use(orig_backend)
      ax.set_axis_off()
      ax.set_aspect('equal')
      ax.set_position([0, 0, 1, 1])
      im = ax.imshow(frame, cmap='gray')
      fig.savefig(os.path.join(out_dir, f"{k}.png"))
      plt.close()

def save_dataset(out_dir, _agent, _env, _env_param, index=None, max_steps=180):
  param = _env_param
  param.max_steps = max_steps
  if index is None:
    for i in range(6):
        while True:
          frames, actions = render(_agent, _env, _env_param, i)
          if len(frames) < max_steps:
            os.makedirs(os.path.join(out_dir, f"test_{i}"), exist_ok=True)
            save_frames(frames, os.path.join(out_dir, f"test_{i}"), 30)
            json.dump(actions, open(os.path.join(out_dir, f"test_{i}.json"), "w"))
            break
  else:
    frames, actions = render(_agent, _env, _env_param, index)
    if len(frames) < max_steps:
      os.makedirs(os.path.join(out_dir, f"test_{index}"), exist_ok=True)
      save_frames(frames, os.path.join(out_dir, f"test_{index}"), 30)
      json.dump(actions, open(os.path.join(out_dir, f"test_{index}.json"), "w"))



In [37]:
save_dataset("./test", agent, ENV, env_param, 2, 180)

Load 6 maps from ./datas/minimal/envmsgs


 96%|█████████▌| 172/180 [00:14<00:00, 11.97it/s]


In [21]:
save_dataset("./test", agent, ENV, env_param)

Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.24it/s]


Load 6 maps from ./datas/minimal/envmsgs


 25%|██▌       | 45/180 [00:03<00:11, 11.48it/s]


Load 6 maps from ./datas/minimal/envmsgs


 21%|██        | 38/180 [00:03<00:12, 11.32it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.72it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:14<00:00, 12.23it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.92it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.82it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.54it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.25it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.88it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.44it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.65it/s]


Load 6 maps from ./datas/minimal/envmsgs


 76%|███████▌  | 136/180 [00:11<00:03, 11.58it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:17<00:00, 10.46it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:16<00:00, 10.69it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.25it/s]


Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:15<00:00, 11.52it/s]


Load 6 maps from ./datas/minimal/envmsgs


 31%|███       | 55/180 [00:04<00:10, 11.68it/s]


KeyboardInterrupt: 

In [9]:
frames, actions = render(agent, ENV, env_param, 0)
display_video(frames, 8, 30)

Load 6 maps from ./datas/minimal/envmsgs


 37%|███▋      | 67/180 [00:05<00:09, 11.62it/s]


In [10]:
frames, actions = render(agent, ENV, env_param, 1)
display_video(frames, 8, 30)

Load 6 maps from ./datas/minimal/envmsgs


 22%|██▏       | 39/180 [00:03<00:12, 11.44it/s]


In [11]:
frames, actions = render(agent, ENV, env_param, 2)
display_video(frames, 8, 30)

Load 6 maps from ./datas/minimal/envmsgs


 69%|██████▉   | 124/180 [00:10<00:04, 11.67it/s]


In [12]:
frames, actions = render(agent, ENV, env_param, 3)
display_video(frames, 8, 30)

Load 6 maps from ./datas/minimal/envmsgs


 33%|███▎      | 59/180 [00:05<00:10, 11.37it/s]


In [19]:
frames, actions = render(agent, ENV, env_param, 4)
display_video(frames, 8, 30)

Load 6 maps from ./datas/minimal/envmsgs


 23%|██▎       | 41/180 [00:03<00:12, 11.56it/s]


In [16]:
frames, actions = render(agent, ENV, env_param, 5)
display_video(frames, 8, 30)

Load 6 maps from ./datas/minimal/envmsgs


 18%|█▊        | 32/180 [00:02<00:13, 11.29it/s]


# 成功率计算

In [45]:
def calc_succ_rate( _agent, _env, _env_param, index=None, succ_steps=170, max_steps=180, rounds=500):
  param = _env_param
  param.max_steps = max_steps
  avg_step  = 0
  num_succs = 0
  for j in range(rounds):
    frames, _ = render(_agent, _env, _env_param, index)
    if len(frames) < succ_steps:
      avg_step += len(frames)
      num_succs += 1
    print(avg_step, num_succs)
  avg_step = avg_step / num_succs
  return avg_step, num_succs

In [47]:
calc_succ_rate(agent, ENV, env_param, 0, rounds=150)

Load 6 maps from ./datas/minimal/envmsgs


 99%|█████████▉| 179/180 [00:16<00:00, 10.79it/s]


Load 6 maps from ./datas/minimal/envmsgs


 30%|███       | 54/180 [00:04<00:10, 11.62it/s]


Load 6 maps from ./datas/minimal/envmsgs


 42%|████▏     | 76/180 [00:06<00:09, 10.85it/s]