In [None]:
import numpy as np
import gymnasium
import asyncio
from poke_env.player.random_player import RandomPlayer

class PokeGymEnv(gymnasium.Env):

    REPLAY_FOLDER = "gen1_replays"  # Define the replay folder

    def __init__(self):
        super().__init__()

        # Define action and observation spaces
        self.action_space = gymnasium.spaces.Discrete(4)  # Assuming 4 possible actions
        self.observation_space = gymnasium.spaces.Box(low=-1, high=np.inf, shape=(10,), dtype=np.float32)

        # Initialize both players as random players for Gen 1 random battles
        self.player = RandomPlayer(
            battle_format="gen1randombattle",
            save_replays=True,
            replay_folder=f"replays/{self.REPLAY_FOLDER}"
        )
        self.opponent = RandomPlayer(battle_format="gen1randombattle")


    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        # Reset the environment to a new battle
        self.current_battle = asyncio.run(self.player.battle_against(self.opponent, n_battles=1))
        return self.embed_battle(self.current_battle), {}

    async def async_step(self, action):
        # Assuming choose_move is an async method and takes 'action' as an argument
        await self.player.choose_move(action)
        observation = self.embed_battle(self.current_battle)
        reward = 1 if self.current_battle.won else 0
        terminated = self.current_battle.finished
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def step(self, action):
        return asyncio.run(self.async_step(action))

    def close(self):
        # Cleanup (optional)
        pass

In [None]:
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.data import GenData

def embed_battle(self, battle: AbstractBattle):
    # -1 indicates that the move does not have a base power
    # or is not available
    moves_base_power = -np.ones(4)
    moves_dmg_multiplier = np.ones(4)
    for i, move in enumerate(battle.available_moves):
        moves_base_power[i] = (
            move.base_power / 100
        )  # Simple rescaling to facilitate learning
        if move.type:
            moves_dmg_multiplier[i] = move.type.damage_multiplier(
                battle.opponent_active_pokemon.type_1,
                battle.opponent_active_pokemon.type_2,
                type_chart=GenData.from_gen(8).type_chart,
            )

    # We count how many pokemons have fainted in each team
    fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
    fainted_mon_opponent = (
        len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
    )

    # Final vector with 10 components
    final_vector = np.concatenate(
        [
            moves_base_power,
            moves_dmg_multiplier,
            [fainted_mon_team, fainted_mon_opponent],
        ]
    )
    return np.float32(final_vector)

In [None]:
# Example usage
if __name__ == "__main__":
    env = PokeGymEnv()
    observation, info = env.reset()
    for _ in range(1000):
        action = env.action_space.sample()
        observation, reward, terminated, truncated, info = env.step(action)
        if terminated or truncated:
            print(f"Battle ended. Reward: {reward}")
            observation, info = env.reset()
    env.close()