In [1]:
import gymnasium as gym
from pystk2_gymnasium import AgentSpec
import tqdm
import ipyparallel
import torch
import torch.nn.functional as F

# Connect to the IPython cluster
client = ipyparallel.Client()

# Connect to all engines
dview = client[:]

def run_episode(arg):
    track, difficulty, lap = arg
    import torch
    import torch.nn.functional as F

    def preprocess_observation(obs):
        """Convert mixed observation space to flat tensor"""
        continuous_obs, discrete_obs = obs['continuous'], obs['discrete']
        continuous_tensor = torch.FloatTensor(continuous_obs)
        discrete_tensors = [
            F.one_hot(torch.tensor(x), num_classes=num_classes.n) 
            for x, num_classes in zip(discrete_obs, env.observation_space['discrete'])
        ]
        return torch.cat([continuous_tensor] + discrete_tensors)

    import gymnasium as gym
    from pystk2_gymnasium import AgentSpec
    
    records = []

    env = gym.make(
            "supertuxkart/flattened_multidiscrete-v0", 
            render_mode=None, 
            agent=AgentSpec(use_ai=True), 
            track=track, 
            difficulty=difficulty,
            laps=lap,
    )

    ix = 0
    done = False
    obs, *_ = env.reset()
    prev_obs = obs

    while not done:
        ix += 1
        action = env.action_space.sample()          
        next_obs, reward, done, truncated, _ = env.step(action)
        action = next_obs['action']

        records.append(
            {
                'prev_obs':preprocess_observation(prev_obs), 
                'obs':preprocess_observation(obs), 
                'actions':torch.tensor(action), 
                'reward':torch.tensor(reward), 
                'next_obs':preprocess_observation(next_obs), 
                'done':torch.tensor(float(done or truncated)),
                'track': track,
                'step': ix-1,
            }
        )
        prev_obs = obs
        obs = next_obs

    env.close()
    return records

# Push the run_episode function to the engines
dview.push({'run_episode': run_episode})

def parallel_run_episodes(num_episodes):
    # Use `map` to run the function on the cluster
    results = dview.map(run_episode, args)

    # Flatten results into individual lists
    records = []
    for rec in results:
        records.extend(rec)
        print(len(rec),)

    return records

# Number of episodes to run in parallel
tracks = [
    'abyss',
    'black_forest',
    'candela_city',
    'cocoa_temple',
    'cornfield_crossing',
    'fortmagma',
    'gran_paradiso_island',
    'hacienda',
    'lighthouse',
    'mines',
    'minigolf',
    'olivermath',
    'ravenbridge_mansion',
    'sandtrack',
    'scotland',
    'snowmountain',
    'snowtuxpeak',
    'stk_enterprise',
    'volcano_island',
    'xr591',
    'zengarden'
]

args = []
nb_runs = 10
for lap in [1]:
    for difficulty in [0,1,2]:
        for track in tracks:
            for _ in range(nb_runs):
                args.append((track, difficulty, lap))

records = parallel_run_episodes(args)
print(len(records))

962
989
989
994
975
1007
989
989
986
970
1982
1932
1926
1998
1983
1983
1978
1978
1964
1932
810
824
824
814
830
818
818
818
818
823
880
919
901
925
919
922
899
906
914
869
996
990
952
976
985
998
998
1019
997
988
827
815
799
830
856
828
845
880
796
822
1085
1076
1074
1046
1077
1051
1069
991
1046
1093
915
847
880
888
852
884
888
907
892
903
617
607
611
635
640
659
651
652
651
663
915
865
881
888
872
927
930
894
870
881
575
561
587
550
551
541
551
575
576
596
394
385
368
355
385
386
351
385
374
390
1006
1017
1041
1012
1044
1054
1055
1031
1032
1031
940
904
952
937
937
931
937
931
956
930
879
917
907
913
913
912
914
911
887
925
860
861
856
859
870
868
880
865
868
858
739
765
686
794
728
737
704
657
736
750
1078
1055
1054
1047
1031
1062
1072
1069
1047
1022
1584
1594
1588
1573
1582
1569
1582
1598
1555
1580
1001
918
918
941
912
994
945
933
972
962
523
542
488
534
531
533
553
567
537
525
660
651
647
658
662
672
663
626
635
659
1256
1276
1308
1356
1318
1319
1315
1336
1285
1368
579
516
606
544
62

In [2]:
# records.extend(parallel_run_episodes(args))
# print(len(records))

In [3]:
from stk_actor.replay_buffer import SACRolloutBuffer, calculate_total_obs_dim

buffer_size = len(records)

env = gym.make(
    "supertuxkart/flattened_multidiscrete-v0",
    render_mode=None,
    agent=AgentSpec(use_ai=False, name="walid"),
    track='abyss',
    num_kart=2,
    difficulty=0
)

obs_dim = calculate_total_obs_dim(env.observation_space)
action_dims = [space.n for space in env.action_space]

buffer = SACRolloutBuffer(
    buffer_size,
    obs_dim=calculate_total_obs_dim(env.observation_space),
    action_dims=[space.n for space in env.action_space]
)

env.close()

buffer_size


..:: Antarctica Rendering Engine 2.0 ::..


444252

In [4]:
import tqdm 
for i in tqdm.tqdm(list(range(min(len(records), buffer_size)))):
    buffer.add(**records[i])

100%|██████████| 444252/444252 [00:15<00:00, 28715.79it/s]


In [5]:
import joblib
joblib.dump(buffer,'all_tracks_buffer_steps_1laps', compress=4)

['all_tracks_buffer_steps_1laps']