In [1]:
import os
import pandas as pd
import yaml
from collections import defaultdict
from ruamel.yaml import YAML
from typing import Any, Dict, List, Optional, Tuple

In [2]:
def load_scenario_filter_from_yaml(yaml_file):
    with open(yaml_file, "r") as file:
        config = yaml.safe_load(file)
    return config.get("scenario_types", [])

def get_train_data_from_yaml(yaml_file):
    with open(yaml_file, "r") as file:
        config = yaml.safe_load(file)
    train_data = config.get("log_splits", {}).get("train", [])
    return train_data

In [3]:
def get_log_names_from_cache(cache_dir):
    log_names = [
        folder for folder in os.listdir(cache_dir)
        if os.path.isdir(os.path.join(cache_dir, folder))
    ]
    return log_names

def get_all_scenario_tokens(cache_dir):
    scenario_tokens = []
    for log_name in os.listdir(cache_dir):
        log_name_path = os.path.join(cache_dir, log_name)
        if os.path.isdir(log_name_path):
            for scenario_type in os.listdir(log_name_path):
                scenario_type_path = os.path.join(log_name_path, scenario_type)
                if os.path.isdir(scenario_type_path):
                    for scenario_token in os.listdir(scenario_type_path):
                        scenario_token_path = os.path.join(scenario_type_path, scenario_token)
                        if os.path.isdir(scenario_token_path):
                            scenario_tokens.append(scenario_token)
    return scenario_tokens

def get_scenario_type_counts(cache_dir):
    scenario_type_counts = defaultdict(int)
 
    for log_name in os.listdir(cache_dir):
        log_name_path = os.path.join(cache_dir, log_name)
        
  
        if os.path.isdir(log_name_path):
            for scenario_type in os.listdir(log_name_path):
                scenario_type_path = os.path.join(log_name_path, scenario_type)
                
                if os.path.isdir(scenario_type_path):
                    token_count = len([
                        token for token in os.listdir(scenario_type_path)
                        if os.path.isdir(os.path.join(scenario_type_path, token))
                    ])
                    scenario_type_counts[scenario_type] += token_count
    return scenario_type_counts

def diff_scenario_types(scenario_filter_types, csv_scenario_types):
    in_filter_not_in_csv = set(scenario_filter_types) - set(csv_scenario_types)
    in_csv_not_in_filter = set(csv_scenario_types) - set(scenario_filter_types)
    print("Scenario types in scenario_filter but not in CSV:")
    print(in_filter_not_in_csv)
    print("\nScenario types in CSV but not in scenario_filter:")
    print(in_csv_not_in_filter)
    return in_filter_not_in_csv

def get_resample_scenarios(df):
    filtered_scenarios = df[df["count"] / 2 < 1000]
    scenario_dict = (filtered_scenarios
                    .set_index("scenario_type")["count"]
                    .apply(lambda x: 1000-(x / 2))  # 对 count 除以 2
                    .to_dict())
    print("Scenario types with count/2 less than 1000 (counts divided by 2):")
    print(scenario_dict)
    return scenario_dict

In [4]:
scenario_type_counts = get_scenario_type_counts("exp/cache_pdm_open")
print(scenario_type_counts)

defaultdict(<class 'int'>, {'stationary': 107944, 'high_magnitude_speed': 30841, 'traversing_intersection': 15792, 'medium_magnitude_speed': 27579, 'traversing_traffic_light_intersection': 45028, 'near_pedestrian_on_crosswalk': 2672, 'on_traffic_light_intersection': 2071, 'near_long_vehicle': 2973, 'low_magnitude_speed': 4902, 'stationary_in_traffic': 36936, 'near_high_speed_vehicle': 3850, 'near_pedestrian_at_pickup_dropoff': 5413, 'on_all_way_stop_intersection': 487, 'starting_protected_noncross_turn': 167, 'stationary_at_traffic_light_without_lead': 7139, 'following_lane_with_slow_lead': 780, 'near_construction_zone_sign': 974, 'following_lane_without_lead': 1418, 'stationary_at_traffic_light_with_lead': 2531, 'stopping_with_lead': 125, 'waiting_for_pedestrian_to_cross': 192, 'following_lane_with_lead': 57, 'near_multiple_pedestrians': 78, 'accelerating_at_traffic_light': 13, 'stopping_at_traffic_light_without_lead': 32, 'near_pedestrian_on_crosswalk_with_ego': 6})


In [44]:
cache_dir='exp/resample_cache_pdm_open'
scenario_type_count=get_scenario_type_counts(cache_dir)

In [5]:
yaml_file = "InD.yaml"
scenario_filter_types = load_scenario_filter_from_yaml(yaml_file)
csv_file = "open_scenario_type_counts.csv"  
df = pd.read_csv(csv_file)
csv_scenario_types = df["scenario_type"].tolist()
scenario_to_add = diff_scenario_types(scenario_filter_types, csv_scenario_types)
print(scenario_to_add)

