# DDQN Runner Notebook

This notebook is intentionally thin: keep parameters here, then call `main.py` modules.
Edit the config cell and run all.


In [1]:
%matplotlib qt
import os
import torch
import torch.nn as nn
import torch.optim as optim

from env import TaskSchedulingEnv
from model import QNetwork
from trainer import prepare_scenarios, train_ddqn
from evaluator import run_test_and_plot, print_batch_results


In [2]:
# -----------------------------
# Parameters (edit this block)
# -----------------------------
CFG = {
    # Data source and generation
    'task_file': 'dispatch_batches.jsonl',
    'train_data_dir': os.path.join(os.getcwd(), 'train_data'),
    'auto_generate_data': False,
    'gen_batches': 15,
    'gen_size': None,
    'gen_min_size': 1,
    'gen_max_size': 25,
    'gen_arrival_mean': 100,
    'gen_seed': None,
    'multi_streams': True,
    'num_streams': 200,
    'base_seed': 100,
    'stream_file_template': 'dispatch_batches_{i}.jsonl',

    # Scenario sampling
    'sampling_mode': 'full',  # full | window | subset
    'window_size': 10,
    'subset_size': 10,

    # Train/test switches
    'do_train': True,
    'do_test': True,

    # Model checkpoint
    'save_model_path': 'ddqn_policy_100epoch.pt',
    'load_model_path': None,

    # DDQN and optimizer
    'state_dim': 28,
    'action_dim': 4,
    'lr': 1e-3,
    'num_episodes': 100,
    'batch_size': 128,
    'gamma': 0.99,
    'epsilon': 1.0,
    'epsilon_end': 0.05,
    'epsilon_decay': 0.995,
    'allow_proactive_replenish': True,

    # Training visualization and profiling
    'show_train_schedule': False,
    'train_schedule_every_episodes': 1,
    'train_schedule_every_steps': 1,
    'train_schedule_window': 120.0,
    'train_schedule_window_all_axes': False,
    'train_schedule_pause': 0.01,
    'train_schedule_show_labels': False,
    'train_schedule_figsize': (14, 8),
    'show_train_route_map': False,
    'train_route_map_every_episodes': 1,
    'train_route_map_every_steps': 1,
    'train_route_map_pause': 0.01,
    'train_route_map_figsize': (9, 8),
    'train_route_map_animate': False,
    'train_route_map_time_step': 0.5,
    'train_route_map_max_frames_per_update': 120,
    'train_route_map_delay_seconds': 20.0,
    'enable_profile': True,
    'profile_cuda_sync': True,

    # Test and plotting
    'test_scenario_file': 'test_scenario_one_time.jsonl',
    'show_live': False,
    'show_live_stream': False,
    'show_interactive': False,
    'show_route_map': True,
    'show_plotly': True,
    'show_test_plots': True,
    'live_pause': 0.05,
    'live_job_file': 'live_jobs.jsonl',
    'live_start_at_end': True,
    'live_poll_interval': 0.5,
    'live_idle_sleep': 0.1,
    'live_max_steps': 100,
    'live_max_sim_time': None,
    'live_init_scenario': [],
    'live_record_dir': None,
    'live_record_every': 5,
    'live_record_dpi': 140,
    'live_make_gif': False,
    'live_gif_path': 'live_schedule.gif',
    'route_play_step': 0.5,
    'route_play_interval_ms': 120,
}


In [3]:
# -----------------------------
# Run pipeline
# -----------------------------
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

scenario_list = prepare_scenarios(
    task_file=CFG['task_file'],
    train_data_dir=CFG['train_data_dir'],
    auto_generate_data=CFG['auto_generate_data'],
    gen_batches=CFG['gen_batches'],
    gen_size=CFG['gen_size'],
    gen_min_size=CFG['gen_min_size'],
    gen_max_size=CFG['gen_max_size'],
    gen_arrival_mean=CFG['gen_arrival_mean'],
    gen_seed=CFG['gen_seed'],
    multi_streams=CFG['multi_streams'],
    num_streams=CFG['num_streams'],
    base_seed=CFG['base_seed'],
    stream_file_template=CFG['stream_file_template'],
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)
if device.type == 'cuda':
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision('high')

env = TaskSchedulingEnv()
env.allow_proactive_replenish = CFG['allow_proactive_replenish']
input_dim = CFG['state_dim'] + CFG['action_dim']
policy_net = QNetwork(input_dim).to(device)
target_net = QNetwork(input_dim).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=CFG['lr'])
criterion = nn.MSELoss()

if CFG['load_model_path'] and os.path.exists(CFG['load_model_path']):
    ckpt = torch.load(CFG['load_model_path'], map_location=device)
    if isinstance(ckpt, dict) and 'model_state' in ckpt:
        policy_net.load_state_dict(ckpt['model_state'])
        if 'target_state' in ckpt:
            target_net.load_state_dict(ckpt['target_state'])
        else:
            target_net.load_state_dict(policy_net.state_dict())
        if 'optimizer_state' in ckpt:
            optimizer.load_state_dict(ckpt['optimizer_state'])
    else:
        policy_net.load_state_dict(ckpt)
        target_net.load_state_dict(policy_net.state_dict())
    print(f"Loaded model from {CFG['load_model_path']}")

