# import 必要套件

In [None]:
import random
import time,math
import numpy as np
import gymnasium as gym
import gymnasium.wrappers as gym_wrap
import matplotlib.pyplot as plt
import matplotlib.animation as animation #輸出動畫影片
from IPython import display
from tqdm import tqdm

In [None]:
import torch
import torch.nn.functional as F
import collections
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
class ImageEnv(gym.Wrapper):
  def __init__(self,env,stack_frames=4,delay_op=50):
    super(ImageEnv, self).__init__(env)
    self.delay_op = delay_op
    self.stack_frames = stack_frames
  def reset(self):
    s, info = self.env.reset()
    for i in range(self.delay_op):
      s, r, terminated, truncated, info = self.env.step(0)
      s=(s[:84, 6:90]/255.0)-0.5
      self.stacked_state = np.tile( s , (self.stack_frames,1,1) )  # [4, 84, 84]
    return self.stacked_state, info

  def step(self, action):
    reward = 0
    for _ in range(self.stack_frames):
      s, r, terminated, truncated, info = self.env.step(action)
      if r==-100:terminated=True
      s=(s[:84, 6:90]/255.0)-0.5
      reward += r
      if terminated or truncated:break
      self.stacked_state = np.concatenate((self.stacked_state[1:], s[np.newaxis]), axis=0)
    return self.stacked_state, reward, terminated, truncated, info

# 搭建DQN神經網路的類別

In [None]:
class DQN(torch.nn.Module):
  def __init__(self,n_act):
    super(DQN,self).__init__()
    self.conv1 = torch.nn.Conv2d(4, 16, kernel_size=8, stride=4)  #[N,4,84,84]->[N,16,20,20]
    self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=4, stride=2)  #[N,16,20,20]->[N,32,9,9]
    self.fc1 = torch.nn.Linear(32 * 9 * 9, 256)
    self.fc2 = torch.nn.Linear(256, n_act)
  def forward(self,x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = x.view((-1, 32 * 9 * 9))
    x = self.fc1(x)
    x = self.fc2(x)
    return x

# 設定是否載入模型參數，舊參數檔路徑，新參數檔路徑

In [None]:
env=gym.make('CarRacing-v3',render_mode="rgb_array",domain_randomize=False, continuous=False)
env = gym_wrap.GrayscaleObservation(env)
env = ImageEnv(env)

# 搭建智能體Agent的類別

In [None]:
class DQNAgent():
  def __init__(self,OldFile=50):
    self.env = env
    self.n_act=self.env.action_space.n
    self.PredictDQN= DQN(self.n_act)
    self.PredictDQN.load_state_dict(torch.load(f"Model-{OldFile}.pt",weights_only=False))
    self.PredictDQN.to(device)
  def PredictA(self,s):
    with torch.no_grad():
      return torch.argmax(self.PredictDQN(torch.FloatTensor(s).to(device))).item()
  def Test(self):
    RewardList=[]
    for i in range(20):
        total_reward=0
        s,_=self.env.reset()
        while True:
          a=self.PredictA(s)
          s,r,done,stop,_=self.env.step(a)
          total_reward+=r
          if done or stop:break
        RewardList.append(total_reward)
    return RewardList

In [None]:
Log={"Mean":[],"Max":[],"Min":[],"Std":[]}
for OldFile in range(50,5050,50):
    Agent=DQNAgent(OldFile)
    RewardList=np.array(Agent.Test())
    Log["Mean"].append(float(RewardList.mean()))
    Log["Max"].append(float(RewardList.max()))
    Log["Min"].append(float(RewardList.min()))
    Log["Std"].append(float(RewardList.std()))
    print(float(RewardList.mean()),float(RewardList.max()),float(RewardList.min()),float(RewardList.std()))
np.save("TestReward.npy",Log)

-89.75057695644136 -64.0721649484543 -94.99999999999903 9.54725288008995
-72.56482494718405 26.323529411765506 -94.999999999999 27.570870640270027
-47.2584064757102 1.0854092526682528 -84.83050847457582 27.45782548914961
-12.542018805476356 59.12186379928315 -84.3238434163697 38.95489426834188
-79.35517573916361 -55.99290780141888 -94.99999999999923 7.539454708366761
-20.417320817832543 57.67175572519017 -94.99999999999903 38.10367241748896
144.34237207144508 254.31506849315392 -36.36807817589646 75.85678709526054
24.58405273294023 155.0000000000025 -94.999999999999 82.95345653963001
4.744865271045133 116.35646687697425 -69.52229299363091 56.175332766309666
-43.419930728630945 -13.088737201365564 -70.6097560975613 18.493912536925347
71.95387808840941 229.59016393442943 -59.66431095406432 86.25130511125397
-52.15729243151111 -8.494809688581181 -88.48534201954331 19.11118086054896
-81.83698759006292 -37.372881355932336 -88.67088607594869 10.965389670537807
-19.52404545203744 21.465863453