# CHANGE THESE CONFIGS, THEN UPDATE MODEL INITIALIZATION

In [16]:
MODEL_PATH = "./models/sac/sac"
TENSORBOARD_LOG_DIR = "./models/sac/logs"
IMAGE_TAG = "sac_rewards"
IMAGE_DIR = "./images/sac"
RUNS_FILE = "./models/sac/sac_success_runs.pkl"

#### ENV CONFIGS ####
CONFIG = {
    "observation": {
        "type": "Kinematics",
        "vehicles_count": 15,
        "features": ["presence", "x", "y", "vx", "vy"],
        "features_range": {
            "x": [-100, 100],
            "y": [-100, 100],
            "vx": [-10, 10],
            "vy": [-10, 10],
        },
        "absolute": False,
        "clip": False,
        "normalize": False,
    },
    "action": {
        "type": "ContinuousAction",
    },
    "duration": 15,
    "simulation_frequency": 10,
    "policy_frequency": 10,
    "destination": "o1",
    "initial_vehicle_count": 20,
    "spawn_probability": 0.8,
    "ego_spacing": 25,
    "initial_lane_id": None,
    "controlled_vehicles": 1,
    "duration": 15,
    "vehicles_density": 1.0,
    "screen_width": 600,
    "screen_height": 600,
    "centering_position": [0.5, 0.6],
    "scaling": 5.5 * 1.3,
    "normalize_reward": False,
}

## Imports

In [17]:
import time
import torch
import pickle
import random
import gymnasium as gym

from fvcore.nn import FlopCountAnalysis

from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import ProgressBarCallback

from custom_intersection_env import CustomIntersectionEnv
from custom_training_callback import RewardTrackingCallback

## Register Env with Gymnasium

In [18]:
gym.envs.registration.register(
    id='custom-intersection-v0',
    entry_point='custom_intersection_env:CustomIntersectionEnv',
)

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


## Create and Wrap Env

In [19]:
env = gym.make("custom-intersection-v0", render_mode='rgb_array', config=CONFIG)
env = DummyVecEnv([lambda: env])

## UPDATE HERE: Set Up Correct Model

In [21]:
model = SAC(
    policy="MlpPolicy",
    env=env,
    learning_rate=3e-4,
    buffer_size=1000,   # sufficiently large replay buffer
    batch_size=256,
    tau=0.005,
    gamma=0.99,
    train_freq=1,
    gradient_steps=1,
    ent_coef="auto",
    target_entropy="auto",
    learning_starts=1_000,
    use_sde=False,
    verbose=0,
    device="cpu",
)

## Training

In [22]:
reward_callback = RewardTrackingCallback(
    tag=IMAGE_TAG,
    path_dir=IMAGE_DIR
)

destinations = ["o1", "o2", "o3"]
steps = [5000 for _ in destinations]

for dest, steps in zip(destinations, steps):
    config = CONFIG.copy()
    config["destination"] = dest  # Change destination for each training phase
    env = gym.make("custom-intersection-v0", render_mode='rgb_array', config=config)
    env = DummyVecEnv([lambda: env])
    model.set_env(env)  # Update the model with the new environment
    model.learn(
        total_timesteps=steps,
        callback=[ProgressBarCallback(), reward_callback]
    )
    reward_callback.start_new_phase()
reward_callback.save_all_plot()

Output()

Output()

Output()

### Save Model

In [23]:
model.save(MODEL_PATH)

## Evaluation

### Load Model

In [24]:
model = SAC.load(MODEL_PATH)

### Run Evaluation

In [25]:
import contextlib, io

collisions = 0
destination_arrivals = 0
success_count = 0
successful_flopcount = 0
episodes = 100

# Store successful runs for rendering
successful_runs = []

for eps in range(100):
    config = CONFIG.copy()
    config["destination"] = "o" + str(random.randint(1, 3))
    env = gym.make("custom-intersection-v0", render_mode='rgb_array', config=config)

    seed = random.randint(0, 10000)

    obs, _ = env.reset(seed=seed)
    episode_flops = 0
    done = False
    truncated = False
    episode_reward = 0
    trajectory = []

    while not (done or truncated):
        # === FLOP COUNTING (silenced) ===
        f = io.StringIO()
        with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
            # Anything printed by FlopCountAnalysis—whether “unused submodules” or other messages—goes into `f`
            input_tensor, _ = model.policy.obs_to_tensor(obs)
            flops = FlopCountAnalysis(model.policy, input_tensor)
            flops.unsupported_ops_warnings(False)
            flops = flops.total()
        
        episode_flops += flops
        action, _states = model.predict(obs, deterministic=True)
        trajectory.append((obs, action))
        obs, reward, done, truncated, info = env.step(action)
        episode_reward += reward
    
    crashed = info.get("crashed", False)
    arrived = info.get("arrived", False)
    if crashed:
        collisions += 1
    if arrived:
        destination_arrivals += 1
    if (not crashed) and arrived:
        success_count += 1
        successful_flopcount += episode_flops
        successful_runs.append((seed, config.copy(), trajectory))

    print(f"Episode {eps + 1} finished, total reward: {episode_reward}, destination: {config['destination']}, arrived: {arrived}, crashed: {crashed}")
    time.sleep(1)

print(f"Total collisions: {collisions} out of {episodes} episodes")
print(f"Total destination arrivals: {destination_arrivals} out of {episodes} episodes")
if success_count > 0:
    print(f"FLOPS per successful episode: {successful_flopcount / success_count:.2}")
else:
    print("No successful episodes.")



Episode 1 finished, total reward: -14.80396679710732, destination: o1, arrived: False, crashed: True
Episode 2 finished, total reward: -2.1475817299501756, destination: o2, arrived: False, crashed: True
Episode 3 finished, total reward: 93.76519922138418, destination: o3, arrived: False, crashed: False
Episode 4 finished, total reward: 93.76519922138418, destination: o3, arrived: False, crashed: False
Episode 5 finished, total reward: 97.60661437086566, destination: o3, arrived: False, crashed: False
Episode 6 finished, total reward: 93.91519922138419, destination: o3, arrived: False, crashed: False
Episode 7 finished, total reward: 105.52341906913172, destination: o3, arrived: False, crashed: False
Episode 8 finished, total reward: -12.731371847557856, destination: o3, arrived: False, crashed: True


KeyboardInterrupt: 

### Save Successful Runs

In [None]:
with open(RUNS_FILE, "wb") as f:
    pickle.dump(successful_runs, f)

### Load Successful Runs File

In [None]:
with open(RUNS_FILE, "rb") as f:
    successful_runs = pickle.load(f)

### Render Successful Episodes

In [None]:
for i, (seed, config, trajectory) in enumerate(successful_runs):
    print(f"\nRendering successful episode {i + 1}")
    env = gym.make("custom-intersection-v0", render_mode='human', config=config)
    obs, _ = env.reset(seed=seed)
    for obs, action in trajectory:
        env.step(action)
        env.render()
        time.sleep(0.05)
    env.close()