In [1]:
# !pip install gymnasium
# !pip install swig
# !pip install gymnasium[box2d]

# import 必要套件

In [2]:
import random
import time,math
import numpy as np
import gymnasium as gym
import gymnasium.wrappers as gym_wrap
import matplotlib.pyplot as plt
import matplotlib.animation as animation #輸出動畫影片
from IPython import display
from tqdm import tqdm

In [3]:
import torch
import torch.nn.functional as F
import collections
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
class ImageEnv(gym.Wrapper):
  def __init__(self,env,stack_frames=4,delay_op=50):
    super(ImageEnv, self).__init__(env)
    self.delay_op = delay_op
    self.stack_frames = stack_frames
  def reset(self):
    s, info = self.env.reset()
    for i in range(self.delay_op):
      s, r, terminated, truncated, info = self.env.step(0)
      s=s[:84, 6:90]/255.0
      self.stacked_state = np.tile( s , (self.stack_frames,1,1) )  # [4, 84, 84]
    return self.stacked_state, info

  def step(self, action):
    reward = 0
    for _ in range(self.stack_frames):
      s, r, terminated, truncated, info = self.env.step(action)
      if r==-100:terminated=True
      s=s[:84, 6:90]/255.0
      reward += r
      if terminated or truncated:break
      self.stacked_state = np.concatenate((self.stacked_state[1:], s[np.newaxis]), axis=0)
    return self.stacked_state, reward, terminated, truncated, info

# 建立Replay Buffer類別

In [5]:
class ReplayBuffer:
  def __init__(self,max_size=int(1e5), num_steps=1):
    self.s = np.zeros((max_size,4,84,84), dtype=np.float32)
    self.a = np.zeros((max_size,), dtype=np.int64)
    self.r = np.zeros((max_size, 1), dtype=np.float32)
    self.s_ = np.zeros((max_size,4,84,84), dtype=np.float32)
    self.done = np.zeros((max_size, 1), dtype=np.float32)
    self.ptr = 0
    self.size = 0
    self.max_size = max_size
    self.num_steps = num_steps

  def append(self,s,a,r,s_,done):
    self.s[self.ptr] = s
    self.a[self.ptr] = a
    self.r[self.ptr] = r
    self.s_[self.ptr] = s_
    self.done[self.ptr] = done
    self.ptr = (self.ptr + 1) % self.max_size
    self.size = min(self.size+1,self.max_size)
  def sample(self, batch_size):
    ind = np.random.randint(0, self.size, batch_size)
    return torch.FloatTensor(self.s[ind]),torch.LongTensor(self.a[ind]),torch.FloatTensor(self.r[ind]),torch.FloatTensor(self.s_[ind]),torch.FloatTensor(self.done[ind])

# 搭建DQN神經網路的類別

