In [None]:
# !pip install gymnasium
# !pip install swig
# !pip install gymnasium[box2d]

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0
Collecting swig
  Downloading swig-4.3.0-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl.metadata (3.5 kB)
Downloading swig-4.3.0-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: swig
Successfully installed swig-4.3.0
Collecti

# import 必要套件

In [None]:
import random
import time,math
import numpy as np
import gymnasium as gym
import gymnasium.wrappers as gym_wrap
import matplotlib.pyplot as plt
import matplotlib.animation as animation #輸出動畫影片
from IPython import display
from tqdm import tqdm

In [None]:
import torch
import torch.nn.functional as F
import collections
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
class ImageEnv(gym.Wrapper):
  def __init__(self,env,stack_frames=4,delay_op=50):
    super(ImageEnv, self).__init__(env)
    self.delay_op = delay_op
    self.stack_frames = stack_frames
  def reset(self):
    s, info = self.env.reset()
    for i in range(self.delay_op):
      s, r, terminated, truncated, info = self.env.step(0)
      s=s[:84, 6:90]/255.0
      self.stacked_state = np.tile( s , (self.stack_frames,1,1) )  # [4, 84, 84]
    return self.stacked_state, info

  def step(self, action):
    reward = 0
    for _ in range(self.stack_frames):
      s, r, terminated, truncated, info = self.env.step(action)
      if r==-100:terminated=True
      s=s[:84, 6:90]/255.0
      reward += r
      if terminated or truncated:break
      self.stacked_state = np.concatenate((self.stacked_state[1:], s[np.newaxis]), axis=0)
    return self.stacked_state, reward, terminated, truncated, info

# 建立Replay Buffer類別

In [None]:
class ReplayBuffer:
  def __init__(self,max_size=int(1e5), num_steps=1):
    self.s = np.zeros((max_size,4,84,84), dtype=np.float32)
    self.a = np.zeros((max_size,), dtype=np.int64)
    self.r = np.zeros((max_size, 1), dtype=np.float32)
    self.s_ = np.zeros((max_size,4,84,84), dtype=np.float32)
    self.done = np.zeros((max_size, 1), dtype=np.float32)
    self.ptr = 0
    self.size = 0
    self.max_size = max_size
    self.num_steps = num_steps

  def append(self,s,a,r,s_,done):
    self.s[self.ptr] = s
    self.a[self.ptr] = a
    self.r[self.ptr] = r
    self.s_[self.ptr] = s_
    self.done[self.ptr] = done
    self.ptr = (self.ptr + 1) % self.max_size
    self.size = min(self.size+1,self.max_size)
  def sample(self, batch_size):
    ind = np.random.randint(0, self.size, batch_size)
    return torch.FloatTensor(self.s[ind]),torch.LongTensor(self.a[ind]),torch.FloatTensor(self.r[ind]),torch.FloatTensor(self.s_[ind]),torch.FloatTensor(self.done[ind])

# 搭建DQN神經網路的類別

