# you can use this file to train the model in some cloud platform

## utils

In [4]:
import os
from pathlib import Path
import sys
from typing import Any, Dict, List, Optional

from pygame import image as pyg_image
from pygame import mixer as pyg_mixer
from pygame import Rect
from pygame.transform import flip as img_flip
from pygame.transform import smoothscale


_BASE_DIR = r"E:\MyCode\pythonCode\flappy_bird_ai"  # remember to change this

SPRITES_PATH = os.path.join(_BASE_DIR, "assets/images")
AUDIO_PATH = os.path.join(_BASE_DIR , "assets/sounds")

PLAYER_WIDTH = 30
PLAYER_HEIGHT = 25

BACKGROUND_WIDTH = 360
BACKGROUND_HEIGHT = 450

PIPE_WIDTH = PLAYER_WIDTH
PIPE_HEIGHT = int(BACKGROUND_HEIGHT * 0.7)

class Utils:
    def __int__(self):
        pass
    @staticmethod
    def pixel_collision(rect1: Rect,
                        rect2: Rect,
                        hitmask1: List[List[bool]],
                        hitmask2: List[List[bool]]) -> bool:
        """ Checks if two objects collide and not just their rects. """
        rect = rect1.clip(rect2)
    
        if rect.width == 0 or rect.height == 0:
            return False
    
        x1, y1 = rect.x - rect1.x, rect.y - rect1.y
        x2, y2 = rect.x - rect2.x, rect.y - rect2.y
    
        for x in range(rect.width):
            for y in range(rect.height):
                if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:
                    return True
        return False
    
    @staticmethod
    def get_hitmask(image) -> List[List[bool]]:
        """ Returns a hitmask using an image's alpha. """
        mask = []
        for x in range(image.get_width()):
            mask.append([])
            for y in range(image.get_height()):
                mask[x].append(bool(image.get_at((x, y))[3]))
        return mask
    
    @staticmethod
    def _load_sprite(filename, convert, alpha=True):
        img = pyg_image.load(f"{SPRITES_PATH}/{filename}")
        return (img.convert_alpha() if convert and alpha
                else img.convert() if convert
                else img)
    
    @staticmethod
    def load_images(convert: bool = True,
                    bg_type: Optional[str] = None) -> Dict[str, Any]:  # 根据game_logic调整的图片大小(写死的)
        """ Loads and returns the image assets of the game. """
        images = {}
    
        try:
            if bg_type is None:
                images["background"] = None
            else:
                images["background"] = smoothscale(Utils._load_sprite(f"{bg_type}.png",
                                                    convert=convert, alpha=False), (BACKGROUND_WIDTH, BACKGROUND_HEIGHT))
        
            # Bird sprites:
            images["player"] = (
                smoothscale(Utils._load_sprite(f"bird_up.png",
                             convert=convert, alpha=True),(PLAYER_WIDTH, PLAYER_HEIGHT)),
                smoothscale(Utils._load_sprite(f"bird_middle.png",
                             convert=convert, alpha=True),(PLAYER_WIDTH, PLAYER_HEIGHT)),
                smoothscale(Utils._load_sprite(f"bird_down.png",
                             convert=convert, alpha=True),(PLAYER_WIDTH, PLAYER_HEIGHT)),
            )
        
            # Pipe sprites:
            pipe_sprite = smoothscale(Utils._load_sprite(f"pipe.png",
                                       convert=convert, alpha=True),(PIPE_WIDTH, PIPE_HEIGHT))
            images["pipe"] = (img_flip(pipe_sprite, False, True),
                              pipe_sprite)  # up_pipe and low_pipe
        except FileNotFoundError as ex:
            raise FileNotFoundError("Can't find the sprites folder! No such file or"
                                    f" directory: {SPRITES_PATH}") from ex
        return images
    
    @staticmethod
    def load_sounds() -> Dict[str, pyg_mixer.Sound]:
        """ Loads and returns the audio assets of the game. """
        pyg_mixer.init()
        sounds = {}
        try:
            sounds["game_over"] = pyg_mixer.Sound(os.path.join(AUDIO_PATH, "game_over.wav"))
            sounds["hit"] = pyg_mixer.Sound(os.path.join(AUDIO_PATH, "hit.wav"))
            sounds["score"] = pyg_mixer.Sound(os.path.join(AUDIO_PATH, "score.mp3"))
            sounds["jump"] = pyg_mixer.Sound(os.path.join(AUDIO_PATH, "jump.wav"))
            sounds["btn_click"] = pyg_mixer.Sound(os.path.join(AUDIO_PATH, "btn_click.wav"))
            sounds["main_theme"] = pyg_mixer.Sound(os.path.join(AUDIO_PATH, "main_theme.mp3"))
            sounds["world_clear"] = pyg_mixer.Sound(os.path.join(AUDIO_PATH, "world_clear.wav"))
        except FileNotFoundError as ex:
            raise FileNotFoundError("Can't find the audio folder! No such file or "
                                    f"directory: {AUDIO_PATH}") from ex
    
        return sounds