Scenario types in scenario_filter but not in CSV:
{'behind_long_vehicle'}

Scenario types in CSV but not in scenario_filter:
set()
{'behind_long_vehicle'}


In [7]:
resample_scenarios =get_resample_scenarios(df)
for scenario in scenario_to_add:
    resample_scenarios[scenario] =1000

Scenario types with count/2 less than 1000 (counts divided by 2):
{'near_construction_zone_sign': 26.0, 'following_lane_with_slow_lead': 220.0, 'on_all_way_stop_intersection': 513.0, 'waiting_for_pedestrian_to_cross': 808.0, 'starting_protected_noncross_turn': 833.0, 'stopping_with_lead': 875.0, 'near_multiple_pedestrians': 922.0, 'following_lane_with_lead': 943.0, 'stopping_at_traffic_light_without_lead': 968.0, 'accelerating_at_traffic_light': 987.0, 'near_pedestrian_on_crosswalk_with_ego': 994.0}


In [10]:
def get_resample_scenario_tokens(cache_dir: str, resample_scenarios: Dict[str, float]) -> Dict[str, List[str]]:
    """
    Get all scenario tokens for scenario types specified in `resample_scenarios`.

    :param cache_dir: Path to the cache directory.
    :param resample_scenarios: Dictionary of scenario types to include.
    :return: A dictionary where keys are scenario types and values are lists of tokens.
    """
    scenario_tokens = {scenario_type: [] for scenario_type in resample_scenarios.keys()}  # Initialize dictionary

    for log_name in os.listdir(cache_dir):
        log_name_path = os.path.join(cache_dir, log_name)
        if os.path.isdir(log_name_path):
            for scenario_type in os.listdir(log_name_path):
                # Only process scenario types in resample_scenarios
                if scenario_type in resample_scenarios:
                    scenario_type_path = os.path.join(log_name_path, scenario_type)
                    if os.path.isdir(scenario_type_path):
                        for scenario_token in os.listdir(scenario_type_path):
                            scenario_token_path = os.path.join(scenario_type_path, scenario_token)
                            if os.path.isdir(scenario_token_path):
                                # Append the token to the corresponding scenario_type list
                                scenario_tokens[scenario_type].append(scenario_token)

    return scenario_tokens

cache_dir = 'exp/cache_pdm_open'
all_exist_tokens=get_resample_scenario_tokens(cache_dir, resample_scenarios)

In [11]:
print(len(all_exist_tokens))

12


In [29]:
print(all_exist_tokens.get('near_pedestrian_on_crosswalk_with_ego'))

['a21b982406725ebd', 'f9c21012e8f65fc4', '2ad5a764d3b85c8f', 'c1763508ac2b5f85', 'd30a370fca7a5d24', '8fae7fcb6fc4581e']


In [25]:
import sqlite3
import os
from typing import Dict, List, Generator
from nuplan.database.nuplan_db.nuplan_scenario_queries import (
    get_lidarpc_tokens_with_scenario_tag_from_db,
    get_sensor_data_token_timestamp_from_db,
    get_sensor_token_map_name_from_db,
)
from collections import defaultdict

def execute_many(query: str, params: Tuple, db_file: str):
    """
    Execute a SQL query on a SQLite database and yield the results row by row.
    :param query: The SQL query string.
    :param params: Parameters for the query.
    :param db_file: Path to the SQLite database file.
    :yield: Rows from the query result.
    """
    with sqlite3.connect(db_file) as conn:
        conn.row_factory = sqlite3.Row  # Enable dictionary-like row access
        cursor = conn.cursor()
        cursor.execute(query, params)
        for row in cursor.fetchall():
            yield row

def get_scenario_info_from_db(db_file: str, resample_scenarios: Dict[str, float]) -> Dict[str, List[str]]:
    """
    Get the scenario tokens for scenario types specified in `resample_scenarios` from a single database file.
    
    :param db_file: Path to the SQLite database file.
    :param resample_scenarios: A dictionary where keys are scenario types and values are desired counts.
    :return: A dictionary where keys are scenario types and values are lists of scenario tokens.
    """
    query = """
    SELECT st.type, st.lidar_pc_token
    FROM scenario_tag AS st;
    """
    scenario_info = {}
    # Execute the query and process the results
    for row in execute_many(query, (), db_file):
        scenario_type = row["type"]
        token = row["lidar_pc_token"]  # This is likely in binary format
        # Convert binary token to hex string
        token_hex = token.hex() if isinstance(token, bytes) else token
        
        # Only include scenario_types present in resample_scenarios
        if scenario_type in resample_scenarios:
            if scenario_type not in scenario_info:
                scenario_info[scenario_type] = []
            scenario_info[scenario_type].append(token_hex)
    
    return scenario_info


