# CHANGE THESE CONFIGS, THEN UPDATE MODEL INITIALIZATION

In [1]:
MODEL_PATH = "./models/a2c/a2c_simple"
TENSORBOARD_LOG_DIR = "./models/a2c_simple/logs"
IMAGE_TAG = "a2c_simple_rewards"
IMAGE_DIR = "./images/a2c_simple"
RUNS_FILE = "./models/a2c_simple/a2c_simple_success_runs.pkl"

#### ENV CONFIGS ####
CONFIG = {
    "observation": {
        "type": "Kinematics",
        "vehicles_count": 15,  # Number of other vehicles to observe
        "features": ["presence", "x", "y", "vx", "vy"],  # Observe position and velocity
        "features_range": {
            "x": [-100, 100],
            "y": [-100, 100],
            "vx": [-10, 10],
            "vy": [-10, 10]
        },
        "absolute": False,
        "clip": False,
        "normalize": False
    },
    "action": {
        "type": "DiscreteMetaAction",  # Keep simple, 5 discrete actions
    },
    "simulation_frequency": 10,
    "policy_frequency": 10,
    "destination": "o1",
    "initial_vehicle_count": 20,
    "spawn_probability": 0.8,
    "ego_spacing": 25,
    "initial_lane_id": None,
    "controlled_vehicles": 1,
    "duration": 15,  # seconds
    "vehicles_density": 1.0,
    "screen_width": 600,
    "screen_height": 600,
    "centering_position": [0.5, 0.6],
    "scaling": 5.5 * 1.3,
    "normalize_reward": False
}

## Imports

In [None]:
import time
import torch
import pickle
import random
import gymnasium as gym

from fvcore.nn import FlopCountAnalysis

from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import ProgressBarCallback

from custom_intersection_env import CustomIntersectionEnv
from simple_intersection_env import SimpleIntersectionEnv
from custom_training_callback import RewardTrackingCallback

  from pkg_resources import resource_stream, resource_exists


## Register Env with Gymnasium

In [None]:
gym.envs.registration.register(
    id="simple-intersection-v0",
    entry_point="simple_intersection_env:SimpleIntersectionEnv",
)

## Create and Wrap Env

In [None]:
env = gym.make("simple-intersection-v0", render_mode='rgb_array', config=CONFIG)
env = DummyVecEnv([lambda: env])

## UPDATE HERE: Set Up Correct Model

In [5]:
model = A2C(
    "MlpPolicy",
    env,
    n_steps=5,
    learning_rate=7e-4,
    gamma=0.99,
    gae_lambda=1.0,
    ent_coef=0.01,
    vf_coef=0.5,
    max_grad_norm=0.5,
    rms_prop_eps=1e-5,
    use_rms_prop=True,
    normalize_advantage=False,
    tensorboard_log=TENSORBOARD_LOG_DIR,
    verbose=0,
    device='cpu'
)

## Training

In [None]:
reward_callback = RewardTrackingCallback(
    tag=IMAGE_TAG,
    path_dir=IMAGE_DIR
)

destinations = ["o1", "o2", "o3"]
steps = [50000 for _ in destinations]

for dest, steps in zip(destinations, steps):
    config = CONFIG.copy()
    config["destination"] = dest  # Change destination for each training phase
    env = gym.make("simple-intersection-v0", render_mode='rgb_array', config=CONFIG)
    env = DummyVecEnv([lambda: env])
    model.set_env(env)  # Update the model with the new environment
    model.learn(
        total_timesteps=steps,
        callback=[ProgressBarCallback(), reward_callback]
    )
    reward_callback.start_new_phase()
reward_callback.save_all_plot()

Output()

KeyboardInterrupt: 

### Save Model

In [None]:
model.save(MODEL_PATH)

## Evaluation

### Load Model

In [None]:
model = A2C.load(MODEL_PATH)

### Run Evaluation

In [None]:
collisions = 0
destination_arrivals = 0
success_count = 0
successful_flopcount = 0
episodes = 100

# Store successful runs for rendering
successful_runs = []

for eps in range(100):
    config = CONFIG.copy()
    config["destination"] = "o" + str(random.randint(1, 3))
    env = gym.make("simple-intersection-v0", render_mode='rgb_array', config=CONFIG)

    seed = random.randint(0, 10000)

    obs, _ = env.reset(seed=seed)
    episode_flops = 0
    done = False
    truncated = False
    episode_reward = 0
    trajectory = []

    while not (done or truncated):
        # Flop Counting
        input_tensor, _ = model.policy.obs_to_tensor(obs)
        flops = FlopCountAnalysis(model.policy, input_tensor)
        flops.unsupported_ops_warnings(False)
        flops = flops.total()
        episode_flops += flops

        action, _states = model.predict(obs, deterministic=True)
        trajectory.append((obs, action))  # Save for later render if successful
        obs, reward, done, truncated, info = env.step(action)
        episode_reward += reward
    
    crashed = info.get("crashed", False)
    arrived = info.get("arrived", False)
    if crashed:
        collisions += 1
    if arrived:
        destination_arrivals += 1
    if (not crashed) and arrived:
        success_count += 1
        successful_flopcount += episode_flops
        successful_runs.append((seed, config.copy(), trajectory))

    print(f"Episode {eps + 1} finished, total reward: {episode_reward}, destination: {config['destination']}")
    time.sleep(1)

print(f"Total collisions: {collisions} out of {episodes} episodes")
print(f"Total destination arrivals: {destination_arrivals} out of {episodes} episodes")
if success_count > 0:
    print(f"FLOPS per successful episode: {successful_flopcount / success_count:.2}")
else:
    print("No successful episodes.")



Episode 1 finished, total reward: 149.21755712170813, destination: o1
Episode 2 finished, total reward: 200.09040963587552, destination: o1
Episode 3 finished, total reward: 326.3554000085889, destination: o3
Episode 4 finished, total reward: 182.2736896877958, destination: o3
Episode 5 finished, total reward: 338.46485623563393, destination: o1
Episode 6 finished, total reward: 222.7795726068492, destination: o3
Episode 7 finished, total reward: 137.2186726623264, destination: o3
Episode 8 finished, total reward: 192.3416738182873, destination: o1
Episode 9 finished, total reward: 160.12450495919896, destination: o1
Episode 10 finished, total reward: 118.32974049458173, destination: o3
Episode 11 finished, total reward: 482.06703192309106, destination: o2
Episode 12 finished, total reward: 377.329142567498, destination: o2
Episode 13 finished, total reward: 727.722003647451, destination: o3
Episode 14 finished, total reward: 346.78955677067586, destination: o2
Episode 15 finished, tot

### Save Successful Runs

In [None]:
with open(RUNS_FILE, "wb") as f:
    pickle.dump(successful_runs, f)

### Load Successful Runs File

In [None]:
with open(RUNS_FILE, "rb") as f:
    successful_runs = pickle.load(f)

### Render Successful Episodes

In [None]:
for i, (seed, config, trajectory) in enumerate(successful_runs):
    print(f"\nRendering successful episode {i + 1}")
    env = gym.make("simple-intersection-v0", render_mode='rgb_array', config=CONFIG)
    obs, _ = env.reset(seed=seed)
    for obs, action in trajectory:
        env.step(action)
        env.render()
        time.sleep(0.05)
    env.close()


Rendering successful episode 1

Rendering successful episode 2

Rendering successful episode 3

Rendering successful episode 4

Rendering successful episode 5

Rendering successful episode 6


KeyboardInterrupt: 