In [None]:
class DQN(torch.nn.Module):
  def __init__(self,n_act):
    super(DQN,self).__init__()
    self.conv1 = torch.nn.Conv2d(4, 16, kernel_size=8, stride=4)  #[N,4,84,84]->[N,16,20,20]
    self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=4, stride=2)  #[N,16,20,20]->[N,32,9,9]
    self.fc1 = torch.nn.Linear(32 * 9 * 9, 256)
    self.fc2 = torch.nn.Linear(256, n_act)
  def forward(self,x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = x.view((-1, 32 * 9 * 9))
    x = self.fc1(x)
    x = self.fc2(x)
    return x

# 設定是否載入模型參數，舊參數檔路徑，新參數檔路徑

In [None]:
Load_File=0
Old_File=f"Model-{Load_File}.pt"
if Load_File>0:
  Log= np.load(f"Log-{Load_File}.npy", allow_pickle=True).item()
else:
  Log={"TrainReward":[],"TestReward":[],"Loss":[]}

In [None]:
env=gym.make('CarRacing-v3',render_mode="rgb_array",domain_randomize=False, continuous=False)
env = gym_wrap.GrayscaleObservation(env)
env = ImageEnv(env)

# 搭建智能體Agent的類別

In [None]:
class DQNAgent():
  def __init__(self,gamma=0.9,eps_low=0.1,lr=0.00025):
    self.env = env
    self.n_act=self.env.action_space.n
    self.PredictDQN= DQN(self.n_act)
    self.TargetDQN= DQN(self.n_act)
    if Load_File>0:
      self.PredictDQN.load_state_dict(torch.load(Old_File))
      self.TargetDQN.load_state_dict(torch.load(Old_File))
    self.PredictDQN.to(device)
    self.TargetDQN.to(device)
    self.LossFun=torch.nn.SmoothL1Loss()
    self.optimizer=torch.optim.Adam(self.PredictDQN.parameters(),lr=lr)
    self.gamma=gamma
    self.eps_low=eps_low
    self.rb=ReplayBuffer(max_size=10000, num_steps=1)
  def PredictA(self,s):
    with torch.no_grad():
      return torch.argmax(self.PredictDQN(torch.FloatTensor(s).to(device))).item()
  def SelectA(self,a):
    return self.env.action_space.sample() if np.random.random()<self.EPS else a
  def Train(self,N_EPISODES):
    for i in tqdm(range(Load_File,N_EPISODES)):
      self.EPS=self.eps_low+(1-self.eps_low)*math.exp(-i*5/(N_EPISODES))
      total_reward=0
      s,_=self.env.reset()
      while True:
        a=self.SelectA(self.PredictA(s))
        s_,r,done,stop,_=self.env.step(a)
        self.rb.append(s,a,r,s_,done)
        if self.rb.size > 200 and i%self.rb.num_steps==0:self.Learn()
        if i % 20==0:  self.TargetDQN.load_state_dict(self.PredictDQN.state_dict())
        s=s_
        total_reward+=r
        if done or stop:break
      # print(f"\n{total_reward}")
      Log["TrainReward"].append(total_reward)
      if i % 10 == 9:
        test_reward=self.Test()
        print(f"\n訓練次數{i+1}，總回報{test_reward}")
        Log["TestReward"].append(test_reward)
        torch.save(self.PredictDQN.state_dict(), f"Model-{i+1}.pt")
        np.save(f"Log-{i+1}.npy", Log)
  def Learn(self):
    self.optimizer.zero_grad()
    batch_s, batch_a, batch_r, batch_s_, batch_done=self.rb.sample(32)
    predict_Q = (self.PredictDQN(batch_s.to(device))*F.one_hot(batch_a.long().to(device),self.n_act)).sum(1,keepdims=True)
    with torch.no_grad():
      target_Q = batch_r.to(device)+(1-batch_done.to(device))*self.gamma*self.TargetDQN(batch_s_.to(device)).max(1,keepdims=True)[0]
    loss = self.LossFun(predict_Q, target_Q)
    Log["Loss"].append(float(loss))
    loss.backward()
    self.optimizer.step()
  def Test(self,VIDEO=False):
    total_reward=0
    video=[]
    s,_=self.env.reset()
    while True:
      video.append(self.env.render())
      a=self.PredictA(s)
      s,r,done,stop,_=self.env.step(a)
      total_reward+=r
      if done or stop:break
    if VIDEO:
      patch = plt.imshow(video[0]) #產生展示圖形物件
      plt.axis('off') #關閉坐標軸
      def animate(i): #設定更換影格的函數
        patch.set_data(video[i])
        #plt.gcf()=>建新繪圖區 animate=>更換影格函數 frames=>影格數 interval=>影隔間距(毫秒)
      anim = animation.FuncAnimation(plt.gcf(),animate,frames=len(video),interval=200)
      anim.save('Car_Racing.mp4') #儲存為mp4擋
    return total_reward
  def Record(self):
    total_reward=0
    s,_=self.env.reset()
    while True:
      image=self.env.render()
      plt.imshow(image)
      #plt.imsave(f"/content/drive/MyDrive/recording/{str(int(time.time()))}.png", image)
      a=self.PredictA(s)
      s,r,done,stop,_=self.env.step(a)
      print(r)
      total_reward+=r
      plt.pause(0.1)
      #清除目前的顯示
      display.clear_output(wait=True)
      if done or stop:break
    print(total_reward)

In [None]:
Agent=DQNAgent(gamma=0.95,eps_low=0.05,lr=0.00025)
Agent.Train(N_EPISODES=1000)

  1%|          | 10/1000 [02:18<4:38:57, 16.91s/it]


訓練次數10，總回報-94.99999999999898


  2%|▏         | 20/1000 [04:38<4:34:47, 16.82s/it]


訓練次數20，總回報-94.99999999999898


  3%|▎         | 30/1000 [07:01<4:35:22, 17.03s/it]


訓練次數30，總回報-94.99999999999898


  4%|▍         | 40/1000 [09:20<4:23:40, 16.48s/it]


訓練次數40，總回報-94.99999999999903


  5%|▌         | 50/1000 [11:37<4:19:14, 16.37s/it]


訓練次數50，總回報-94.99999999999903


  6%|▌         | 60/1000 [13:56<4:23:07, 16.80s/it]


訓練次數60，總回報-62.320261437909195


  7%|▋         | 70/1000 [16:16<4:17:55, 16.64s/it]


訓練次數70，總回報-27.384341637010817


  8%|▊         | 80/1000 [18:37<4:15:18, 16.65s/it]


訓練次數80，總回報-94.999999999999


  9%|▉         | 90/1000 [20:55<4:12:16, 16.63s/it]


訓練次數90，總回報-84.20863309352478


 10%|█         | 100/1000 [23:15<4:10:21, 16.69s/it]


訓練次數100，總回報-69.9103942652334


 11%|█         | 110/1000 [25:31<4:05:36, 16.56s/it]


訓練次數110，總回報-80.76512455515993


 12%|█▏        | 120/1000 [27:43<3:45:20, 15.36s/it]


訓練次數120，總回報-72.52808988764073


 13%|█▎        | 130/1000 [29:46<3:17:49, 13.64s/it]


訓練次數130，總回報-129.65849056603784


 14%|█▍        | 140/1000 [31:56<3:45:23, 15.73s/it]


訓練次數140，總回報-88.44262295081901


 15%|█▌        | 150/1000 [34:06<3:43:45, 15.80s/it]


訓練次數150，總回報-36.580756013746424


 16%|█▌        | 160/1000 [36:21<3:51:58, 16.57s/it]


訓練次數160，總回報-10.309446254072432


 17%|█▋        | 170/1000 [38:30<3:37:34, 15.73s/it]


訓練次數170，總回報-1.8965517241389034


 18%|█▊        | 180/1000 [40:41<3:39:24, 16.05s/it]


訓練次數180，總回報53.9361702127659


 19%|█▉        | 190/1000 [42:55<3:43:53, 16.58s/it]


訓練次數190，總回報-94.99999999999899


 20%|██        | 200/1000 [45:14<3:44:10, 16.81s/it]


訓練次數200，總回報-55.57347670250968


 21%|██        | 210/1000 [47:29<3:37:11, 16.50s/it]


訓練次數210，總回報-94.99999999999895


 22%|██▏       | 220/1000 [49:49<3:38:47, 16.83s/it]


訓練次數220，總回報-63.64111498257904


 23%|██▎       | 230/1000 [52:06<3:30:42, 16.42s/it]


訓練次數230，總回報-125.9057971014498


 24%|██▍       | 240/1000 [54:18<3:33:27, 16.85s/it]


訓練次數240，總回報-11.666666666667503


 25%|██▌       | 250/1000 [56:22<3:08:34, 15.09s/it]


訓練次數250，總回報178.04964539007517


 26%|██▌       | 260/1000 [58:41<3:27:41, 16.84s/it]


訓練次數260，總回報107.84697508897156


 27%|██▋       | 270/1000 [1:00:49<3:16:33, 16.16s/it]


訓練次數270，總回報10.431309904153213


 28%|██▊       | 280/1000 [1:02:54<3:16:15, 16.35s/it]


訓練次數280，總回報138.552631578949


 29%|██▉       | 290/1000 [1:05:06<3:09:38, 16.03s/it]


訓練次數290，總回報-94.99999999999895


 30%|███       | 300/1000 [1:07:15<3:06:23, 15.98s/it]


訓練次數300，總回報215.6796116504896


 31%|███       | 310/1000 [1:09:27<3:03:15, 15.94s/it]


訓練次數310，總回報515.1083032490864


 32%|███▏      | 320/1000 [1:11:41<3:02:54, 16.14s/it]


訓練次數320，總回報-94.99999999999898


 33%|███▎      | 330/1000 [1:13:54<3:00:50, 16.19s/it]


訓練次數330，總回報278.417721518984


 34%|███▍      | 340/1000 [1:16:07<2:55:43, 15.98s/it]


訓練次數340，總回報-94.999999999999


 35%|███▌      | 350/1000 [1:18:18<2:52:15, 15.90s/it]


訓練次數350，總回報84.2828685259008


 36%|███▌      | 360/1000 [1:20:30<2:49:50, 15.92s/it]


訓練次數360，總回報298.93939393938575


 37%|███▋      | 370/1000 [1:22:43<2:49:17, 16.12s/it]


訓練次數370，總回報-94.99999999999896


 38%|███▊      | 380/1000 [1:24:56<2:46:10, 16.08s/it]


訓練次數380，總回報536.1475409835956


 39%|███▉      | 390/1000 [1:27:09<2:44:38, 16.19s/it]


訓練次數390，總回報439.2019543973881


 40%|████      | 400/1000 [1:29:23<2:44:12, 16.42s/it]


訓練次數400，總回報281.1755485893355


 41%|████      | 410/1000 [1:31:35<2:35:46, 15.84s/it]


訓練次數410，總回報451.6237942122105


 42%|████▏     | 420/1000 [1:33:48<2:37:52, 16.33s/it]


訓練次數420，總回報500.46925566342304


 43%|████▎     | 430/1000 [1:36:01<2:33:07, 16.12s/it]


訓練次數430，總回報632.5985663082336


 44%|████▍     | 440/1000 [1:38:14<2:28:15, 15.89s/it]


訓練次數440，總回報335.2788844621435


 45%|████▌     | 450/1000 [1:40:27<2:28:07, 16.16s/it]


訓練次數450，總回報713.6642599277875


 46%|████▌     | 460/1000 [1:42:41<2:26:35, 16.29s/it]


訓練次數460，總回報153.32214765101028


 47%|████▋     | 470/1000 [1:44:55<2:23:00, 16.19s/it]


訓練次數470，總回報497.46575342465087


 48%|████▊     | 480/1000 [1:47:10<2:20:29, 16.21s/it]


訓練次數480，總回報520.3846153846075


 49%|████▉     | 490/1000 [1:49:23<2:17:11, 16.14s/it]


訓練次數490，總回報859.7169811320648


 50%|█████     | 500/1000 [1:51:46<2:21:53, 17.03s/it]


訓練次數500，總回報240.84905660377848


 51%|█████     | 510/1000 [1:54:08<2:19:58, 17.14s/it]


訓練次數510，總回報751.4285714285576


 52%|█████▏    | 520/1000 [1:56:30<2:14:38, 16.83s/it]


訓練次數520，總回報308.3517241379195


 53%|█████▎    | 530/1000 [1:58:53<2:14:37, 17.19s/it]


訓練次數530，總回報657.6132404181059


 54%|█████▍    | 540/1000 [2:01:15<2:10:05, 16.97s/it]


訓練次數540，總回報889.4961240309992


 55%|█████▌    | 550/1000 [2:03:37<2:07:34, 17.01s/it]


訓練次數550，總回報782.5510204081544


 56%|█████▌    | 560/1000 [2:06:00<2:07:12, 17.35s/it]


訓練次數560，總回報506.69491525422654


 57%|█████▋    | 570/1000 [2:08:22<2:02:10, 17.05s/it]


訓練次數570，總回報823.727915194332


 58%|█████▊    | 580/1000 [2:10:45<2:00:35, 17.23s/it]


訓練次數580，總回報654.1749174917402


 59%|█████▉    | 590/1000 [2:12:54<1:51:07, 16.26s/it]


訓練次數590，總回報514.1549295774632


 60%|██████    | 600/1000 [2:15:18<1:55:05, 17.26s/it]


訓練次數600，總回報683.7878787878635


 61%|██████    | 610/1000 [2:17:40<1:50:25, 16.99s/it]


訓練次數610，總回報450.81673306772313


 62%|██████▏   | 620/1000 [2:20:02<1:48:14, 17.09s/it]


訓練次數620，總回報515.5610561056046


 63%|██████▎   | 630/1000 [2:22:24<1:44:04, 16.88s/it]


訓練次數630，總回報639.3173431734216


 64%|██████▍   | 640/1000 [2:24:47<1:42:08, 17.02s/it]


訓練次數640，總回報640.507246376804


 65%|██████▌   | 650/1000 [2:27:05<1:37:29, 16.71s/it]


訓練次數650，總回報569.6706586826228


 66%|██████▌   | 660/1000 [2:29:27<1:36:30, 17.03s/it]


訓練次數660，總回報665.2523659305859


 67%|██████▋   | 670/1000 [2:31:50<1:34:31, 17.19s/it]


訓練次數670，總回報439.4262295081855


 68%|██████▊   | 680/1000 [2:34:14<1:31:00, 17.06s/it]


訓練次數680，總回報843.5964912280533


 69%|██████▉   | 690/1000 [2:36:37<1:28:05, 17.05s/it]


訓練次數690，總回報654.9999999999905


 70%|███████   | 700/1000 [2:38:56<1:19:15, 15.85s/it]


訓練次數700，總回報179.74046242774205


 71%|███████   | 710/1000 [2:41:18<1:22:14, 17.01s/it]


訓練次數710，總回報292.7551020408033


 72%|███████▏  | 720/1000 [2:43:38<1:19:27, 17.03s/it]


訓練次數720，總回報315.5960264900566


 73%|███████▎  | 730/1000 [2:45:58<1:15:39, 16.81s/it]


訓練次數730，總回報365.75085324230827


 74%|███████▍  | 740/1000 [2:48:21<1:14:11, 17.12s/it]


訓練次數740，總回報782.9527559055023


 75%|███████▌  | 750/1000 [2:50:42<1:09:26, 16.67s/it]


訓練次數750，總回報37.841328413287144


 76%|███████▌  | 760/1000 [2:53:03<1:07:11, 16.80s/it]


訓練次數760，總回報202.2136222910243


 77%|███████▋  | 770/1000 [2:55:26<1:05:10, 17.00s/it]


訓練次數770，總回報528.6933797909273


 78%|███████▊  | 780/1000 [2:57:51<1:05:12, 17.79s/it]


訓練次數780，總回報422.3611111111021


 79%|███████▉  | 790/1000 [3:00:13<58:57, 16.85s/it]


訓練次數790，總回報564.0163934426115


 80%|████████  | 800/1000 [3:02:34<56:57, 17.09s/it]


訓練次數800，總回報590.8974358974253


 81%|████████  | 810/1000 [3:04:56<54:15, 17.13s/it]


訓練次數810，總回報472.27272727271213


 82%|████████▏ | 820/1000 [3:07:18<50:33, 16.85s/it]


訓練次數820，總回報403.36065573769315


 83%|████████▎ | 830/1000 [3:09:39<47:38, 16.81s/it]


訓練次數830，總回報270.71627906975493


 84%|████████▍ | 840/1000 [3:12:00<44:58, 16.86s/it]


訓練次數840，總回報765.544217687065


 85%|████████▌ | 850/1000 [3:14:12<35:52, 14.35s/it]


訓練次數850，總回報22.505405405406265


 86%|████████▌ | 860/1000 [3:16:29<39:37, 16.98s/it]


訓練次數860，總回報667.2149837133416


 87%|████████▋ | 870/1000 [3:18:51<36:55, 17.04s/it]


訓練次數870，總回報503.0066445182596


 88%|████████▊ | 880/1000 [3:21:13<34:28, 17.24s/it]


訓練次數880，總回報592.0967741935383


 89%|████████▉ | 890/1000 [3:23:33<30:33, 16.66s/it]


訓練次數890，總回報300.70552147238084


 90%|█████████ | 900/1000 [3:25:52<27:45, 16.65s/it]


訓練次數900，總回報857.7272727272589


 91%|█████████ | 910/1000 [3:28:14<25:27, 16.97s/it]


訓練次數910，總回報894.9999999999907


 92%|█████████▏| 920/1000 [3:30:35<22:32, 16.91s/it]


訓練次數920，總回報331.38036809814764


 93%|█████████▎| 930/1000 [3:32:53<19:45, 16.93s/it]


訓練次數930，總回報445.8163265306001


 94%|█████████▍| 940/1000 [3:35:14<16:45, 16.76s/it]


訓練次數940，總回報862.8947368420885


 95%|█████████▌| 950/1000 [3:37:35<14:04, 16.89s/it]


訓練次數950，總回報879.9999999999843


 96%|█████████▌| 960/1000 [3:39:57<11:28, 17.22s/it]


訓練次數960，總回報565.1307189542347


 97%|█████████▋| 970/1000 [3:42:19<08:33, 17.12s/it]


訓練次數970，總回報745.148698884746


 98%|█████████▊| 980/1000 [3:44:43<05:43, 17.19s/it]


訓練次數980，總回報301.3414634146237


 99%|█████████▉| 990/1000 [3:47:05<02:49, 16.91s/it]


訓練次數990，總回報459.09836065572574


100%|██████████| 1000/1000 [3:49:27<00:00, 13.77s/it]


訓練次數1000，總回報264.6214511040881





In [None]:
Agent.Record()