In [1]:
import pygame
from jax import random
from functools import partial
import darkdetect
import jax.numpy as jnp
from chex import dataclass
import jaxmarl
from typing import Tuple, List, Any, Optional, Dict as DictType
import parabellum as pb

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


2024-06-23 18:48:38.939 python[38248:351853] ApplePersistence=NO


In [2]:
fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)

In [3]:
# parabellum setup
env = pb.Parabellum(pb.scenarios["default"])

In [4]:
# types
State = jaxmarl.environments.smax.smax_env.State
Dict = DictType[str, jnp.ndarray]
StateSeq = List[Tuple[jnp.array, State, Dict]]

In [5]:
@dataclass
class Control:
    running: bool = True
    paused: bool = False
    click: bool = None


@dataclass
class Game:
    clock: pygame.time.Clock
    state: State
    obs: Dict
    state_seq: StateSeq
    control: Control
    env: pb.Parabellum
    rng: random.PRNGKey

In [6]:
def handle_event(event, control_state):
    if event.type == pygame.QUIT:
        control_state.running = False
    if event.type == pygame.MOUSEBUTTONDOWN:
        pos = pygame.mouse.get_pos()
        control_state.click = pos
    if event.type == pygame.KEYDOWN:  # any key press pauses
        control_state.paused = not control_state.paused
    return control_state

In [7]:
def control_fn(game):
    game.control.click = None  # reset click
    for event in pygame.event.get():
        game.control = handle_event(event, game.control)
    return game

In [8]:
def render_fn(screen, game):
    if len(game.state_seq) < 3:
        return game
    for rng, state, action in env.expand_state_seq(game.state_seq[-2:])[-8:]:
        screen.fill(bg)
        unit_positions = state.unit_positions
        for pos in unit_positions:
            pos = (pos / env.map_width * 800).tolist()
            pygame.draw.circle(screen, fg, pos, 5)
        pygame.display.flip()
        game.clock.tick(24)  # limits FPS to 24
    return game

In [9]:
def step_fn(game):
    rng, act_rng, step_key = random.split(game.rng, 3)
    act_key = random.split(act_rng, env.num_agents)
    action = {
        a: env.action_space(a).sample(act_key[i]) for i, a in enumerate(env.agents)
    }
    state_seq_entry = (step_key, game.state, action)
    # append state_seq_entry to state_seq
    game.state_seq.append(state_seq_entry)
    obs, state, reward, done, info = env.step(step_key, game.state, action)
    game.state = state
    game.obs = obs
    game.rng = rng
    return game

In [10]:
# pygame setup
pygame.init()
screen = pygame.display.set_mode((1280, 720))
render = partial(render_fn, screen)
clock = pygame.time.Clock()
rng, key = random.split(random.PRNGKey(0))
obs, state = env.reset(key)
kwargs = dict(
    control=Control(),
    env=env,
    rng=rng,
    state_seq=[],  # [(key, state, action)]
    clock=clock,
    state=state,
    obs=obs,
)
game = Game(**kwargs)

In [11]:
while game.control.running:
    game = control_fn(game)
    game = game if game.control.paused else step_fn(game)
    game = game if game.control.paused else render(game)

pygame.quit()

KeyboardInterrupt: 