In [1]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
from pystk2_gymnasium import AgentSpec
from stk_actor.replay_buffer import SACRolloutBuffer
from stk_actor.wrappers import PreprocessObservationWrapper
import torch

class PolicyWrapper(torch.nn.Module):
    def __init__(self, policy_stb, dropout):
        super().__init__()
        self.shared = policy_stb
        self.dropout = torch.nn.Dropout(dropout)
    def forward(self, x):
        x = self.shared.features_extractor(x)
        x = self.shared.mlp_extractor.policy_net(x)
        x = self.dropout(x)
        x = self.shared.action_net(x)
        return x

vec_env = make_vec_env(
    "supertuxkart/flattened_multidiscrete-v0", 
    n_envs=1, 
    wrapper_class=lambda x : (PreprocessObservationWrapper(x)), 
    env_kwargs={
        'render_mode':None, 'agent':AgentSpec(use_ai=False, name="walid"), 'difficulty':0,#'track':'abyss', #'num_kart':2, 'difficulty':0
    }
)
vec_env.close()

policy_stb = PPO('MlpPolicy', vec_env, policy_kwargs = dict(
    net_arch=[512,512,512,512],
    activation_fn=torch.nn.SiLU,
)).policy

obs_dim = vec_env.observation_space.shape[0]
action_dims = [space.n for space in vec_env.action_space]

device = 'mps'
policy = PolicyWrapper(
    policy_stb, 0.2
).to(device)

obs_dim, action_dims

..:: Antarctica Rendering Engine 2.0 ::..


(154, [5, 2, 2, 2, 2, 2, 7])

In [2]:
import joblib, random
from torch.optim import AdamW

batch_size = 1024 * 8
start_step_id = 18
num_epochs = 1000
device = 'mps'

policy = policy.to(device)

optimizer = AdamW(policy.parameters(), lr=2e-3)
criterion = torch.nn.CrossEntropyLoss(reduction='none')

buffer1 = joblib.load('all_tracks_buffer_steps_2mil')
buffer2 = joblib.load('all_tracks_buffer_steps_1laps')

batches = [[torch.tensor(b).to(device) for b in batch] for batch in list(buffer1.get_batches(batch_size, start_step_id))]
batches.extend([[torch.tensor(b).to(device) for b in batch] for batch in list(buffer2.get_batches(batch_size, start_step_id))])

  batches = [[torch.tensor(b).to(device) for b in batch] for batch in list(buffer1.get_batches(batch_size, start_step_id))]
  batches.extend([[torch.tensor(b).to(device) for b in batch] for batch in list(buffer2.get_batches(batch_size, start_step_id))])


In [3]:
valid_obs = torch.concatenate(
    [
        buffer1.observations[:buffer1.size],
        buffer2.observations[:buffer2.size],
    ], dim=0
)
mean = valid_obs.mean(dim=0).to(device)
std = valid_obs.std(dim=0).to(device)
mean.shape, std.shape, valid_obs.shape

(torch.Size([154]), torch.Size([154]), torch.Size([1092612, 154]))

In [4]:
policy.dropout = torch.nn.Identity()

In [5]:
for epoch in range(num_epochs):

    policy.train()
    total_loss = 0
    num_batches = 0

    if epoch %10 == 0:
        batches = [[torch.tensor(b).to(device) for b in batch] for batch in list(buffer1.get_batches(batch_size, start_step_id))]
        batches.extend([[torch.tensor(b).to(device) for b in batch] for batch in list(buffer2.get_batches(batch_size, start_step_id))])

    action_correct = [0] * len(action_dims)
    action_total = [0] * len(action_dims)
    
    random.shuffle(batches)
    
    for batch in batches:
        obs, actions, rewards, next_obs, prev_obs, dones, _, _ = batch
        obs = obs 
        # mean = obs.mean(dim=1, keepdim=True)
        # std = obs.std(dim=1, keepdim=True)
        # obs = (obs - mean.unsqueeze(0)) / (std.unsqueeze(0) + 1e-8)
        obs = obs #+ torch.randn_like(obs) * 0.02

        actions = actions.permute(1, 0)

        outputs = policy(obs)
        optimizer.zero_grad()
        split_logits = torch.split(outputs, action_dims, dim=-1)
        
        losses = []
        # loss_weights = ([0,0,0,0,0,0,1]) # only train steering
        loss_weights = ([1,1,1,1,1,1,10]) # train all actions
        loss_weights = [x/sum(loss_weights) for x in loss_weights]
        for i in range(actions.size(1)):
            if loss_weights[i] ==0:
                continue
            loss_i = torch.nn.CrossEntropyLoss()(split_logits[i], actions[:,i])
            losses.append(loss_weights[i] * loss_i)
            
            # Compute accuracy
            predicted = torch.argmax(split_logits[i], dim=1)
            correct = (predicted == actions[:,i]).sum().item()
            total = actions.size(0)
            
            action_correct[i] += correct
            action_total[i] += total

            
        loss = sum(losses)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    accuracies = [correct/(total or 1) * 100 for correct, total in zip(action_correct, action_total)]
    
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Training Loss: {avg_loss:.4f}")
    for i, acc in enumerate(accuracies):
        print(f"Action {i} Accuracy: {acc:.2f}%")
    print("-" * 40)

print("Training complete!")

  batches = [[torch.tensor(b).to(device) for b in batch] for batch in list(buffer1.get_batches(batch_size, start_step_id))]
  batches.extend([[torch.tensor(b).to(device) for b in batch] for batch in list(buffer2.get_batches(batch_size, start_step_id))])


Epoch 1/1000
Training Loss: 1.4889
Action 0 Accuracy: 97.23%
Action 1 Accuracy: 99.20%
Action 2 Accuracy: 98.25%
Action 3 Accuracy: 99.20%
Action 4 Accuracy: 98.41%
Action 5 Accuracy: 100.00%
Action 6 Accuracy: 42.64%
----------------------------------------
Epoch 2/1000
Training Loss: 0.7159
Action 0 Accuracy: 98.56%
Action 1 Accuracy: 99.20%
Action 2 Accuracy: 98.73%
Action 3 Accuracy: 99.96%
Action 4 Accuracy: 98.98%
Action 5 Accuracy: 100.00%
Action 6 Accuracy: 59.57%
----------------------------------------
Epoch 3/1000
Training Loss: 0.6344
Action 0 Accuracy: 98.58%
Action 1 Accuracy: 99.21%
Action 2 Accuracy: 98.74%
Action 3 Accuracy: 99.96%
Action 4 Accuracy: 99.04%
Action 5 Accuracy: 100.00%
Action 6 Accuracy: 64.02%
----------------------------------------
Epoch 4/1000
Training Loss: 0.6102
Action 0 Accuracy: 98.59%
Action 1 Accuracy: 99.21%
Action 2 Accuracy: 98.73%
Action 3 Accuracy: 99.96%
Action 4 Accuracy: 99.05%
Action 5 Accuracy: 100.00%
Action 6 Accuracy: 65.38%
-----

KeyboardInterrupt: 

In [6]:
torch.save(policy.state_dict(),'policy_notnorm_512_512_512_512_SiLU_statedict')
torch.save(mean,'buffer_mean')
torch.save(std,'buffer_std')