---

# <center> GFootball Stable-Baselines3 </center>

---
<center><img src="https://raw.githubusercontent.com/DLR-RM/stable-baselines3/master/docs/_static/img/logo.png" width="308" height="268" alt="Stable-Baselines3"></center>
<center><small>Image from Stable-Baselines3 repository</small></center>

---
This notebook uses the [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) library to train a [PPO](https://openai.com/blog/openai-baselines-ppo/) reinforcement learning agent on [GFootball Academy](https://github.com/google-research/football/tree/master/gfootball/scenarios) scenarios, applying the architecture from the paper "[Google Research Football: A Novel Reinforcement Learning Environment](https://arxiv.org/abs/1907.11180)".

In [None]:
import sys
sys.path.append("..")
# sys.path.append("../imitation_learning")
import os
import base64
import pickle
import zlib
import gym
import numpy as np
import pandas as pd
import torch as th
from torch import nn, tensor
from collections import deque
from gym.spaces import Box, Discrete
# from kaggle_environments import make
# from kaggle_environments.envs.football.helpers import *
from gfootball.env import create_environment, observation_preprocessing, wrappers
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy
from stable_baselines3.common import results_plotter
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from stable_baselines3.common.policies import BasePolicy, register_policy
from IPython.display import HTML
import time
from datetime import date
# from visualizer import visualize
from matplotlib import pyplot as plt
from stable_baselines3 import DQN
import torch
from models.MlpClassifierModel import MlpClassifierModel
%matplotlib inline

In [None]:
torch.manual_seed(42)
torch.manual_seed(torch.initial_seed())

---
# Football Gym
> [Stable-Baselines3: Custom Environments](https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html)<br/>
> [SEED RL Agent](https://www.kaggle.com/piotrstanczyk/gfootball-train-seed-rl-agent): stacked observations

In [None]:
class FootballGym(gym.Env):
    spec = None
    metadata = None
#     metadata = {'render.modes': ['human']}
    
    def __init__(self, config=None, render=False, rewards='scoring'):
        super(FootballGym, self).__init__()
        env_name = "academy_empty_goal_close"
#         rewards = "scoring,checkpoints"

        rewards = rewards
        if config is not None:
            env_name = config.get("env_name", env_name)
            rewards = config.get("rewards", rewards)
        self.env = create_environment(
            env_name=env_name,
            stacked=False,
            representation="simple115v2",
            rewards = rewards,
            write_goal_dumps=False,
            write_full_episode_dumps=False,
            render=render,
            write_video=False,
            dump_frequency=1,
            logdir=".",
            extra_players=None,
            number_of_left_players_agent_controls=1,
            number_of_right_players_agent_controls=0)
        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space
        self.reward_range = (-1, 1)
        self.obs_stack = deque([], maxlen=4)
        
    def reset(self):
        self.obs_stack.clear()
        obs = self.env.reset()
#         obs = self.transform_obs(obs)
        return obs
    
    def step(self, action):
        obs, reward, done, info = self.env.step([action])
#         obs = self.transform_obs(obs)
        return obs, float(reward), done, info
    
# check_env(env=FootballGym(), warn=True)

In [None]:
import pytorch_lightning as pl

In [None]:
class FootballMLP(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=115):
        super().__init__(observation_space, features_dim)
        self.mlp = MlpClassifierModel(hparams, input_size=115, p_dropout=0.25, num_classes=19)
        
    def forward(self, input_tensor):
        return self.mlp(input_tensor)  

In [None]:
hparams = {}
hparams['hidden_size'] = 1024
hparams['lr'] = 2e-3
hparams['lr_decay_rate'] = 0.25
hparams['batch_size'] = 256
hparams['activation'] = 'GELU'
# hparams['activation'] = 'ReLU'
# model = MLPModel(hparams).to('cuda')
# model = MlpClassifierModel(hparams).to('cuda')

## DDQN Model

In [None]:
from torch.nn import functional as F

class DDQN(DQN):
    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            ### YOUR CODE HERE
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            # Do not backpropagate gradient to the target network
            with th.no_grad():
#                 print(f"replay data actions: {replay_data.actions.shape} and data: {replay_data.actions}")
                # Compute the next Q-values using the target network
                next_q_values = self.q_net_target(replay_data.next_observations)
#                 print(f"Next Q values shape calculated by target net: {next_q_values.shape}")
                # Decouple action selection from value estimation
                # Compute q-values for the next observation using the online q net
                next_q_values_online = self.q_net(replay_data.next_observations)
#                 print(f"Next Q values shape calculated by Online net: {next_q_values_online.shape}")
                # Select action with online network
                next_actions_online = next_q_values_online.argmax(dim=1)
#                 print(f"Next Actions shape calculated by Online net: {next_actions_online.shape}")
                # Estimate the q-values for the selected actions using target q network
                next_q_values = th.gather(next_q_values, dim=1, index=next_actions_online.unsqueeze(-1))
#                 print(f"Next Q values calculated by Target net from the selected actions: {next_q_values.shape}")
               
                # 1-step TD target
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates
            current_q_values = self.q_net(replay_data.observations)

            # Retrieve the q-values for the actions from the replay buffer
            current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())

            # Check the shape
            assert current_q_values.shape == target_q_values.shape

            # Compute loss (L2 or Huber loss)
            loss = F.smooth_l1_loss(current_q_values, target_q_values)

            ### END OF YOUR CODE
            
            losses.append(loss.item())

            # Optimize the q-network
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))