def get_scenario_tokens_from_all_dbs(db_directory: str, db_files: List[str], resample_scenarios: Dict[str, float]) -> Dict[str, List[str]]:
    """
    Get scenario tokens for specific scenario types from multiple database files.
    
    :param db_directory: Directory containing the `.db` files.
    :param db_files: List of `.db` files to process.
    :param get_resample_scenarios: A dictionary where keys are scenario types and values are desired counts.
    :return: A dictionary where keys are scenario types and values are lists of scenario tokens.
    """
    aggregated_scenario_info = {}

    for db_file in db_files:
        db_path = os.path.join(db_directory, db_file)
        if not os.path.isfile(db_path):
            continue
        scenario_info = get_scenario_info_from_db(db_path, resample_scenarios)

        # Merge results from the current database file
        for scenario_type, tokens in scenario_info.items():
            if scenario_type not in aggregated_scenario_info:
                aggregated_scenario_info[scenario_type] = set()
            aggregated_scenario_info[scenario_type].update(tokens)
            
    for scenario_type in aggregated_scenario_info:
        aggregated_scenario_info[scenario_type] = list(aggregated_scenario_info[scenario_type])

    return aggregated_scenario_info

In [22]:
import sqlite3
from typing import Dict, List

def get_tokens_by_scenario_type(db_file: str, scenario_type: str) -> List[str]:
    """
    Get the tokens for a specified scenario type from the NuPlan database.

    :param db_file: Path to the SQLite database file.
    :param scenario_type: The scenario type to filter tokens by.
    :return: A list of tokens corresponding to the specified scenario type.
    """
    query = """
    SELECT st.lidar_pc_token
    FROM scenario_tag AS st
    WHERE st.type = ?;
    """
    
    tokens = []
    
    # Connect to the database
    with sqlite3.connect(db_file) as conn:
        conn.row_factory = sqlite3.Row  # Enable dictionary-like row access
        cursor = conn.cursor()
        
        # Execute the query with the specified scenario type
        cursor.execute(query, (scenario_type,))
        
        # Fetch all matching tokens
        for row in cursor.fetchall():
            token = row["lidar_pc_token"]
            token_hex = token.hex() if isinstance(token, bytes) else token
            tokens.append(token_hex)
    
    return tokens

In [17]:
db_directory = os.path.join(os.environ["NUPLAN_DATA_ROOT"], "nuplan-v1.1/trainval")

In [18]:
yaml_file = "nuplan.yaml"   
train_logs = get_train_data_from_yaml(yaml_file)
db_files = [
    file for file in os.listdir(db_directory)
    if file.endswith(".db") and os.path.splitext(file)[0] in train_logs
]

In [26]:
all_scenario_tokens = get_scenario_tokens_from_all_dbs(db_directory, db_files, resample_scenarios)

In [41]:
def filter_and_limit_scenarios(
    all_scenario_tokens: Dict[str, List[str]],
    all_exist_tokens: Dict[str, List[str]],
    resample_scenarios: Dict[str, float]
) -> Dict[str, List[str]]:
    """
    Filter tokens from `all_scenario_tokens` that do NOT exist in `all_exist_tokens` for the same scenario type,
    and limit the number of tokens for each `scenario_type` to the count specified in `resample_scenarios`.

    If the target_count is greater than the length of valid_tokens, take all valid_tokens and print the difference.

    :param all_scenario_tokens: Dictionary of scenario types and their corresponding tokens.
    :param all_exist_tokens: Dictionary of scenario types and their corresponding tokens to exclude.
    :param resample_scenarios: Dictionary of scenario types and their target counts.
    :return: A dictionary with filtered and limited tokens for each scenario type.
    """
    filtered_scenarios = {}

    for scenario_type, tokens in all_scenario_tokens.items():
        # Check if the scenario_type is in resample_scenarios
        if scenario_type in resample_scenarios:
            # Get the desired count for this scenario_type
            target_count = int(resample_scenarios[scenario_type])  # Convert float to int

            # Get the tokens to exclude for this scenario_type from all_exist_tokens
            exclude_tokens = all_exist_tokens.get(scenario_type, [])

            # Filter tokens that are NOT in the exclude_tokens list
            # valid_tokens = [token for token in tokens if token not in exclude_tokens]
            valid_tokens = [token for token in tokens if token not in exclude_tokens]   
            # Check if target_count is greater than the length of valid_tokens
            if target_count > len(valid_tokens):
                # Print the difference
                print(f"Warning: Scenario '{scenario_type}' - target_count ({target_count}) "
                      f"is greater than valid_tokens length ({len(valid_tokens)}). "
                      f"Missing {target_count - len(valid_tokens)} tokens.")

                # Take all valid tokens
                filtered_scenarios[scenario_type] = valid_tokens
            else:
                # Limit the number of tokens to the target count
                filtered_scenarios[scenario_type] = valid_tokens[:target_count]

    return filtered_scenarios

