In [1]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
from pystk2_gymnasium import AgentSpec
from stk_actor.replay_buffer import SACRolloutBuffer
from stk_actor.wrappers import PreprocessObservationWrapper
from pystk2_gymnasium.stk_wrappers import ConstantSizedObservations, PolarObservations, DiscreteActionsWrapper
from pystk2_gymnasium.wrappers import FlattenerWrapper
import torch

class PolicyWrapper(torch.nn.Module):
    def __init__(self, policy_stb, dropout):
        super().__init__()
        self.shared = policy_stb
        self.dropout = torch.nn.Dropout(dropout)
    def forward(self, x):
        x = self.shared.features_extractor(x)
        x = self.shared.mlp_extractor.policy_net(x)
        x = self.dropout(x)
        x = self.shared.action_net(x)
        return x

ll = lambda env : FlattenerWrapper(DiscreteActionsWrapper(PolarObservations(ConstantSizedObservations(env, state_items = 10, state_karts = 10, state_paths = 10,))))

vec_env = make_vec_env(
    "supertuxkart/full-v0", 
    n_envs=1, 
    wrapper_class=lambda x : (PreprocessObservationWrapper(ll(x))), 
    env_kwargs={
        'render_mode':None, 'agent':AgentSpec(use_ai=False, name="walid"), 'difficulty':0,#'track':'abyss', #'num_kart':2, 'difficulty':0
    }
)
vec_env.close()

policy_stb = PPO('MlpPolicy', vec_env, policy_kwargs = dict(
    net_arch=[1024,1024,1024],
    activation_fn=torch.nn.Tanh,
)).policy

obs_dim = vec_env.observation_space.shape[0]
action_dims = [space.n for space in vec_env.action_space]

device = 'mps'
policy = PolicyWrapper(
    policy_stb, 0.2
).to(device)

obs_dim, action_dims

..:: Antarctica Rendering Engine 2.0 ::..


  self.mean = torch.load(mod_path/'buffer_mean', map_location='cpu')
  self.std = torch.load(mod_path/'buffer_std', map_location='cpu',)


(264, [5, 2, 2, 2, 2, 2, 7])

In [2]:
import joblib, random
from torch.optim import AdamW, Adam

batch_size = 1024 * 8
num_epochs = 1000
device = 'mps'

policy = policy.to(device)

optimizer = AdamW(policy.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss(reduction='none')

buffer1 = joblib.load('all_tracks_buffer_steps_diff2_obs10x10x10')
buffer2 = joblib.load('all_tracks_buffer_steps_diff2_obs10x10x10x')
buffer3 = joblib.load('all_tracks_buffer_steps_diff2_obs10x10x10xx')
buffer4 = joblib.load('all_tracks_buffer_steps_diff2_obs10x10x10xxx')
buffer5 = joblib.load('all_tracks_buffer_steps_diff2_obs10x10x10xhard')




In [3]:
import numpy as np
# Combine buffers

start_step_id = 30

buffers = [buffer1, buffer2, buffer3, buffer4, buffer5]

size = sum([b.size for b in buffers])
observations = torch.cat([b.observations[:b.size] for b in buffers], dim=0)
steps = torch.cat([b.steps[:b.size] for b in buffers], dim=0)
actions = torch.cat([torch.stack([actions for actions in b.actions]) for b in buffers], dim=1)

all_indices = np.arange(0, size)
all_indices = all_indices[steps[all_indices].flatten()>start_step_id]
observations = observations[all_indices]
actions = actions[:,all_indices]

_,unique_indices = np.unique(observations.numpy(), axis=0, return_index=True)
unique_indices = torch.tensor(unique_indices)

unique_observations = observations[unique_indices]
unique_actions = actions[:, unique_indices]

# unique_observations = observations
# unique_actions = actions

size = unique_observations.size(0)

del buffer1
del buffer2
del buffer3
del buffer4
del buffer5

del observations
del actions


def get_batches(batch_size, ):
    all_indices = np.arange(0, size)
    num_batches = int(np.ceil(len(all_indices) / batch_size))
    np.random.shuffle(all_indices)
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(all_indices))
        indices = all_indices[start_idx:end_idx]
        yield (
            unique_observations[indices],
            unique_actions[:,indices],
        )