### Initializing the Vanilla DQN model

In [None]:
# # This initialization concentrates most of the layers into the feature extractor
# # and then leaves only a single layer for the prediction part

# policy_kwargs = dict(features_extractor_class=FootballMLP,
#                      features_extractor_kwargs=dict(features_dim=1024),
#                     net_arch = [],
#                     )
# model_name = "dqn"
# model = DQN(policy="MlpPolicy", 
#             env=train_env, 
#             policy_kwargs=policy_kwargs, 
#             verbose=1,
#             exploration_initial_eps=0.00,
#             exploration_final_eps=0.0,
#             target_update_interval=15000,
# #             learning_rate=hparams['lr'],
# #             batch_size=hparams['batch_size'],
#             seed=42,
#             tensorboard_log='tb_logs_DQN',
#             learning_starts=100000,
#            )
# model.policy

In [None]:
# # Configuring the DQN net like SB3's network structure
# # Only one layer for the feature extraction and the rest for the Q value computation
# from stable_baselines3 import DQN

# policy_kwargs = dict(
#     net_arch = [1024, 1024, 1024],
#     activation_fn = torch.nn.GELU
# )
# model_name = "dqn"
# model = DQN(policy="MlpPolicy", 
#             env=train_env, 
#             policy_kwargs=policy_kwargs, 
#             verbose=1,
#             exploration_initial_eps=0.00,
#             exploration_final_eps=0.0,
#             target_update_interval=15000,
# #             learning_rate=hparams['lr'],
# #             batch_size=hparams['batch_size'],
#             seed=42,
#             tensorboard_log='tb_logs_DQN',
#             learning_starts=100000,
#            )
# model.policy
           

### Double DQN

In [None]:
# # DDQN

# policy_kwargs = dict(
#     net_arch = [1024, 1024, 1024],
#     activation_fn = torch.nn.ReLU
# )

# # policy_kwargs = dict(features_extractor_class=FootballMLP,
# #                      features_extractor_kwargs=dict(features_dim=1024),
# #                     net_arch = [],
# #                     )
# model_name = "ddqn"
# model = DDQN(policy="MlpPolicy", 
#             env=train_env, 
#             policy_kwargs=policy_kwargs, 
#             verbose=1,
#             exploration_initial_eps=0.05,
#             exploration_final_eps=0.05,
#             target_update_interval=150000,
# #             learning_rate=0.0000001,
# #             batch_size=hparams['batch_size'],
#             seed=42,
#             tensorboard_log='tb_logs_DDQN',
#             train_freq=3002,
# #             learning_starts=100000,
#            )

# # With cartpole initialization
# # model = DDQN(policy="MlpPolicy", 
# #             env=train_env, 
# #             learning_rate=2.3e-3,
# #             batch_size=64,
# #             buffer_size=100000,
# #             learning_starts=1000,
# #             gamma=0.99,
# #             target_update_interval=10,
# #             train_freq=256,
# #             gradient_steps=128,
# #             exploration_fraction=0.16,
# #             exploration_final_eps=0.04,
# #             policy_kwargs=policy_kwargs,
# #              tensorboard_log='tb_logs_DDQN',
# #              seed=42,
# #              verbose=1,
# #            )
# model.policy
           