In [42]:
filtered_scenarios = filter_and_limit_scenarios(all_scenario_tokens, all_exist_tokens, resample_scenarios)

In [48]:
print(len(filtered_scenarios))

12


In [33]:
from ruamel.yaml import YAML

def generate_scenario_token_yaml(filtered_scenarios, template_path, output_path):
    """
    Generates a scenario_filter config YAML file with filtered scenarios written under `scenario_tokens`,
    while preserving the original format, including null fields, indentation, empty lines, and field order.

    :param filtered_scenarios: A dictionary where keys are scenario types and values are lists of tokens.
    :param template_path: Path to the template YAML file.
    :param output_path: Path to save the generated YAML file.
    """
    # Initialize ruamel.yaml
    yaml = YAML()
    yaml.preserve_quotes = True  # Preserve quotes and formatting

    # Load the template YAML file
    with open(template_path, 'r', encoding='utf-8') as template_file:
        scenario_filter_config = yaml.load(template_file)

    # Flatten all tokens from filtered_scenarios into a single list
    all_tokens = []
    for tokens in filtered_scenarios.values():
        all_tokens.extend(tokens)

    # Update the `scenario_tokens` field in the config
    # scenario_filter_config['scenario_tokens'] = all_tokens
    scenario_filter_config['scenario_tokens'] = [f"'{token}'" for token in all_tokens]
    # Write the updated YAML back to the output file
    with open(output_path, 'w', encoding='utf-8') as output_file:
        yaml.dump(scenario_filter_config, output_file)

    print(f"YAML file successfully generated at: {output_path}")

In [51]:
from ruamel.yaml import YAML

def generate_scenario_type_yaml(filtered_scenarios, template_path, output_path):
    """
    Generates a scenario_filter config YAML file with filtered scenarios written under `scenario_tokens`,
    while preserving the original format, including null fields, indentation, empty lines, and field order.

    :param filtered_scenarios: A dictionary where keys are scenario types and values are lists of tokens.
    :param template_path: Path to the template YAML file.
    :param output_path: Path to save the generated YAML file.
    """
    # Initialize ruamel.yaml
    yaml = YAML()
    yaml.preserve_quotes = True  # Preserve quotes and formatting

    # Load the template YAML file
    with open(template_path, 'r', encoding='utf-8') as template_file:
        scenario_filter_config = yaml.load(template_file)

    # Flatten all tokens from filtered_scenarios into a single list
    all_types = []
    for tokens in filtered_scenarios.keys():
        all_types.add(tokens)

    # Update the `scenario_tokens` field in the config
    # scenario_filter_config['scenario_tokens'] = all_tokens
    scenario_filter_config['scenario_types'] = [f"'{types}'" for types in all_types]
    # Write the updated YAML back to the output file
    with open(output_path, 'w', encoding='utf-8') as output_file:
        yaml.dump(scenario_filter_config, output_file)

    print(f"YAML file successfully generated at: {output_path}")

In [53]:
from ruamel.yaml import YAML

def generate_scenario_filter_yaml(filtered_scenarios, template_path, output_path):
    """
    Generates a scenario_filter config YAML file with filtered scenarios written under `scenario_tokens`,
    while preserving the original format, including null fields, indentation, empty lines, and field order.

    :param filtered_scenarios: A dictionary where keys are scenario types and values are lists of tokens.
    :param template_path: Path to the template YAML file.
    :param output_path: Path to save the generated YAML file.
    """
    # Initialize ruamel.yaml
    yaml = YAML()
    yaml.preserve_quotes = True  # Preserve quotes and formatting

    # Load the template YAML file
    with open(template_path, 'r', encoding='utf-8') as template_file:
        scenario_filter_config = yaml.load(template_file)

    # Flatten all tokens from filtered_scenarios into a single list
    all_types = set()  # Use a set to avoid duplicates
    for token in filtered_scenarios.keys():
        all_types.add(token)

    # Update the `scenario_types` field in the config
    scenario_filter_config['scenario_types'] = [f"'{types}'" for types in all_types]

    # Write the updated YAML back to the output file
    with open(output_path, 'w', encoding='utf-8') as output_file:
        yaml.dump(scenario_filter_config, output_file)

    print(f"YAML file successfully generated at: {output_path}")

In [54]:
template_path = 'template.yaml'
output_path = 'resample.yaml'             # Path to save the generated YAML file
# Generate the YAML
generate_scenario_filter_yaml(filtered_scenarios, template_path, output_path)

YAML file successfully generated at: resample.yaml
