In [None]:
import gymnasium as gym
from pystk2_gymnasium import AgentSpec
import tqdm
import ipyparallel
import torch
import torch.nn.functional as F

# Connect to the IPython cluster
client = ipyparallel.Client()

# Connect to all engines
dview = client[:]

def run_episode(arg):
    track, difficulty = arg
    import torch
    import torch.nn.functional as F

    def preprocess_observation(obs):
        """Convert mixed observation space to flat tensor"""
        continuous_obs, discrete_obs = obs['continuous'], obs['discrete']
        continuous_tensor = torch.FloatTensor(continuous_obs)
        discrete_tensors = [
            F.one_hot(torch.tensor(x), num_classes=num_classes.n) 
            for x, num_classes in zip(discrete_obs, env.observation_space['discrete'])
        ]
        return torch.cat([continuous_tensor] + discrete_tensors)

    import gymnasium as gym
    from pystk2_gymnasium import AgentSpec
    
    records = []

    env = gym.make(
            "supertuxkart/flattened_multidiscrete-v0", 
            render_mode=None, 
            agent=AgentSpec(use_ai=True), 
            track=track, 
            difficulty=difficulty,
            laps=3,
    )

    ix = 0
    done = False
    obs, *_ = env.reset()
    prev_obs = obs

    while not done:
        ix += 1
        action = env.action_space.sample()          
        next_obs, reward, done, truncated, _ = env.step(action)
        action = next_obs['action']

        records.append(
            {
                'prev_obs':preprocess_observation(prev_obs), 
                'obs':preprocess_observation(obs), 
                'actions':torch.tensor(action), 
                'reward':torch.tensor(reward), 
                'next_obs':preprocess_observation(next_obs), 
                'done':torch.tensor(float(done or truncated)),
                'track': track,
                'step': ix-1,
            }
        )
        prev_obs = obs
        obs = next_obs

    env.close()
    return records

# Push the run_episode function to the engines
dview.push({'run_episode': run_episode})

def parallel_run_episodes(num_episodes):
    # Use `map` to run the function on the cluster
    results = dview.map(run_episode, args)

    # Flatten results into individual lists
    records = []
    for rec in results:
        records.extend(rec)
        print(len(rec),)

    return records

# Number of episodes to run in parallel
tracks = [
    'abyss',
    'black_forest',
    'candela_city',
    'cocoa_temple',
    'cornfield_crossing',
    'fortmagma',
    'gran_paradiso_island',
    'hacienda',
    'lighthouse',
    'mines',
    'minigolf',
    'olivermath',
    'ravenbridge_mansion',
    'sandtrack',
    'scotland',
    'snowmountain',
    'snowtuxpeak',
    'stk_enterprise',
    'volcano_island',
    'xr591',
    'zengarden'
]

args = []
nb_runs = 1
for difficulty in [0,1,2]:
    for track in tracks:
        for _ in range(nb_runs):
            args.append((track, difficulty))

records = parallel_run_episodes(args)
print(len(records))

In [None]:
from stk_actor.replay_buffer import SACRolloutBuffer, calculate_total_obs_dim

buffer_size = len(records)

env = gym.make(
    "supertuxkart/flattened_multidiscrete-v0",
    render_mode=None,
    agent=AgentSpec(use_ai=False, name="walid"),
    track='abyss',
    num_kart=2,
    difficulty=0
)

obs_dim = calculate_total_obs_dim(env.observation_space)
action_dims = [space.n for space in env.action_space]

buffer = SACRolloutBuffer(
    buffer_size,
    obs_dim=calculate_total_obs_dim(env.observation_space),
    action_dims=[space.n for space in env.action_space]
)

env.close()

buffer_size


In [None]:
import tqdm 
for i in tqdm.tqdm(list(range(min(len(records), buffer_size)))):
    buffer.add(**records[i])

In [None]:
import joblib
joblib.dump(buffer,'all_tracks_buffer_steps_2mil', compress=3)