In [None]:
%reload_ext autoreload
%autoreload 2

### Tetsing the Agent

In [None]:
model = DDQN.load("../models/ddqn/ddqn_gfootball_8_20-05-2022-23-37-05.zip")

In [None]:
scenarios = {0: "academy_empty_goal_close",
             1: "academy_empty_goal",
             2: "academy_run_to_score",
             3: "academy_run_to_score_with_keeper",
             4: "academy_pass_and_shoot_with_keeper",
             5: "academy_run_pass_and_shoot_with_keeper",
             6: "academy_3_vs_1_with_keeper",
             7: "academy_corner",
             8: "academy_counterattack_easy",
             9: "academy_counterattack_hard",
             10: "academy_single_goal_versus_lazy",
             11: "11_vs_11_kaggle",
             12: "11_vs_11_stochastic",
             13: "11_vs_11_easy_stochastic",
             14: "11_vs_11_hard_stochastic"}

scenario_name = scenarios[13]

In [None]:
action_set = {
    0: "idle",
    1: "left",
    2: "top_left",
    3: "top",
    4: "top_right",
    5: "right",
    6: "bottom_right",
    7: "bottom",
    8: "bottom_left",
    9: "long_pass",
    10: "high_pass",
    11: "short_pass",
    12: "shot",
    13: "sprint",
    14: "release_direction",
    15: "release_sprint",
    16: "sliding",
    17: "dribble",
    18: "release_dribble",
}

In [None]:
def play_match_with_dqn_agent(test_env):
    obs = test_env.reset()
    env_steps = 0
    match_reward = 0
    done = False
    my_agent = 0
    ai_agent = 0
#     model.eval()
    while not done:
        action, state = model.predict(obs, deterministic=True)
#         print(action)
#         action, state = model.predict(obs)
        obs, reward, done, info = test_env.step(action)
        if reward > 0:
            my_agent += 1
        elif reward < 0:
            ai_agent += 1
#             print(f"Step: {str(env_steps).ljust(10, ' ')}\t{str(action_set[action.item()]).ljust(10, ' ')}\t{round(reward,2)}\t{info}")
        env_steps += 1
        if (env_steps+1) % 3001  == 0:
#             print(f"Match reward: {match_reward}")
            return my_agent, ai_agent
    

In [None]:
# model = PPO.load("ppo_gfootball")

# scenario_names = [scenarios[11], scenarios[12], scenarios[13], scenarios[14]]
scenario_names = [scenarios[13]]
# scenario_names = [scenarios[12], scenarios[13]]
total_matches = 10
with torch.no_grad():
    for scenario_name in scenario_names:
        test_env = FootballGym({"env_name":scenario_name}, render=False)
        win_count = 0
        draw_count = 0
        total_goals_scored = 0
        total_goal_conceded = 0
        print(f"-------------------------------------------------")
        print(f"Playing in the {scenario_name} scenario:")
        print(f"-------------------------------------------------")
        
        for i in range(total_matches):
#             my_agent_goals, ai_agent_goals = play_match(test_env)
            my_agent_goals, ai_agent_goals = play_match_with_dqn_agent(test_env)
            total_goals_scored += my_agent_goals
            total_goal_conceded += ai_agent_goals
            if my_agent_goals > ai_agent_goals:
                match_result = "WIN!"
                win_count += 1
            elif my_agent_goals < ai_agent_goals:
                match_result = "DEFEAT!"
            else:
                match_result = "DRAW"
                draw_count += 1
            print(f"Match {i+1}: {match_result.ljust(10, ' ')} | MY AGENT {my_agent_goals} - {ai_agent_goals} AI")
        print(f"Results of {total_matches} Matches in {scenario_name} scenario:")
        print(f"WON: {win_count}, LOST: {total_matches-win_count-draw_count}, DREW: {draw_count} | total goals scored: {total_goals_scored}, total goals conceded: {total_goal_conceded}\n\n")