if CFG['do_train']:
    _train_info = train_ddqn(
        env=env,
        policy_net=policy_net,
        target_net=target_net,
        optimizer=optimizer,
        criterion=criterion,
        scenario_list=scenario_list,
        device=device,
        num_episodes=CFG['num_episodes'],
        batch_size=CFG['batch_size'],
        gamma=CFG['gamma'],
        epsilon=CFG['epsilon'],
        epsilon_end=CFG['epsilon_end'],
        epsilon_decay=CFG['epsilon_decay'],
        sampling_mode=CFG['sampling_mode'],
        window_size=CFG['window_size'],
        subset_size=CFG['subset_size'],
        show_train_schedule=CFG['show_train_schedule'],
        train_schedule_every_episodes=CFG['train_schedule_every_episodes'],
        train_schedule_every_steps=CFG['train_schedule_every_steps'],
        train_schedule_window=CFG['train_schedule_window'],
        train_schedule_window_all_axes=CFG['train_schedule_window_all_axes'],
        train_schedule_pause=CFG['train_schedule_pause'],
        train_schedule_show_labels=CFG['train_schedule_show_labels'],
        train_schedule_figsize=CFG['train_schedule_figsize'],
        show_train_route_map=CFG['show_train_route_map'],
        train_route_map_every_episodes=CFG['train_route_map_every_episodes'],
        train_route_map_every_steps=CFG['train_route_map_every_steps'],
        train_route_map_pause=CFG['train_route_map_pause'],
        train_route_map_figsize=CFG['train_route_map_figsize'],
        train_route_map_animate=CFG['train_route_map_animate'],
        train_route_map_time_step=CFG['train_route_map_time_step'],
        train_route_map_max_frames_per_update=CFG['train_route_map_max_frames_per_update'],
        train_route_map_delay_seconds=CFG['train_route_map_delay_seconds'],
        enable_profile=CFG['enable_profile'],
        profile_cuda_sync=CFG['profile_cuda_sync'],
    )

    if CFG['save_model_path']:
        ckpt = {
            'model_state': policy_net.state_dict(),
            'target_state': target_net.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'state_dim': CFG['state_dim'],
            'action_dim': CFG['action_dim'],
            'input_dim': input_dim,
        }
        torch.save(ckpt, CFG['save_model_path'])
        print(f"Saved model to {CFG['save_model_path']}")

if CFG['do_test']:
    _mk = run_test_and_plot(
        env=env,
        policy_net=policy_net,
        device=device,
        scenario_list=scenario_list,
        test_scenario_file=CFG['test_scenario_file'],
        show_live=CFG['show_live'],
        show_live_stream=CFG['show_live_stream'],
        show_interactive=CFG['show_interactive'],
        show_route_map=CFG['show_route_map'],
        show_plotly=CFG['show_plotly'],
        show_test_plots=CFG['show_test_plots'],
        live_pause=CFG['live_pause'],
        live_job_file=CFG['live_job_file'],
        live_start_at_end=CFG['live_start_at_end'],
        live_poll_interval=CFG['live_poll_interval'],
        live_idle_sleep=CFG['live_idle_sleep'],
        live_max_steps=CFG['live_max_steps'],
        live_max_sim_time=CFG['live_max_sim_time'],
        live_init_scenario=CFG['live_init_scenario'],
        live_record_dir=CFG['live_record_dir'],
        live_record_every=CFG['live_record_every'],
        live_record_dpi=CFG['live_record_dpi'],
        live_make_gif=CFG['live_make_gif'],
        live_gif_path=CFG['live_gif_path'],
        route_play_step=CFG['route_play_step'],
        route_play_interval_ms=CFG['route_play_interval_ms'],
    )
    print_batch_results(env=env, policy_net=policy_net, device=device, scenario_list=scenario_list)


device = cuda
[EP 1] batch=stream:15 dispatch_time=13.45 release_t0=0.00 jobs=212 MA-50 mk: 2926.21, mk/job: 13.803, mk/proc: 0.914, eps=0.995, last_mk=2926.21
[EP 2] batch=stream:15 dispatch_time=16.04 release_t0=0.00 jobs=193 MA-50 mk: 2677.48, mk/job: 12.584, mk/proc: 0.829, eps=0.990, last_mk=2428.74
[EP 3] batch=stream:15 dispatch_time=49.03 release_t0=0.00 jobs=211 MA-50 mk: 2782.63, mk/job: 14.185, mk/proc: 0.961, eps=0.985, last_mk=2992.95
[EP 4] batch=stream:15 dispatch_time=60.16 release_t0=0.00 jobs=151 MA-50 mk: 2669.11, mk/job: 15.421, mk/proc: 1.049, eps=0.980, last_mk=2328.55
[EP 5] batch=stream:15 dispatch_time=39.17 release_t0=0.00 jobs=213 MA-50 mk: 2661.68, mk/job: 12.357, mk/proc: 0.853, eps=0.975, last_mk=2631.94
[EP 6] batch=stream:15 dispatch_time=83.22 release_t0=0.00 jobs=168 MA-50 mk: 2552.40, mk/job: 11.940, mk/proc: 0.785, eps=0.970, last_mk=2006.00
[EP 7] batch=stream:15 dispatch_time=43.47 release_t0=0.00 jobs=191 MA-50 mk: 2552.39, mk/job: 13.363, mk/proc