# CHANGE THESE CONFIGS, THEN UPDATE MODEL INITIALIZATION

In [1]:
MODEL_PATH = "./models/ppo/ppo"
TENSORBOARD_LOG_DIR = "./models/ppo/logs"
IMAGE_TAG = "ppo_rewards"
IMAGE_DIR = "./images/ppo"
RUNS_FILE = "./models/ppo/ppo_success_runs.pkl"

#### ENV CONFIGS ####
CONFIG = {
    "observation": {
        "type": "Kinematics",
        "vehicles_count": 15,
        "features": ["presence", "x", "y", "vx", "vy"],
        "features_range": {
            "x": [-100, 100],
            "y": [-100, 100],
            "vx": [-10, 10],
            "vy": [-10, 10],
        },
        "absolute": False,
        "clip": False,
        "normalize": False,
    },
    # ─────────────── Switch to continuous actions ───────────────
    "action": {
        "type": "ContinuousAction",
    },
    # ───────────────────────── Other settings ─────────────────────────
    "duration": 15,
    "simulation_frequency": 10,
    "policy_frequency": 10,
    "destination": "o1",
    "initial_vehicle_count": 20,
    "spawn_probability": 0.8,
    "ego_spacing": 25,
    "initial_lane_id": None,
    "controlled_vehicles": 1,
    "duration": 15,
    "vehicles_density": 1.0,
    "screen_width": 600,
    "screen_height": 600,
    "centering_position": [0.5, 0.6],
    "scaling": 5.5 * 1.3,
    "normalize_reward": False,
}

## Imports

In [3]:
import time
import torch
import pickle
import random
import gymnasium as gym

from fvcore.nn import FlopCountAnalysis

from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import ProgressBarCallback

from custom_intersection_env import CustomIntersectionEnv
from simple_intersection_env import SimpleIntersectionEnv
from custom_training_callback import RewardTrackingCallback

## Register Env with Gymnasium

In [4]:
gym.envs.registration.register(
    id="simple-intersection-v0",
    entry_point="simple_intersection_env:SimpleIntersectionEnv",
)

## Create and Wrap Env

In [5]:
env = gym.make("simple-intersection-v0", render_mode='rgb_array', config=CONFIG)
env = DummyVecEnv([lambda: env])

## UPDATE HERE: Set Up Correct Model

In [6]:
model = PPO(
    policy="MlpPolicy",
    env=env,
    learning_rate=3e-4,
    n_steps=1000,           # on‐policy rollout length
    batch_size=64,
    n_epochs=10,
    gamma=0.95,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.01,
    vf_coef=0.5,
    max_grad_norm=0.5,
    use_sde=False,          # you can set True for state‐dependent noise
    sde_sample_freq=-1,
    tensorboard_log="./ppo_cont_tb/",
    verbose=0,
    device="cpu",
)

We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=1000 and n_envs=1)


## Training

In [None]:
reward_callback = RewardTrackingCallback(
    tag=IMAGE_TAG,
    path_dir=IMAGE_DIR
)

destinations = ["o1", "o2", "o3"]
steps = [250000 for _ in destinations]

for dest, steps in zip(destinations, steps):
    config = CONFIG.copy()
    config["destination"] = dest  # Change destination for each training phase
    env = gym.make("simple-intersection-v0", render_mode='rgb_array', config=config)
    env = DummyVecEnv([lambda: env])
    model.set_env(env)  # Update the model with the new environment
    model.learn(
        total_timesteps=steps,
        callback=[ProgressBarCallback(), reward_callback]
    )
    reward_callback.start_new_phase()
reward_callback.save_all_plot()

Output()

### Save Model

In [None]:
model.save(MODEL_PATH)

## Evaluation

### Load Model

In [None]:
model = PPO.load(MODEL_PATH)



### Run Evaluation

In [None]:
import contextlib, io

collisions = 0
destination_arrivals = 0
success_count = 0
successful_flopcount = 0
episodes = 100

# Store successful runs for rendering
successful_runs = []

for eps in range(100):
    config = CONFIG.copy()
    config["destination"] = "o" + str(random.randint(1, 3))
    env = gym.make("simple-intersection-v0", render_mode='rgb_array', config=config)

    seed = random.randint(0, 10000)

    obs, _ = env.reset(seed=seed)
    episode_flops = 0
    done = False
    truncated = False
    episode_reward = 0
    trajectory = []

    while not (done or truncated):
        # === FLOP COUNTING (silenced) ===
        f = io.StringIO()
        with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
            # Anything printed by FlopCountAnalysis—whether “unused submodules” or other messages—goes into `f`
            input_tensor, _ = model.policy.obs_to_tensor(obs)
            flops = FlopCountAnalysis(model.policy, input_tensor)
            flops.unsupported_ops_warnings(False)
            flops = flops.total()
        
        episode_flops += flops
        action, _states = model.predict(obs, deterministic=True)
        trajectory.append((obs, action))
        obs, reward, done, truncated, info = env.step(action)
        episode_reward += reward
    
    crashed = info.get("crashed", False)
    arrived = info.get("arrived", False)
    if crashed:
        collisions += 1
    if arrived:
        destination_arrivals += 1
    if (not crashed) and arrived:
        success_count += 1
        successful_flopcount += episode_flops
        successful_runs.append((seed, config.copy(), trajectory))

    print(f"Episode {eps + 1} finished, total reward: {episode_reward}, destination: {config['destination']}, arrived: {arrived}, crashed: {crashed}")
    time.sleep(1)

print(f"Total collisions: {collisions} out of {episodes} episodes")
print(f"Total destination arrivals: {destination_arrivals} out of {episodes} episodes")
if success_count > 0:
    print(f"FLOPS per successful episode: {successful_flopcount / success_count:.2}")
else:
    print("No successful episodes.")



Episode 1 finished, total reward: 21.0, destination: o1, arrived: True, crashed: False
Episode 2 finished, total reward: -20.0, destination: o2, arrived: False, crashed: True
Episode 3 finished, total reward: 164.0, destination: o2, arrived: False, crashed: False
Episode 4 finished, total reward: -21.0, destination: o2, arrived: False, crashed: True
Episode 5 finished, total reward: -8.0, destination: o2, arrived: False, crashed: True


KeyboardInterrupt: 

### Save Successful Runs

In [None]:
with open(RUNS_FILE, "wb") as f:
    pickle.dump(successful_runs, f)

### Load Successful Runs File

In [11]:
with open(RUNS_FILE, "rb") as f:
    successful_runs = pickle.load(f)

### Render Successful Episodes

In [None]:
for i, (seed, config, trajectory) in enumerate(successful_runs[20:]):
    print(f"\nRendering successful episode {i + 1}")
    env = gym.make("custom-intersection-v0", render_mode='human', config=config)
    obs, _ = env.reset(seed=seed)
    for obs, action in trajectory:
        env.step(action)
        env.render()
        time.sleep(0.05)
    env.close()


Rendering successful episode 1

Rendering successful episode 2

Rendering successful episode 3

Rendering successful episode 4

Rendering successful episode 5

Rendering successful episode 6

Rendering successful episode 7

Rendering successful episode 8

Rendering successful episode 9

Rendering successful episode 10


KeyboardInterrupt: 

: 