In [None]:
import os
import gymnasium as gym
import numpy as np
import torch
from stable_baselines3 import DQN, A2C, PPO, SAC
from stable_baselines3.common.callbacks import BaseCallback
import numpy
import matplotlib.pyplot as plt

class RewardTracker(BaseCallback):
    def __init__(self):
        super().__init__()
        self.episode_rewards = []
        self.current_ep_reward = 0

    def _on_step(self):
        # infos is a list of info dicts when using vectorized envs
        info = self.locals.get("infos", [{}])[0]
        reward = self.locals["rewards"][0]
        self.current_ep_reward += reward

        done = self.locals["dones"][0]
        if done:
            self.episode_rewards.append(self.current_ep_reward)
            self.current_ep_reward = 0
        return True



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


lidarconfig = {
    "observation": {
        "type": "Kinematics",
    }
}

env = gym.make("merge-v0", config=lidarconfig, render_mode=None)
print(f"Created environment")

# First train PPO model for 1 million timesteps

# ppo_model = PPO("MlpPolicy", env, verbose=1)

ppo_model = PPO.load("model", env)

# ppo_model.learn(total_timesteps=1_000_000)
# ppo_model.save("onem_ppo_model_finetuned")

sac_model = SAC("MlpPolicy", env, verbose=1)

obs, info = env.reset()

for i in range(100_000):

    if i % 1000 == 0:
        print(f"Warmup step {i}/10000")

    action, _ = ppo_model.predict(obs, deterministic=False)

    # Step with Gymnasium API
    next_obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated

    # Add single transition to SAC replay buffer
    sac_model.replay_buffer.add(obs, next_obs, action, reward, done, [info])

    # Advance state
    if done:
        obs, info = env.reset()
    else:
        obs = next_obs

In [None]:
warm_cb = RewardTracker()
sac_model.learn(total_timesteps=100_000, callback=warm_cb)
sac_model.save("sac_highway_warm")
warm_rewards = np.array(warm_cb.episode_rewards)

# test training without warmup

sac_model2 = SAC("MlpPolicy", env, verbose=1)
obs = env.reset()


no_warm_cb = RewardTracker()
sac_model2.learn(total_timesteps=100_000, callback=no_warm_cb)
sac_model2.save("sac_highway_no_warm")
no_warm_rewards = np.array(no_warm_cb.episode_rewards)
print(f"Finished training")

In [None]:
from multi_scenario_env import MultiScenarioHighwayEnv
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
import numpy as np
import matplotlib.pyplot as plt

env = MultiScenarioHighwayEnv(render_mode="human")




In [None]:
plt.figure(figsize=(10, 6))
plt.plot(warm_rewards, label="SAC (warm start)")
plt.plot(no_warm_rewards, label="SAC (no warm start)")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("Training Reward Comparison")
plt.legend()
plt.grid(True)
plt.show()