utils = Utils()

## game_logic

In [5]:
import random
from enum import IntEnum
from itertools import cycle
from typing import Dict, Tuple, Union

import pygame

############################ Speed and Acceleration ############################
PIPE_VEL_X = -4

PLAYER_MAX_VEL_Y = 10  # max vel along Y, max descend speed
PLAYER_MIN_VEL_Y = -8  # min vel along Y, max ascend speed

PLAYER_ACC_Y = 1  # players downward acceleration
PLAYER_VEL_ROT = 5  # angular speed

PLAYER_FLAP_ACC = -9  # players speed on flapping
################################################################################


################################## Dimensions ##################################
PLAYER_WIDTH = 30
PLAYER_HEIGHT = 25

BACKGROUND_WIDTH = 360
BACKGROUND_HEIGHT = 450

PIPE_WIDTH = PLAYER_WIDTH
PIPE_HEIGHT = int(BACKGROUND_HEIGHT * 0.7)
PIPE_DISTANCE = 120  # 管道之间距离
NUM_PIPE_ON_SCREEN = BACKGROUND_WIDTH // 120


################################################################################


class FlappyBirdLogic:
    def __init__(self,
                 screen_size: Tuple[int, int],
                 pipe_gap_size: int = 100) -> None:
        self._screen_width = screen_size[0]
        self._screen_height = screen_size[1]

        self.player_x = int(self._screen_width * 0.2)
        self.player_y = int((self._screen_height - PLAYER_HEIGHT) / 2)

        self.score = 0
        self._pipe_gap_size = pipe_gap_size

        # Generate 2 new pipes to add to upper_pipes and lower_pipes lists
        # new_pipe1 = self._get_random_pipe()
        # new_pipe2 = self._get_random_pipe()
        tmp_pipe_list = [self._get_random_pipe() for i in range(NUM_PIPE_ON_SCREEN+1)]

        # List of upper pipes:
        self.upper_pipes = [
            {"x": self._screen_width + i * PIPE_DISTANCE, "y": tmp_pipe_list[i][0]["y"]}
            for i in range(NUM_PIPE_ON_SCREEN+1)
        ]
        # self.upper_pipes = [
        #     {"x": self._screen_width + PIPE_DISTANCE,
        #      "y": new_pipe1[0]["y"]},
        #     {"x": self._screen_width + PIPE_DISTANCE + PIPE_DISTANCE,
        #      "y": new_pipe2[0]["y"]},
        # ]

        # List of lower pipes:
        self.lower_pipes = [
            {"x": self._screen_width + i * PIPE_DISTANCE, "y": tmp_pipe_list[i][1]["y"]}
            for i in range(NUM_PIPE_ON_SCREEN+1)
        ]

        # self.lower_pipes = [
        #     {"x": self._screen_width + PIPE_DISTANCE,
        #      "y": new_pipe1[1]["y"]},
        #     {"x": self._screen_width + PIPE_DISTANCE + PIPE_DISTANCE,
        #      "y": new_pipe2[1]["y"]},
        # ]

        # Player's info:
        self.player_vel_y = -9  # player's velocity along Y
        self.player_rot = 45  # player's rotation

        self.last_action = None
        self.sound_cache = None

        self._player_flapped = False
        self.player_idx = 0
        self._player_idx_gen = cycle([0, 1, 2, 1])
        self._loop_iter = 0

    class Actions(IntEnum):  # 注意是一个类来定义枚举量(或共同体Union)
        """ Possible actions for the player to take. """
        IDLE, FLAP = 0, 1

    def _get_random_pipe(self) -> Dict[str, int]:
        """ Returns a randomly generated pipe. """
        # y of gap between upper and lower pipe
        gap_y = random.randrange(0,
                                 int(self._screen_height * 0.6 - self._pipe_gap_size))
        gap_y += int(self._screen_height * 0.2)

        pipe_x = self._screen_width + PIPE_DISTANCE  # 我这里恰好屏幕容纳整数个pipe,则第一个pipe要消失时，即它的x在最左边时，就要添加一个新的x在最右边
        return [
            {"x": pipe_x, "y": gap_y - PIPE_HEIGHT},  # upper pipe
            {"x": pipe_x, "y": gap_y + self._pipe_gap_size},  # lower pipe
        ]

    def check_crash(self) -> bool:
        """ Returns True if player collides with the ground (base) or a pipe.
        """
        if self.player_y + PLAYER_HEIGHT >= self._screen_height or self.player_y < -1:  # 撞地板或天花板
            return True
        else:
            player_rect = pygame.Rect(self.player_x, self.player_y,
                                      PLAYER_WIDTH, PLAYER_HEIGHT)

            for up_pipe, low_pipe in zip(self.upper_pipes, self.lower_pipes):
                # upper and lower pipe rects
                up_pipe_rect = pygame.Rect(up_pipe['x'], up_pipe['y'],
                                           PIPE_WIDTH, PIPE_HEIGHT)
                low_pipe_rect = pygame.Rect(low_pipe['x'], low_pipe['y'],
                                            PIPE_WIDTH, PIPE_HEIGHT)

                # check collision  Use pygame.Rect.colliderect to do this
                up_collide = player_rect.colliderect(up_pipe_rect)
                low_collide = player_rect.colliderect(low_pipe_rect)

                if up_collide or low_collide:
                    return True

        return False

    def update_state(self, action: Union[Actions, int]) -> bool:
        """ Given an action taken by the player, updates the game's state.

        Args:
            action (Union[FlappyBirdLogic.Actions, int]): The action taken by
                the player.

        Returns:
            `True` if the player is alive and `False` otherwise.
        更新了的（重要的):
          self.sound_cache = str  # 该播放的音乐名称
          self.player_y
          self.pipes
        """
        self.sound_cache = None  # 该播放的音乐名称
        if action == FlappyBirdLogic.Actions.FLAP:
            if self.player_y > -2 * PLAYER_HEIGHT:
                self.player_vel_y = PLAYER_FLAP_ACC  # 迅速改变为上升的速度
                self._player_flapped = True
                self.sound_cache = "jump"

        self.last_action = action  # 跟新action
        if self.check_crash():
            self.sound_cache = "hit"
            return False  # die

        # check for score  检测鸟中间坐标是否过了管子中间坐标
        player_mid_pos = self.player_x + PLAYER_WIDTH / 2
        for pipe in self.upper_pipes:
            pipe_mid_pos = pipe['x'] + PIPE_WIDTH / 2
            if pipe_mid_pos <= player_mid_pos < pipe_mid_pos + (-PIPE_VEL_X):  # 使用一个区间来保证每对管子只会加一次分
                self.score += 1
                self.sound_cache = "score"

        # player_index change  显示鸟拍翅膀的动画
        if (self._loop_iter + 1) % 3 == 0:
            self.player_idx = next(self._player_idx_gen)

        self._loop_iter = (self._loop_iter + 1) % 30

        # rotate the player
        if self.player_rot > -70:
            self.player_rot -= PLAYER_VEL_ROT

        # player's movement
        if self.player_vel_y < PLAYER_MAX_VEL_Y and not self._player_flapped:
            self.player_vel_y += PLAYER_ACC_Y  # 自动下落

        if self._player_flapped:
            self._player_flapped = False

            # more rotation to cover the threshold
            # (calculated in visible rotation)
            self.player_rot = 45

        self.player_y += min(self.player_vel_y,
                             self._screen_height - self.player_y - PLAYER_HEIGHT)  # 不能掉到地板以下

        # move pipes to left
        for up_pipe, low_pipe in zip(self.upper_pipes, self.lower_pipes):
            up_pipe['x'] += PIPE_VEL_X
            low_pipe['x'] += PIPE_VEL_X

        # add new pipe when first pipe is about to touch left of screen
        # if len(self.upper_pipes) > 0 and 0 < self.upper_pipes[0]['x'] <= (-PIPE_VEL_X):


        # remove first pipe if its out of the screen
        if (len(self.upper_pipes) > 0 and
                self.upper_pipes[0]['x'] < -PIPE_WIDTH):
            self.upper_pipes.pop(0)
            self.lower_pipes.pop(0)
            new_pipes = self._get_random_pipe()
            self.upper_pipes.append(new_pipes[0])
            self.lower_pipes.append(new_pipes[1])

        return True

