In [6]:
from my_model.DQN_net import DQN, ReplayMemory, Transition
import torch
import torch.optim as optim
import torch.nn as nn
from itertools import count
import matplotlib
import matplotlib.pyplot as plt
import os
import my_env
import random
import math
import my_tool_func
import numpy as np

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
plt.ion()

<contextlib.ExitStack at 0x1efc47c1b10>

In [7]:
def init_weights(m):  # 初始化模型权重
    if type(m) == nn.Conv2d or type(m) == nn.Linear:
        torch.nn.init.uniform(m.weight, -0.01, 0.01)
        m.bias.data.fill_(0.01)

In [8]:
RE_TRAIN_FLAG = False  # False then use the existed model to continue training
BATCH_SIZE = 32
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 3000
TAU = 0.005
LR = 1e-4  # learning rate
# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = my_env.make("FlappyBird-rgb-v0")
# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
n_observations = 4  # each time use four frames

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

steps_done = 0

  logger.warn(


# functions

In [9]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                    math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            tmp = policy_net(state)
            tmp = tmp.max(1).indices.view(1, 1)
            return tmp
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


episode_durations = []


def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                       if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1).values
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    # criterion = nn.MSELoss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    # torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()


In [19]:
save_dir = r"./my_model"
# 创建路径（如果不存在）
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if RE_TRAIN_FLAG:
    print("retrain model")
    init_weights(policy_net)
    target_net.load_state_dict(policy_net.state_dict())
else:
    print("use exist model")
    try:
        policy_net.load_state_dict(torch.load(os.path.join(save_dir, 'policy_net.pkl')))
        target_net.load_state_dict(torch.load(os.path.join(save_dir, 'target_net.pkl')))
    except Exception as e:
        print(e)
        print(f"model files not found: {os.path.join(save_dir, 'policy_net.pkl')} or {os.path.join(save_dir, 'target_net.pkl')}")
        print("You can set RE_TRAIN_FLAG=True to retrain the model")
        print("automatically retrain model")
        init_weights(policy_net)
        target_net.load_state_dict(policy_net.state_dict())

use exist model


# Training loop

In [20]:
if torch.cuda.is_available():
    num_episodes = 26000
else:
    num_episodes = 50


max_score = 1
policy_net.train()
for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state = env.reset()  # rgb img data:shape(360,450,3)
    state = my_tool_func.process_state(state)
    state = np.repeat(state, 4, axis=0)  # 最开始将四帧图片全部初始化为第一帧图片
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) # 作成batch
    for t in count():
        action = select_action(state)
        observation, reward, done, info = env.step(action.item())  # observation is rgb img
        if info['score'] > max_score:
            max_score = info['score']
            print(f"{i_episode}:{info['score']}")
        reward = torch.tensor([reward], device=device)

        if done:
            next_state = None
        else:
            next_state = my_tool_func.process_state(observation)
            next_state = torch.tensor(next_state, dtype=torch.float32, device=device)
            next_state = torch.cat((state.squeeze(0)[1:, :, :], next_state)).unsqueeze(0)  # 更新帧组

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break
    if (i_episode+1) % 1000 == 0:
        print(f"save,max_score:{max_score}")
        torch.save(target_net.state_dict(), os.path.join(save_dir, f'target_net_{i_episode}.pkl'))
        torch.save(policy_net.state_dict(), os.path.join(save_dir, f'policy_net_{i_episode}.pkl'))

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()
torch.save(target_net.state_dict(), os.path.join(save_dir, f'target_net_{num_episodes}.pkl'))
torch.save(policy_net.state_dict(), os.path.join(save_dir, f'policy_net_{num_episodes}.pkl'))

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

In [None]:
# torch.save(target_net.state_dict(), os.path.join(save_dir, f'target_net.pkl'))
# torch.save(policy_net.state_dict(), os.path.join(save_dir, f'policy_net.pkl'))