In [2]:
import numpy as np
from stable_baselines3.common.vec_env import VecEnv, VecEnvStepReturn
from poke_env.player.random_player import RandomPlayer
import asyncio
import gymnasium as gym

class PokeVecEnv(VecEnv):
    def __init__(self, num_envs, embed_battle):
        self.players = [RandomPlayer(battle_format="gen1randombattle") for _ in range(num_envs)]
        self.opponents = [RandomPlayer(battle_format="gen1randombattle") for _ in range(num_envs)]
        self.embed_battle = embed_battle

        # Define action_space and observation_space
        action_space = gym.spaces.Discrete(9)  # Assuming 4 moves + 5 switches
        observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32)

        super().__init__(num_envs, observation_space, action_space)

    def step_async(self, actions):
        self.actions = actions
        for player, action in zip(self.players, actions):
            asyncio.run(player.choose_move(action))

    def step_wait(self):
        observations, rewards, dones, infos = [], [], [], []
        for player in self.players:
            asyncio.run(player.complete_turn())
            observation = self.embed_battle(player.current_battle)
            reward = 1 if player.current_battle.won else 0
            done = player.current_battle.finished
            info = {}

            observations.append(observation)
            rewards.append(reward)
            dones.append(done)
            infos.append(info)

        return np.array(observations), np.array(rewards), np.array(dones), infos

    def reset(self):
        observations = []
        for player, opponent in zip(self.players, self.opponents):
            asyncio.run(player.battle_against(opponent, n_battles=1))
            observations.append(self.embed_battle(player.current_battle))
        return np.array(observations)

    def close(self):
        pass  # Optional: Implement closing procedures

    def render(self, mode='human'):
        pass  # Optional: Implement rendering if needed

ImportError: cannot import name 'VecEnvStepReturn' from 'stable_baselines3.common.vec_env' (/Users/tonysun/.pyenv/versions/3.11.5/lib/python3.11/site-packages/stable_baselines3/common/vec_env/__init__.py)

In [3]:
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.data import GenData

def embed_battle(self, battle: AbstractBattle):
    # -1 indicates that the move does not have a base power
    # or is not available
    moves_base_power = -np.ones(4)
    moves_dmg_multiplier = np.ones(4)
    for i, move in enumerate(battle.available_moves):
        moves_base_power[i] = (
            move.base_power / 100
        )  # Simple rescaling to facilitate learning
        if move.type:
            moves_dmg_multiplier[i] = move.type.damage_multiplier(
                battle.opponent_active_pokemon.type_1,
                battle.opponent_active_pokemon.type_2,
                type_chart=GenData.from_gen(8).type_chart,
            )

    # We count how many pokemons have fainted in each team
    fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
    fainted_mon_opponent = (
        len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
    )

    # Final vector with 10 components
    final_vector = np.concatenate(
        [
            moves_base_power,
            moves_dmg_multiplier,
            [fainted_mon_team, fainted_mon_opponent],
        ]
    )
    return np.float32(final_vector)