In [6]:
class DQN(torch.nn.Module):
  def __init__(self,n_act):
    super(DQN,self).__init__()
    self.conv1 = torch.nn.Conv2d(4, 16, kernel_size=8, stride=4)  #[N,4,84,84]->[N,16,20,20]
    self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=4, stride=2)  #[N,16,20,20]->[N,32,9,9]
    self.fc1 = torch.nn.Linear(32 * 9 * 9, 256)

    # 分兩路
    self.value1 = torch.nn.Linear(256, 128)
    self.value2 = torch.nn.Linear(128, 1)

    self.adv1 = torch.nn.Linear(256 , 64)
    self.adv2 = torch.nn.Linear(64, n_act)
  def forward(self,x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = x.view((-1, 32 * 9 * 9))
    x = F.relu(self.fc1(x))
    value = self.value2(F.relu(self.value1(x)))
    adv = self.adv2(F.relu(self.adv1(x)))
    return value + ( adv - adv.mean(dim=-1, keepdim=True))

# 設定是否載入模型參數，舊參數檔路徑，新參數檔路徑

In [7]:
Load_File=0
Old_File=f"Model-{Load_File}.pt"
if Load_File>0:
  Log= np.load(f"Log-{Load_File}.npy", allow_pickle=True).item()
else:
  Log={"TrainReward":[],"TestReward":[],"Loss":[]}

In [8]:
env=gym.make('CarRacing-v3',render_mode="rgb_array",domain_randomize=False, continuous=False)
env = gym_wrap.GrayscaleObservation(env)
env = ImageEnv(env)

# 搭建智能體Agent的類別

In [9]:
class DQNAgent():
  def __init__(self,gamma=0.9,eps_low=0.1,lr=0.00025):
    self.env = env
    self.n_act=self.env.action_space.n
    self.PredictDQN= DQN(self.n_act)
    self.TargetDQN= DQN(self.n_act)
    if Load_File>0:
      self.PredictDQN.load_state_dict(torch.load(Old_File))
      self.TargetDQN.load_state_dict(torch.load(Old_File))
    self.PredictDQN.to(device)
    self.TargetDQN.to(device)
    self.LossFun=torch.nn.SmoothL1Loss()
    self.optimizer=torch.optim.Adam(self.PredictDQN.parameters(),lr=lr)
    self.gamma=gamma
    self.eps_low=eps_low
    self.rb=ReplayBuffer(max_size=10000, num_steps=1)
  def PredictA(self,s):
    with torch.no_grad():
      return torch.argmax(self.PredictDQN(torch.FloatTensor(s).to(device))).item()
  def SelectA(self,a):
    return self.env.action_space.sample() if np.random.random()<self.EPS else a
  def Train(self,N_EPISODES):
    for i in tqdm(range(Load_File,N_EPISODES)):
      self.EPS=self.eps_low+(1-self.eps_low)*math.exp(-i*5/(N_EPISODES))
      total_reward=0
      s,_=self.env.reset()
      while True:
        a=self.SelectA(self.PredictA(s))
        s_,r,done,stop,_=self.env.step(a)
        self.rb.append(s,a,r,s_,done)
        if self.rb.size > 200 and i%self.rb.num_steps==0:self.Learn()
        if i % 20==0:  self.TargetDQN.load_state_dict(self.PredictDQN.state_dict())
        s=s_
        total_reward+=r
        if done or stop:break
      # print(f"\n{total_reward}")
      Log["TrainReward"].append(total_reward)
      if i % 10 == 9:
        test_reward=self.Test()
        print(f"\n訓練次數{i+1}，總回報{test_reward}")
        Log["TestReward"].append(test_reward)
        torch.save(self.PredictDQN.state_dict(), f"Model-{i+1}.pt")
        np.save(f"Log-{i+1}.npy", Log)
  def Learn(self):
    self.optimizer.zero_grad()
    batch_s, batch_a, batch_r, batch_s_, batch_done=self.rb.sample(32)
    predict_Q = (self.PredictDQN(batch_s.to(device))*F.one_hot(batch_a.long().to(device),self.n_act)).sum(1,keepdims=True)
    with torch.no_grad():
      target_Q = batch_r.to(device)+(1-batch_done.to(device))*self.gamma*self.TargetDQN(batch_s_.to(device)).max(1,keepdims=True)[0]
    loss = self.LossFun(predict_Q, target_Q)
    Log["Loss"].append(float(loss))
    loss.backward()
    self.optimizer.step()
  def Test(self,VIDEO=False):
    total_reward=0
    video=[]
    s,_=self.env.reset()
    while True:
      video.append(self.env.render())
      a=self.PredictA(s)
      s,r,done,stop,_=self.env.step(a)
      total_reward+=r
      if done or stop:break
    if VIDEO:
      patch = plt.imshow(video[0]) #產生展示圖形物件
      plt.axis('off') #關閉坐標軸
      def animate(i): #設定更換影格的函數
        patch.set_data(video[i])
        #plt.gcf()=>建新繪圖區 animate=>更換影格函數 frames=>影格數 interval=>影隔間距(毫秒)
      anim = animation.FuncAnimation(plt.gcf(),animate,frames=len(video),interval=200)
      anim.save('Car_Racing.mp4') #儲存為mp4擋
    return total_reward
  def Record(self):
    total_reward=0
    s,_=self.env.reset()
    while True:
      image=self.env.render()
      plt.imshow(image)
      #plt.imsave(f"/content/drive/MyDrive/recording/{str(int(time.time()))}.png", image)
      a=self.PredictA(s)
      s,r,done,stop,_=self.env.step(a)
      print(r)
      total_reward+=r
      plt.pause(0.1)
      #清除目前的顯示
      display.clear_output(wait=True)
      if done or stop:break
    print(total_reward)

In [10]:
Agent=DQNAgent(gamma=0.95,eps_low=0.05,lr=0.00025)
Agent.Train(N_EPISODES=1000)

  1%|          | 10/1000 [02:13<4:20:34, 15.79s/it]


訓練次數10，總回報-94.99999999999898


  2%|▏         | 20/1000 [04:21<3:47:44, 13.94s/it]


訓練次數20，總回報-69.06643356643399


  3%|▎         | 30/1000 [06:35<4:16:59, 15.90s/it]


訓練次數30，總回報-94.99999999999895


  4%|▍         | 40/1000 [08:40<3:33:37, 13.35s/it]


訓練次數40，總回報71.57012448132883


  5%|▌         | 50/1000 [10:54<4:12:46, 15.96s/it]


訓練次數50，總回報-94.99999999999898


  6%|▌         | 60/1000 [13:01<3:33:43, 13.64s/it]


訓練次數60，總回報-65.8029304029308


  7%|▋         | 70/1000 [15:17<4:08:26, 16.03s/it]


訓練次數70，總回報-94.99999999999899


  8%|▊         | 80/1000 [17:33<4:07:29, 16.14s/it]


訓練次數80，總回報-94.999999999999


  9%|▉         | 90/1000 [19:46<4:03:39, 16.07s/it]


訓練次數90，總回報-94.99999999999895


 10%|█         | 100/1000 [21:58<3:52:13, 15.48s/it]


訓練次數100，總回報-94.999999999999


 11%|█         | 110/1000 [24:08<3:48:37, 15.41s/it]


訓練次數110，總回報-94.99999999999903


 12%|█▏        | 120/1000 [26:11<3:12:57, 13.16s/it]


訓練次數120，總回報127.24276527331367


 13%|█▎        | 130/1000 [28:00<2:56:56, 12.20s/it]


訓練次數130，總回報-37.49090909090829


 14%|█▍        | 140/1000 [30:05<3:35:11, 15.01s/it]


訓練次數140，總回報-18.87543252595227


 15%|█▌        | 150/1000 [32:10<3:36:35, 15.29s/it]


訓練次數150，總回報-57.3287671232884


 16%|█▌        | 160/1000 [34:16<3:30:12, 15.02s/it]


訓練次數160，總回報18.970588235294905


 17%|█▋        | 170/1000 [36:20<3:28:27, 15.07s/it]


訓練次數170，總回報108.06513409962133


 18%|█▊        | 180/1000 [38:22<3:05:30, 13.57s/it]


訓練次數180，總回報24.16666666666822


 19%|█▉        | 190/1000 [40:26<3:22:49, 15.02s/it]


訓練次數190，總回報98.44262295082399


 20%|██        | 200/1000 [42:31<3:19:56, 15.00s/it]


訓練次數200，總回報130.58922558923032


 21%|██        | 210/1000 [44:30<3:10:37, 14.48s/it]


訓練次數210，總回報-22.27272727272798


 22%|██▏       | 220/1000 [46:38<3:20:41, 15.44s/it]


訓練次數220，總回報305.69686411149496


 23%|██▎       | 230/1000 [48:41<3:20:02, 15.59s/it]


訓練次數230，總回報68.93442622951254


 24%|██▍       | 240/1000 [50:51<3:15:19, 15.42s/it]


訓練次數240，總回報98.54838709677682


 25%|██▌       | 250/1000 [52:53<3:07:11, 14.97s/it]


訓練次數250，總回報498.4426229508107


 26%|██▌       | 260/1000 [54:59<3:06:59, 15.16s/it]


訓練次數260，總回報559.485049833876


 27%|██▋       | 270/1000 [57:06<3:06:19, 15.31s/it]


訓練次數270，總回報67.42038216560884


 28%|██▊       | 280/1000 [59:14<3:05:08, 15.43s/it]


訓練次數280，總回報115.00000000000414


 29%|██▉       | 290/1000 [1:01:20<3:01:33, 15.34s/it]


訓練次數290，總回報-94.99999999999898


 30%|███       | 300/1000 [1:03:29<3:02:27, 15.64s/it]


訓練次數300，總回報303.71382636655017


 31%|███       | 310/1000 [1:05:40<3:02:04, 15.83s/it]


訓練次數310，總回報504.99999999998823


 32%|███▏      | 320/1000 [1:07:47<2:55:23, 15.48s/it]


訓練次數320，總回報617.8712871287062


 33%|███▎      | 330/1000 [1:09:54<2:53:00, 15.49s/it]


訓練次數330，總回報272.00336700337107


 34%|███▍      | 340/1000 [1:12:01<2:50:33, 15.50s/it]


訓練次數340，總回報630.4901960784204


 35%|███▌      | 350/1000 [1:14:08<2:48:29, 15.55s/it]


訓練次數350，總回報637.3943661971723


 36%|███▌      | 360/1000 [1:16:15<2:44:03, 15.38s/it]


訓練次數360，總回報708.5714285714154


 37%|███▋      | 370/1000 [1:18:23<2:37:50, 15.03s/it]


訓練次數370，總回報71.66666666667135


 38%|███▊      | 380/1000 [1:20:28<2:35:47, 15.08s/it]


訓練次數380，總回報453.4949832775821


 39%|███▉      | 390/1000 [1:22:37<2:37:44, 15.52s/it]


訓練次數390，總回報259.23197492162285


 40%|████      | 400/1000 [1:24:45<2:34:44, 15.47s/it]


訓練次數400，總回報621.0493827160416


 41%|████      | 410/1000 [1:26:54<2:34:22, 15.70s/it]


訓練次數410，總回報529.242424242412


 42%|████▏     | 420/1000 [1:29:04<2:32:09, 15.74s/it]


訓練次數420，總回報243.19241982506566


 43%|████▎     | 430/1000 [1:31:12<2:25:53, 15.36s/it]


訓練次數430，總回報732.8388278388171


 44%|████▍     | 440/1000 [1:33:20<2:24:44, 15.51s/it]


訓練次數440，總回報584.4871794871717


 45%|████▌     | 450/1000 [1:35:31<2:22:55, 15.59s/it]


訓練次數450，總回報677.0588235293989


 46%|████▌     | 460/1000 [1:37:40<2:21:41, 15.74s/it]


訓練次數460，總回報589.2105263157781


 47%|████▋     | 470/1000 [1:39:50<2:16:43, 15.48s/it]


訓練次數470，總回報240.61643835616587


 48%|████▊     | 480/1000 [1:41:59<2:15:44, 15.66s/it]


訓練次數480，總回報324.45288753798434


 49%|████▉     | 490/1000 [1:44:02<2:12:32, 15.59s/it]


訓練次數490，總回報515.7142857142724


 50%|█████     | 500/1000 [1:46:11<2:09:57, 15.60s/it]


訓練次數500，總回報601.4856230031854


 51%|█████     | 510/1000 [1:48:20<2:06:24, 15.48s/it]


訓練次數510，總回報57.343750000004384


 52%|█████▏    | 520/1000 [1:50:29<2:06:22, 15.80s/it]


訓練次數520，總回報617.933753943204


 53%|█████▎    | 530/1000 [1:52:37<2:01:05, 15.46s/it]


訓練次數530，總回報462.69230769229705


 54%|█████▍    | 540/1000 [1:54:45<1:57:23, 15.31s/it]


訓練次數540，總回報235.09708737863386


 55%|█████▌    | 550/1000 [1:56:52<1:54:18, 15.24s/it]


訓練次數550，總回報337.7272727272583


 56%|█████▌    | 560/1000 [1:59:01<1:54:02, 15.55s/it]


訓練次數560，總回報756.8518518518362


 57%|█████▋    | 570/1000 [2:01:10<1:51:35, 15.57s/it]


訓練次數570，總回報464.13978494622324


 58%|█████▊    | 580/1000 [2:03:19<1:48:36, 15.52s/it]


訓練次數580，總回報371.25766871164524


 59%|█████▉    | 590/1000 [2:05:26<1:43:56, 15.21s/it]


訓練次數590，總回報408.5460992907675


 60%|██████    | 600/1000 [2:07:34<1:41:08, 15.17s/it]


訓練次數600，總回報906.5078014184303


 61%|██████    | 610/1000 [2:09:41<1:40:14, 15.42s/it]


訓練次數610，總回報804.0228013029152


 62%|██████▏   | 620/1000 [2:11:50<1:40:13, 15.83s/it]


訓練次數620，總回報372.26190476189566


 63%|██████▎   | 630/1000 [2:13:57<1:35:28, 15.48s/it]


訓練次數630，總回報730.2427184465881


 64%|██████▍   | 640/1000 [2:16:04<1:31:10, 15.19s/it]


訓練次數640，總回報810.8823529411663


 65%|██████▌   | 650/1000 [2:18:14<1:30:55, 15.59s/it]


訓練次數650，總回報660.1724137930881


 66%|██████▌   | 660/1000 [2:20:24<1:29:35, 15.81s/it]


訓練次數660，總回報859.3726235741261


 67%|██████▋   | 670/1000 [2:22:32<1:25:56, 15.63s/it]


訓練次數670，總回報728.9436619718157


 68%|██████▊   | 680/1000 [2:24:43<1:23:47, 15.71s/it]


訓練次數680，總回報608.2967032966912


 69%|██████▉   | 690/1000 [2:26:54<1:22:48, 16.03s/it]


訓練次數690，總回報487.32931726906304


 70%|███████   | 700/1000 [2:29:05<1:18:41, 15.74s/it]


訓練次數700，總回報687.6086956521649


 71%|███████   | 710/1000 [2:31:15<1:16:03, 15.74s/it]


訓練次數710，總回報440.94771241828784


 72%|███████▏  | 720/1000 [2:33:24<1:13:39, 15.78s/it]


訓練次數720，總回報385.44692737429074


 73%|███████▎  | 730/1000 [2:35:31<1:09:47, 15.51s/it]


訓練次數730，總回報695.4761904761759


 74%|███████▍  | 740/1000 [2:37:40<1:07:30, 15.58s/it]


訓練次數740，總回報339.5047923322564


 75%|███████▌  | 750/1000 [2:39:48<1:05:05, 15.62s/it]


訓練次數750，總回報832.0833333333157


 76%|███████▌  | 760/1000 [2:41:59<1:04:06, 16.03s/it]


訓練次數760，總回報710.280528052795


 77%|███████▋  | 770/1000 [2:44:14<1:02:40, 16.35s/it]


訓練次數770，總回報837.8859060402522


 78%|███████▊  | 780/1000 [2:46:25<58:33, 15.97s/it]


訓練次數780，總回報883.798586572423


 79%|███████▉  | 790/1000 [2:48:31<54:48, 15.66s/it]


訓練次數790，總回報842.9562043795453


 80%|████████  | 800/1000 [2:50:46<55:10, 16.55s/it]


訓練次數800，總回報857.2292993630393


 81%|████████  | 810/1000 [2:53:01<51:33, 16.28s/it]


訓練次數810，總回報886.4126394051904


 82%|████████▏ | 820/1000 [2:55:14<47:34, 15.86s/it]


訓練次數820，總回報802.4358974358855


 83%|████████▎ | 830/1000 [2:57:29<46:08, 16.28s/it]


訓練次數830，總回報823.2156133828861


 84%|████████▍ | 840/1000 [2:59:40<42:21, 15.88s/it]


訓練次數840，總回報447.16867469878355


 85%|████████▌ | 850/1000 [3:01:52<39:28, 15.79s/it]


訓練次數850，總回報-94.999999999999


 86%|████████▌ | 860/1000 [3:03:59<34:49, 14.92s/it]


訓練次數860，總回報854.6644295301849


 87%|████████▋ | 870/1000 [3:06:08<33:28, 15.45s/it]


訓練次數870，總回報905.6824561403365


 88%|████████▊ | 880/1000 [3:08:18<31:12, 15.61s/it]


訓練次數880，總回報880.8620689655015


 89%|████████▉ | 890/1000 [3:10:26<27:40, 15.10s/it]


訓練次數890，總回報914.4470588235152


 90%|█████████ | 900/1000 [3:12:40<26:49, 16.10s/it]


訓練次數900，總回報860.7823129251608


 91%|█████████ | 910/1000 [3:14:52<24:06, 16.07s/it]


訓練次數910，總回報555.8875739644858


 92%|█████████▏| 920/1000 [3:17:03<20:57, 15.72s/it]


訓練次數920，總回報-94.99999999999898


 93%|█████████▎| 930/1000 [3:19:10<17:26, 14.95s/it]


訓練次數930，總回報914.62727272726


 94%|█████████▍| 940/1000 [3:21:21<15:37, 15.62s/it]


訓練次數940，總回報842.4999999999857


 95%|█████████▌| 950/1000 [3:23:30<13:06, 15.73s/it]


訓練次數950，總回報823.819188191871


 96%|█████████▌| 960/1000 [3:25:40<10:16, 15.42s/it]


訓練次數960，總回報873.6411149825647


 97%|█████████▋| 970/1000 [3:27:52<07:50, 15.69s/it]


訓練次數970，總回報894.5470383275119


 98%|█████████▊| 980/1000 [3:30:04<05:16, 15.84s/it]


訓練次數980，總回報835.6569343065529


 99%|█████████▉| 990/1000 [3:32:16<02:38, 15.89s/it]


訓練次數990，總回報864.0443686006647


100%|██████████| 1000/1000 [3:34:28<00:00, 12.87s/it]


訓練次數1000，總回報888.8709677419273





In [11]:
Agent.Record()

649.3365695792743