## renderer

In [6]:
#: Player's rotation threshold.
PLAYER_ROT_THR = 20

#: Color to fill the surface's background when no background image was loaded.
FILL_BACKGROUND_COLOR = (200, 200, 200)
pygame.font.init()
FONT_NAME = "fangsong"
FONT = pygame.font.SysFont(FONT_NAME, 20)
FONT.bold = True

class FlappyBirdRenderer:
    """ Handles the rendering of the game.

    This class implements the game's renderer, responsible from drawing the game
    on the screen.

    Args:
        screen_size (Tuple[int, int]): The screen's width and height.
        audio_on (bool): Whether the game's audio is ON or OFF.
        background (str): Type of background image.
    """

    def __init__(self,
                 screen_size: Tuple[int, int] = (288, 512),
                 audio_on: bool = True,
                 background: Optional[str] = "day") -> None:
        self._screen_width = screen_size[0]
        self._screen_height = screen_size[1]

        self.display = None
        self.surface = pygame.Surface(screen_size)
        self.images = utils.load_images(convert=False,
                                        bg_type=background)
        self.audio_on = audio_on
        self._audio_queue = []
        if audio_on:
            self.sounds = utils.load_sounds()

        self.game = None  # game_logic
        self._clock = pygame.time.Clock()  # FPS

    def make_display(self) -> None:
        """ Initializes the pygame's display.

        Required for drawing images on the screen.
        """
        self.display = pygame.display.set_mode((self._screen_width,
                                                self._screen_height))
        # self.images全部的键值对的值全部变成tuple即(, , ,)的形式
        for name, value in self.images.items():
            if value is None:
                continue

            if type(value) in (tuple, list):
                self.images[name] = tuple([img.convert_alpha()
                                           for img in value])
            else:
                self.images[name] = (value.convert() if name == "background"
                                     else value.convert_alpha())

    def _draw_score(self) -> None:
        """ Draws the score in the center of the surface. """

        self.surface.blit(FONT.render(f"{self.game.score}", True, (0, 0, 0)), (self._screen_width // 2, 10))

    def draw_surface(self, show_score: bool = True) -> None:
        """ Re-draws the renderer's surface.

        This method updates the renderer's surface by re-drawing it according to
        the current state of the game.

        Args:
            show_score (bool): Whether to draw the player's score or not.
        """
        if self.game is None:
            raise ValueError("A game logic must be assigned to the renderer!")

        # Background
        if self.images['background'] is not None:
            self.surface.blit(self.images['background'], (0, 0))
        else:
            self.surface.fill(FILL_BACKGROUND_COLOR)

        # Pipes
        for up_pipe, low_pipe in zip(self.game.upper_pipes,
                                     self.game.lower_pipes):
            self.surface.blit(self.images['pipe'][0],
                              (up_pipe['x'], up_pipe['y']))
            self.surface.blit(self.images['pipe'][1],
                              (low_pipe['x'], low_pipe['y']))

        # Score
        # (must be drawn before the player, so the player overlaps it)
        if show_score:
            self._draw_score()

        # Getting player's rotation
        visible_rot = PLAYER_ROT_THR
        if self.game.player_rot <= PLAYER_ROT_THR:
            visible_rot = self.game.player_rot

        # Player
        player_surface = pygame.transform.rotate(
            self.images['player'][self.game.player_idx],
            visible_rot,
        )

        self.surface.blit(player_surface, (self.game.player_x,
                                           self.game.player_y))

    def update_display(self) -> None:
        """ Updates the display with the current surface of the renderer.

        A call to this method is usually preceded by a call to
        :meth:`.draw_surface()`. This method simply updates the display by
        showing the current state of the renderer's surface on it, it doesn't
        make any change to the surface.
        """
        if self.display is None:
            raise RuntimeError(
                "Tried to update the display, but a display hasn't been "
                "created yet! To create a display for the renderer, you must "
                "call the `make_display()` method."
            )

        self.display.blit(self.surface, [0, 0])
        pygame.display.update()

        # Sounds:
        if self.audio_on and self.game.sound_cache is not None:
            sound_name = self.game.sound_cache
            self.sounds[sound_name].play()


## flappy_bird_env

In [7]:
from typing import Dict, Tuple, Optional, Union

import gym
import numpy as np
import pygame

class FlappyBirdEnv(gym.Env):  # custom env using gym
    metadata = {"render.modes": ["human", "rgb_array"]}  # 其实在这里human mode 和 rgb_array没什么区别,就是render函数返回不同而已,不过压根不会用render函数的返回值

    def __init__(self,
                 screen_size: Tuple[int, int] = (360, 450),
                 pipe_gap: int = 100,
                 background: Optional[str] = None
                 ):
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(0, 255, [*screen_size, 3])
        self._screen_size = screen_size
        self._pipe_gap = pipe_gap

        self._game = None  # game_logic的实例
        self._renderer = FlappyBirdRenderer(screen_size=self._screen_size,
                                            background=background)

    def _get_observation(self):
        self._renderer.draw_surface(show_score=False)
        return pygame.surfarray.array3d(self._renderer.surface)  # 返回的是图片矩阵

    def reset(self):
        """ Resets the environment (starts a new game).
        return a current game screen shot and don't include a info_dict
        """
        self._game = FlappyBirdLogic(screen_size=self._screen_size,
                                     pipe_gap_size=self._pipe_gap)

        self._renderer.game = self._game
        return self._get_observation()



    def step(self,
             action: Union[FlappyBirdLogic.Actions, int],
         )-> Tuple[np.ndarray, float, bool, Dict]:
        """
        :param
            action(Union[FlappyBirdLogic.Actions, int]): The action taken by
                the agent. Zero (0) means "do nothing" and one (1) means "flap".
        :return:
         A tuple containing, respectively:

                * an observation (RGB-array representing the game's screen);
                * a reward (always 1);
                * a status report (`True` if the game is over and `False`
                  otherwise);
                * an info dictionary.
        """
        alive = self._game.update_state(action)
        obs = self._get_observation()

        reward = 0.1   # redefined reward，你可能需要重新定义reward在你自己的游戏中
        done = not alive
        if done:
            reward = -2.8
        elif self._game.sound_cache == 'score':
            reward = 2.8

        info = {"score": self._game.score}

        return obs, reward, done, info

    def render(self, mode="human") -> Optional[np.ndarray]:
        """
        If ``mode`` is:

            - human: render to the current display. Usually for human
              consumption.
            - rgb_array: Return an numpy.ndarray with shape (x, y, 3),
              representing RGB values for an x-by-y pixel image, suitable
              for turning into a video.
        :return:
            `None` if ``mode`` is "human" and a numpy.ndarray with RGB values if
            it's "rgb_array"
        """
        if mode not in FlappyBirdEnv.metadata["render.modes"]:
            raise ValueError("Invalid render mode!")

        self._renderer.draw_surface(show_score=True)
        if mode == "rgb_array":
            return pygame.surfarray.array3d(self._renderer.surface)  # infact not need this return, use step()'s return is enough
        else:
            if self._renderer.display is None:
                self._renderer.make_display()

            self._renderer.update_display()

    def close(self):
        """ Closes the environment. """
        if self._renderer is not None:
            pygame.display.quit()
            self._renderer = None
        super().close()

## DQN

In [8]:
import torch.nn as nn
from collections import deque, namedtuple

import random


Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class DQN(nn.Module):
    """
        input: shape(batch_size, 4,84,84)
    """
    def __init__(self, frame_num=4, n_actions=2):
        super(DQN, self).__init__()
        # 参数
        self.number_of_actions = n_actions

        self.conv1 = nn.Conv2d(frame_num, 32, 8, 4)  # (84-8)/4+1 = 20
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 64, 4, 2)  # (20-4)/2+1=9
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(64, 64, 3, 1)  # (9-3)/1+1 = 7
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(3136, 512)  # 7x7x64 = 3136
        self.relu4 = nn.ReLU()
        self.fc5 = nn.Linear(512, self.number_of_actions)

    def forward(self, x):

        out = self.conv1(x)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.relu3(out)
        out = out.view(out.size()[0], -1)
        out = self.fc4(out)
        out = self.relu4(out)
        out = self.fc5(out)

        return out


## model_train

In [9]:
import cv2
import numpy as np
import torch

class MyToolFunc:
    def __init__(self):
        pass
    @staticmethod
    def resize_and_bgr2bin(image):
        image = cv2.resize(image, (84, 84))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        thresh, image = cv2.threshold(image, 199, 1, cv2.THRESH_BINARY_INV)
        return image
    @staticmethod
    def image_to_tensor(image):
        image_tensor = image.transpose(2, 0, 1)
        image_tensor = image_tensor.astype(np.float32)
        image_tensor = torch.from_numpy(image_tensor)
        if torch.cuda.is_available():  # put on GPU if CUDA is available
            image_tensor = image_tensor.cuda()
        return image_tensor
    @staticmethod
    def process_state(state):  # make it  a batch input
        """state is rgb img"""
        state = MyToolFunc.resize_and_bgr2bin(state)
        state = np.expand_dims(state, axis=0)
        return state
my_tool_func = MyToolFunc()

In [10]:
import torch
import torch.optim as optim
import torch.nn as nn
from itertools import count
import matplotlib
import matplotlib.pyplot as plt
import os
import random
import math

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
plt.ion()

<contextlib.ExitStack at 0x1a026beba30>

In [11]:
def init_weights(m):  # 初始化模型权重
    if type(m) == nn.Conv2d or type(m) == nn.Linear:
        torch.nn.init.uniform(m.weight, -0.01, 0.01)
        m.bias.data.fill_(0.01)

In [12]:
RE_TRAIN_FLAG = False  # False then use the existed model to continue training
BATCH_SIZE = 32
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 3000
TAU = 0.005
LR = 1e-4  # learning rate
# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = FlappyBirdEnv()
# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
n_observations = 4  # each time use four frames

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

steps_done = 0

In [13]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                    math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            tmp = policy_net(state)
            tmp = tmp.max(1).indices.view(1, 1)
            return tmp
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


episode_durations = []


def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                       if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1).values
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    # criterion = nn.MSELoss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    # torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()