In [4]:

mean = unique_observations.mean(dim=0).to(device)

std = unique_observations.std(dim=0).to(device)


mean.shape, std.shape, unique_observations.shape

(torch.Size([264]), torch.Size([264]), torch.Size([1425899, 264]))

In [5]:
policy.dropout = torch.nn.Identity()
'a'

'a'

In [None]:
for epoch in range(num_epochs):

    policy.train()
    total_loss = 0
    num_batches = 0

    batches = [[torch.tensor(b).to(device) for b in batch] for batch in list(get_batches(batch_size))]
    action_correct = [0] * len(action_dims)
    action_total = [0] * len(action_dims)
    
    random.shuffle(batches)
    
    for batch in batches:
        # obs, actions, rewards, next_obs, prev_obs, dones, _, _ = batch
        obs, actions = batch
        obs = obs 
        # mean = obs.mean(dim=1, keepdim=True)
        # std = obs.std(dim=1, keepdim=True)
        obs = (obs - mean.unsqueeze(0)) / (std.unsqueeze(0) + 1e-8)
        # obs = obs + torch.randn_like(obs) * 0.01

        actions = actions.permute(1, 0)

        outputs = policy(obs)
        optimizer.zero_grad()
        split_logits = torch.split(outputs, action_dims, dim=-1)
        
        losses = []
        # loss_weights = ([0,0,0,0,0,0,1]) # only train steering
        loss_weights = ([1,1,1,1,1,1,1]) # train all actions
        loss_weights = [x/sum(loss_weights) for x in loss_weights]
        for i in range(actions.size(1)):
            if loss_weights[i] ==0:
                continue
            loss_i = torch.nn.CrossEntropyLoss()(split_logits[i], actions[:,i])
            losses.append(loss_weights[i] * loss_i)
            
            # Compute accuracy
            predicted = torch.argmax(split_logits[i], dim=1)
            correct = (predicted == actions[:,i]).sum().item()
            total = actions.size(0)
            
            action_correct[i] += correct
            action_total[i] += total

            
        loss = sum(losses)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    accuracies = [correct/(total or 1) * 100 for correct, total in zip(action_correct, action_total)]
    
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Training Loss: {avg_loss:.4f}")
    for i, acc in enumerate(accuracies):
        print(f"Action {i} Accuracy: {acc:.2f}%")
    print("-" * 40)


print("Training complete!")

    


  batches = [[torch.tensor(b).to(device) for b in batch] for batch in list(get_batches(batch_size))]


Epoch 1/1000
Training Loss: 0.2081
Action 0 Accuracy: 96.09%
Action 1 Accuracy: 97.44%
Action 2 Accuracy: 97.12%
Action 3 Accuracy: 98.37%
Action 4 Accuracy: 98.40%
Action 5 Accuracy: 98.44%
Action 6 Accuracy: 62.65%
----------------------------------------
Epoch 2/1000
Training Loss: 0.1271
Action 0 Accuracy: 97.98%
Action 1 Accuracy: 99.10%
Action 2 Accuracy: 98.75%
Action 3 Accuracy: 99.93%
Action 4 Accuracy: 99.96%
Action 5 Accuracy: 100.00%
Action 6 Accuracy: 73.41%
----------------------------------------
Epoch 3/1000
Training Loss: 0.1130
Action 0 Accuracy: 98.13%
Action 1 Accuracy: 99.22%
Action 2 Accuracy: 98.82%
Action 3 Accuracy: 99.93%
Action 4 Accuracy: 99.96%
Action 5 Accuracy: 100.00%
Action 6 Accuracy: 76.33%
----------------------------------------
Epoch 4/1000
Training Loss: 0.1045
Action 0 Accuracy: 98.25%
Action 1 Accuracy: 99.27%
Action 2 Accuracy: 98.86%
Action 3 Accuracy: 99.93%
Action 4 Accuracy: 99.96%
Action 5 Accuracy: 100.00%
Action 6 Accuracy: 78.01%
------

In [7]:

torch.save(policy.state_dict(),'policy_normed_1024_1024_1204_Tanh_obs10x10x10_statedict')

torch.save(mean,'buffer_mean_obs10x10x10')
torch.save(std,'buffer_std_obs10x10x10')
