## Обучение агента

In [1]:
import os
import random
import numpy as np
import collections
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gymnasium as gym
from vizdoom import gymnasium_wrapper
from gymnasium.wrappers import RecordVideo
from copy import deepcopy
from torch.distributions import Normal
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

In [None]:
# transition = np.dtype([
#                        ('state_s', np.float32, (4, 240, 320)),
#                        ('state_gv', np.float32, (5,)),
#                        ('action_b', np.int16),
#                        ('action_c', np.float32, (3,)),
#                        ('reward', np.float32),
#                        ('done', np.bool),
#                        ('state_s_next', np.float32, (4, 240, 320)),
#                        ('state_gv_next', np.float32, (5,)),
#                       ])

In [3]:
class ActorNet(nn.Module):

    def __init__(self):
        super(ActorNet, self).__init__()
        self.img_stack = 4
        self.cnn_base = nn.Sequential(  # input shape (4, 240, 320)
            nn.Conv2d(self.img_stack, 8, kernel_size=4, stride=2),
            nn.ReLU(),  # activation
            nn.Conv2d(8, 16, kernel_size=3, stride=2),  # (8, 119, 159)
            nn.ReLU(),  # activation
            nn.Conv2d(16, 32, kernel_size=3, stride=2),  # (16, 59, 79)
            nn.ReLU(),  # activation
            nn.Conv2d(32, 64, kernel_size=3, stride=2),  # (32, 29, 39)
            nn.ReLU(),  # activation
            nn.Conv2d(64, 128, kernel_size=3, stride=1),  # (64, 14, 19)
            nn.ReLU(),  # activation
            nn.Conv2d(128, 256, kernel_size=3, stride=1),  # (128, 12, 17)
            nn.ReLU(),  # activation
            nn.AdaptiveAvgPool2d(output_size=1) # (256, 10, 15)
        )  # output shape (256, 1, 1)
        self.fc_base = nn.Sequential(
            nn.Linear(5, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU()
        )
        self.binary = nn.Sequential(nn.Linear(320, 18), nn.Softmax(dim=1))
        self.conts = nn.Sequential(nn.Linear(320, 3), nn.Tanh())

    def forward(self, x, y):
        x = self.cnn_base(x)
        x = x.view(-1, 256)
        y = self.fc_base(y)
        z = torch.cat((x, y), dim=1)
        binary = self.binary(z)
        conts = self.conts(z)
        log_probs = self.conts(z)

        return conts, log_probs, binary

In [4]:
class ValueNet(nn.Module):

    def __init__(self):
        super(ValueNet, self).__init__()
        self.img_stack = 4
        self.cnn_state = nn.Sequential(  # input shape (4, 240, 320)
            nn.Conv2d(self.img_stack, 8, kernel_size=4, stride=2),
            nn.ReLU(),  # activation
            nn.Conv2d(8, 16, kernel_size=3, stride=2),  # (8, 119, 159)
            nn.ReLU(),  # activation
            nn.Conv2d(16, 32, kernel_size=3, stride=2),  # (16, 59, 79)
            nn.ReLU(),  # activation
            nn.Conv2d(32, 64, kernel_size=3, stride=2),  # (32, 29, 39)
            nn.ReLU(),  # activation
            nn.Conv2d(64, 128, kernel_size=3, stride=1),  # (64, 14, 19)
            nn.ReLU(),  # activation
            nn.Conv2d(128, 256, kernel_size=3, stride=1),  # (128, 12, 17)
            nn.ReLU(),  # activation
            nn.AdaptiveAvgPool2d(output_size=1) # (256, 10, 15)
        )  # output shape (256, 1, 1)
        self.fc_state = nn.Sequential(
            nn.Linear(5, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU()
        )
        self.fc_binary = nn.Sequential(
            nn.Linear(1, 18),
            nn.ReLU(),
        )
        self.fc_continiue = nn.Sequential(
            nn.Linear(3, 32),
            nn.ReLU(),
        )
        self.value = nn.Sequential(nn.Linear(370, 1))

    def forward(self, x, y, a, b):
        x = self.cnn_state(x)
        x = x.view(-1, 256)
        y = self.fc_state(y)
        a = self.fc_binary(a)
        b = self.fc_continiue(b)
        v = torch.cat((x, y, a, b), dim=1)
        v = self.value(v)

        return v

In [None]:
class SAC(nn.Module):
    
    def __init__(self, gamma=0.99, alpha=1e-3, tau=1e-2, 
                 buffer_capacity=256, batch_size=64, pi_lr=3e-4, q_lr=3e-4):
        super().__init__()

        self.pi_model = ActorNet()
        self.q1_model = ValueNet() 
        self.q2_model = ValueNet()
        self.gamma = gamma
        self.alpha = alpha
        self.tau = tau
        self.buffer_capacity = buffer_capacity
        self.batch_size = batch_size
        self.memory = [None] * self.buffer_capacity
        self.counter = 0
        self.pi_optimizer = torch.optim.Adam(self.pi_model.parameters(), pi_lr)
        self.q1_optimizer = torch.optim.Adam(self.q1_model.parameters(), q_lr)
        self.q2_optimizer = torch.optim.Adam(self.q2_model.parameters(), q_lr)
        self.q1_target_model = deepcopy(self.q1_model)
        self.q2_target_model = deepcopy(self.q2_model)

    def get_action(self, state_s, state_gv):
        state_s = torch.FloatTensor(state_s)
        state_gv = torch.FloatTensor(state_gv)
        action_c, log_probs, action_b = self.predict_actions(state_s, state_gv)
        return action_b, action_c

    def fit(self, state_s, state_gv, action_b, action_c, reward, done, state_s_next, state_gv_next):
        self.memory[self.counter] = [state_s, state_gv, action_b, action_c, reward, done, state_s_next, state_gv_next]
        self.counter += 1
        if self.counter == self.buffer_capacity:
            self.counter = 0
        
        if self.counter > self.batch_size:
            batch = random.sample(self.memory[:self.counter], self.batch_size)
            states_s, states_gv, actions_b, actions_c, rewards, dones, states_s_next, states_gv_next = zip(*batch)
            
            states_s = torch.tensor(np.stack(states_s)).float()
            states_gv = torch.tensor(np.stack(states_gv)).float().squeeze(1)
            actions_b = torch.tensor(np.stack(actions_b)).float()
            actions_c = torch.tensor(np.stack(actions_c)).float()
            rewards = torch.tensor(np.stack(rewards)).float()
            dones = torch.tensor(np.stack(dones)).float()
            states_s_next = torch.tensor(np.stack(states_s_next)).float()
            states_gv_next = torch.tensor(np.stack(states_gv_next)).float().squeeze(1)
            rewards, dones = rewards.unsqueeze(1), dones.unsqueeze(1)
            
            next_continious, next_log_probs, next_binarys = self.predict_actions(states_s_next, states_gv_next)
            next_q1_values = self.q1_target_model(states_s_next, states_gv_next, next_binarys.float().unsqueeze(1), next_continious.float().squeeze(1))
            next_q2_values = self.q2_target_model(states_s_next, states_gv_next, next_binarys.float().unsqueeze(1), next_continious.float().squeeze(1))
            next_min_q_values = torch.min(next_q1_values, next_q2_values)
            targets = rewards.detach().clone() + self.gamma * (1 - dones.detach().clone()) * (next_min_q_values.repeat(1, 3) - self.alpha * next_log_probs.squeeze(1))

            q1_loss = torch.mean((self.q1_model(states_s, states_gv, actions_b.unsqueeze(1), actions_c) - targets) ** 2)
            q2_loss = torch.mean((self.q2_model(states_s, states_gv, actions_b.unsqueeze(1), actions_c) - targets) ** 2)
            self.update_model(q1_loss, self.q1_optimizer, self.q1_model, self.q1_target_model)
            self.update_model(q2_loss, self.q2_optimizer, self.q2_model, self.q2_target_model)

            pred_continious, log_probs, pred_binarys = self.predict_actions(states_s, states_gv)
            q1_values = self.q1_model(states_s, states_gv, pred_binarys.float().unsqueeze(1), pred_continious.float().squeeze(1))
            q2_values = self.q2_model(states_s, states_gv, pred_binarys.float().unsqueeze(1), pred_continious.float().squeeze(1))
            
            min_q_values = torch.min(q1_values, q2_values)
            pi_loss = - torch.mean(min_q_values - self.alpha * log_probs)
            self.update_model(pi_loss, self.pi_optimizer)
            
    def update_model(self, loss, optimizer, model=None, target_model=None):
        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()
        if model != None and target_model != None:
            for param, target_param in zip(model.parameters(), target_model.parameters()):
                # new_target_param = (1 - self.tau) * target_param + self.tau * param
                target_param.data = (1 - self.tau) * target_param + self.tau * param

    def predict_actions(self, states_s, states_gv):
        means, log_stds, binary = self.pi_model(states_s, states_gv)
        means, log_stds = means.unsqueeze(1), log_stds.unsqueeze(1)
        dists = Normal(means, torch.exp(log_stds))
        continious = dists.rsample()
        log_probs = dists.log_prob(continious)
        binary = torch.argmax(binary, dim=1)
        return continious, log_probs, binary
    
    def save_param(self):
        torch.save(self.pi_model.state_dict(), 'param/pi_model_params.pkl')
        torch.save(self.q1_model.state_dict(), 'param/q1_model_params.pkl')
        torch.save(self.q2_model.state_dict(), 'param/q2_model_params.pkl')

    def load_param(self):
        self.pi_model.load_state_dict(torch.load('param/pi_model_params.pkl'))
        self.q1_model.load_state_dict(torch.load('param/q1_model_params.pkl'))
        self.q2_model.load_state_dict(torch.load('param/q2_model_params.pkl'))

In [7]:
# Available buttons                           # Доступные кнопки
# available_buttons =                         # доступные_кнопки =
# {                                           # {
#     ATTACK                                  #     АТАКА
#     SPEED                                   #     СКОРОСТЬ
#     STRAFE                                  #     ОБХОД

#     MOVE_RIGHT                              #     ДВИЖЕНИЕ_ВПРАВО
#     MOVE_LEFT                               #     ДВИЖЕНИЕ_ВЛЕВО
#     MOVE_BACKWARD                           #     ДВИЖЕНИЕ_НАЗАД
#     MOVE_FORWARD                            #     ДВИЖЕНИЕ_ВПЕРЕД
#     TURN_RIGHT                              #     ПОВЕРНИ_ВПРАВО
#     TURN_LEFT                               #     ПОВЕРНИ_ВЛЕВО

#     SELECT_WEAPON1                          #     ВЫБЕРИ_ОРУЖИЕ 1
#     SELECT_WEAPON2                          #     ВЫБЕРИ_ОРУЖИЕ 2
#     SELECT_WEAPON3                          #     ВЫБЕРИ_ОРУЖИЕ 3
#     SELECT_WEAPON4                          #     ВЫБЕРИ_ОРУЖИЕ 4
#     SELECT_WEAPON5                          #     ВЫБЕРИ_ОРУЖИЕ 5
#     SELECT_WEAPON6                          #     ВЫБЕРИ_ОРУЖИЕ 6

#     SELECT_NEXT_WEAPON                      #     ВЫБОР_СЛЕДУЮЩЕГО_ОРУЖИЯ
#     SELECT_PREV_WEAPON                      #     ВЫБОР_ПРЕДЫДУЩЕГО_ОРУЖИЯ

#     LOOK_UP_DOWN_DELTA                      #     СМОТРЕТЬ_ВВЕРХ_ВНИЗ_ПО_ДЕЛЬТЕ
#     TURN_LEFT_RIGHT_DELTA                   #     ПОВЕРНУТЬ_НАЛЕВО_НАПРАВО_ТРЕУГОЛЬНИКОМ
#     MOVE_LEFT_RIGHT_DELTA                   #     ПЕРЕМЕЩЕНИЕ_ВЛЕВО_ВПРАВО_ДЕЛЬТА

# }                                           # }

# Game variables that will be in the state    # Игровые переменные, которые будут находиться в состоянии
# available_game_variables =                  # доступные игровые переменные =
# {                                           # {
#     KILLCOUNT                               #     КОЛИЧЕСТВО УБИТЫХ
#     HEALTH                                  #     ЗДОРОВЬЕ
#     ARMOR                                   #     БРОНЯ
#     SELECTED_WEAPON                         #     ВЫБРАННОЕ_ОРУЖИЕ
#     SELECTED_WEAPON_AMMO                    #     ВЫБРАННОЕ_ОРУЖИЕ_АММО
# }                                           # }
# mode = PLAYER                               # режим = ИГРОК

In [79]:
class Env():

    def __init__(self):
        self.env = gym.make('VizdoomDeathmatch-v0')
        self.seed = 123
        self.img_stack = 4
        self.game_var = 1
        self.action_repeat = 8
        self.gamevariables_prev = [0., 100., 0., 2., 50.]

    def reset(self):
        self.counter = 0
        self.done = False
        state, _ = self.env.reset()

        img_rgb = state['screen']
        img_gray = self.rgb2gray(img_rgb)
        self.stack = [img_gray] * self.img_stack  # four frames for decision

        gamevariables = state['gamevariables']
        self.gv_stack = [gamevariables] * self.game_var  # four frames for decision
        state_s = np.array(self.stack)
        state_gv = np.array(self.gv_stack)
        # print('Start', state['gamevariables']) # [  0. 100.   0.   2.  50.]
        return state_s, state_gv
        
    #  KILLCOUNT HEALTH   ARMOR SELECTED_WEAPON SELECTED_WEAPON_AMMO
    # [killcount health   armor weapon ammo      ]
    # [кол-во    здоровье броня оружие боеприпасы]
    # [0.        100.     0.    2.     50.       ]

    def step(self, action_b, action_c):
        action = {'binary': action_b, 'continuous': action_c}
        state, reward, done, _, _ = self.env.step(action)
        img_rgb = state['screen']
        gamevariables = state['gamevariables']
        self.gamevariables_prev = gamevariables
        img_gray = self.rgb2gray(img_rgb)
        self.stack.pop(0)
        self.stack.append(img_gray)
        self.gv_stack.pop(0)
        self.gv_stack.append(gamevariables)
        assert len(self.stack) == self.img_stack
        assert len(self.gv_stack) == self.game_var
        return np.array(self.stack), np.array(self.gv_stack), reward, done

    def render(self, *arg):
        self.env.render(*arg)

    @staticmethod
    def rgb2gray(rgb, norm=True):
        gray = np.dot(rgb[..., :], [0.299, 0.587, 0.114]) # rgb image -> gray [0, 1]
        if norm:
            gray = gray / 128. - 1. # normalize
        return gray

In [None]:
env = Env()

In [85]:
agent = SAC()

episode_n = 100
reward_max = 0
total_rewards = []
reward_threshold = 100

In [None]:
for episode in range(episode_n):

    total_reward = 0
    state_s, state_gv = env.reset()
    
    for t in range(2000):
        # _, _ = env.reset()
        action_b, action_c = agent.get_action(state_s, state_gv)
        action_b = action_b.squeeze().detach().numpy()
        action_c = action_c.squeeze().detach().numpy()
        state_s_next, state_gv_next, reward, done = env.step(action_b, action_c)

        agent.fit(state_s, state_gv, action_b, action_c, reward, done, state_s_next, state_gv_next)
        
        total_reward += reward
        state_s = state_s_next
        state_gv = state_gv_next

        # print('Шаг', t+1, end='\r')
        if done:
            break

    total_rewards.append(total_reward)

    if episode < 50:
        reward_50 = np.mean(total_rewards[:episode])
    else:
        reward_50 = np.mean(total_rewards[episode-50:episode])
    if reward_max < reward_50:
        reward_max = reward_50
        agent.save_param()
        print(f'Эпизод {episode+1}\tСохранение агента,\treward-50-max: {reward_max:.2f}                                         ')
    
    print(f'Эпизод {episode+1}\tПоследний reward: {total_reward:.2f}\treward-50: {reward_50:.2f}\tt: {t+1}              ', end='\r')

    if reward_50 > reward_threshold:
        print(f'Эпизод {episode+1}\tОбучение остановлено. \tПоследний score: {total_reward:.2f}\treward-50-max: {reward_max:.2f}')
        break

print('\n', reward_max)

In [146]:
agent.save_param()

In [None]:
# Generate recent 50 interval average
average_reward = []
for idx in range(20, len(total_rewards)):
    avg_list = np.empty(shape=(1,), dtype=int)
    if idx < 50:
        avg_list = total_rewards[:idx+1]
    else:
        avg_list = total_rewards[idx-49:idx+1]
    average_reward.append(np.average(avg_list))
plt.plot(total_rewards)
plt.plot(average_reward)

## Тестирование агента

In [27]:
import os
import PIL
import shutil as sh
from glob import glob
from base64 import b64encode
from datetime import datetime

In [28]:
env_dir = './vizdoom-videos'
scenarios_dir = env_dir # os.path.join(env_dir, 'scenarios')
if not os.path.exists(scenarios_dir):
    os.makedirs(scenarios_dir)

In [29]:
def record_episode(eps_frames, records_dir, exp_id, eps_num):
    record_path = os.path.join(records_dir, f'record-{exp_id}_eps-{eps_num}.mp4')
    eps_frame_dir = './vizdoom-videos/episode_frames'
    if not os.path.exists(eps_frame_dir):
        os.mkdir(eps_frame_dir)

    for i, frame  in enumerate(eps_frames):
        PIL.Image.fromarray(frame).save(os.path.join(eps_frame_dir, f'frame-{i+1}.png'))

    os.system(f'ffmpeg -r 30 -i {eps_frame_dir}/frame-%1d.png -vcodec libx264 -b 10M -y "{record_path}"')
    # sh.rmtree(eps_frame_dir)


def show_episode_records(records_dir):
    record_paths = glob(os.path.join(records_dir, "*.mp4"))
    html_str = ''
    for i, record_path in enumerate(record_paths):
        mp4 = open(record_path, 'rb').read()
        data = f"data:video/mp4;base64,{b64encode(mp4).decode()}"
        html_str += f'EPISODE # {i+1}<br><video width=500 controls><source src="{data}" type="video/mp4"></video><br><br>'
    return HTML(html_str)

In [94]:
def run_episode(env, agent):
    state, img_rgb = env.reset()
    step_count = 0
    frames = list()
    done = False
    frames.append(img_rgb)
    while not done:
        # print(type(state['screen']))
        action, _ = agent.select_action(state)
        next_screen, next_gamevariables, reward, done, img_rgb_, = env.step(action) # terminated, truncated,
        frames.append(img_rgb_)
        # done = terminated or truncated
        state = {'screen': next_screen, 'gamevariables': next_gamevariables}
        step_count += 1

    # frames.extend([env.render()] * 3)
    return frames

def evaluate_agent(env, agent, exp_dir, n_eval=1):
    for eps_num in range(1, n_eval + 1):
        eps_frames = run_episode(env, agent)
        print('frames =', len(eps_frames))
        # record_episode(eps_frames, exp_dir, f'exp-{datetime.now().strftime("%d%m-%H%M%S")}', eps_num)

In [None]:
# env = Env() # gym.make('VizdoomDeathmatch-v0', render_mode='rgb_array')
evaluate_agent(env, agent, env_dir)
# env.close()

In [None]:
from moviepy import ImageSequenceClip
eps_frames = run_episode(env, agent)
record_path = os.path.join(env_dir, f'record-exp-{datetime.now().strftime("%d%m-%H%M%S")}.mp4')
# Настройки
fps = 24          # Частота кадров в секунду

# Сборка клипа из изображений
clip = ImageSequenceClip(eps_frames, fps=fps)

# Сохранение видеофайла
clip.write_videofile(record_path)

In [None]:
from IPython import display, HTML

In [None]:
show_episode_records(env_dir) # exp_dirs['evaluation']

## Проверка окружения

In [None]:
input_example = torch.rand(4,240,320)
net_a = ActorNet()
net_a(input_example)

(tensor([[1.6724, 1.5521, 1.6525]], grad_fn=<AddBackward0>),
 tensor([[1.7859, 1.7419, 1.6252]], grad_fn=<AddBackward0>))

In [25]:
input_example = torch.rand(4,240,320)
net_v = ValueNet()
net_v(input_example)

tensor([[0.0629]], grad_fn=<AddmmBackward0>)

In [None]:
env_doom = gym.make('VizdoomDeathmatch-v0', render_mode="rgb_array")
action_doom = env_doom.action_space
obs_doom = env_doom.observation_space
print('action', action_doom)
print('obs', obs_doom)
print(env_doom.spec.reward_threshold)

In [None]:
state_doom = env_doom.reset()
random_action = action_doom.sample()