In [14]:
save_dir = r"./my_model"
# 创建路径（如果不存在）
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if RE_TRAIN_FLAG:
    print("retrain model")
    init_weights(policy_net)
    target_net.load_state_dict(policy_net.state_dict())
else:
    print("use exist model")
    try:
        policy_net.load_state_dict(torch.load(os.path.join(save_dir, 'policy_net.pkl')))
        target_net.load_state_dict(torch.load(os.path.join(save_dir, 'target_net.pkl')))
    except Exception as e:
        print(e)
        print(f"model files not found: {os.path.join(save_dir, 'policy_net.pkl')} or {os.path.join(save_dir, 'target_net.pkl')}")
        print("You can set RE_TRAIN_FLAG=True to retrain the model")
        print("automatically retrain model")
        init_weights(policy_net)
        target_net.load_state_dict(policy_net.state_dict())

use exist model


# Training Loop

In [15]:
if torch.cuda.is_available():
    num_episodes = 26000
else:
    num_episodes = 50


max_score = 1
policy_net.train()
for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state = env.reset()  # rgb img data:shape(360,450,3)
    state = my_tool_func.process_state(state)
    state = np.repeat(state, 4, axis=0)  # 最开始将四帧图片全部初始化为第一帧图片
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) # 作成batch
    for t in count():
        action = select_action(state)
        observation, reward, done, info = env.step(action.item())  # observation is rgb img
        if info['score'] > max_score:
            max_score = info['score']
            print(f"{i_episode}:{info['score']}")
        reward = torch.tensor([reward], device=device)

        if done:
            next_state = None
        else:
            next_state = my_tool_func.process_state(observation)
            next_state = torch.tensor(next_state, dtype=torch.float32, device=device)
            next_state = torch.cat((state.squeeze(0)[1:, :, :], next_state)).unsqueeze(0)  # 更新帧组

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break
    if (i_episode+1) % 1000 == 0:
        print(f"save,max_score:{max_score}")
        torch.save(target_net.state_dict(), os.path.join(save_dir, f'target_net_{i_episode}.pkl'))
        torch.save(policy_net.state_dict(), os.path.join(save_dir, f'policy_net_{i_episode}.pkl'))

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()
torch.save(target_net.state_dict(), os.path.join(save_dir, f'target_net_{num_episodes}.pkl'))
torch.save(policy_net.state_dict(), os.path.join(save_dir, f'policy_net_{num_episodes}.pkl'))


KeyboardInterrupt



<Figure size 640x480 with 0 Axes>