Step5 — Filter Following Segments, Summarize, and Plot
This notebook processes lane car-following Excel files (e.g., *following_parts*.xlsx) to:

identify valid segments using centralized speed, distance, and min duration thresholds
save per-segment row data and per-segment summary stats
generate per-pair plots (original + filtered)
Adjust thresholds and paths in the Parameters cell below.

In [1]:

# ===================== Parameters =====================
INPUTS = [
    r"/Volumes/weishanshan/Geo trax tool results/DJI_0031/step4 Splited car following behavior",
]
OUTPUT_ROOT = r"/Volumes/weishanshan/Geo trax tool results/DJI_0031/step5 the result of 1st filter outliers"
PATTERN      = r"*following_parts*.xlsx"
RECURSIVE    = True

# ===================== Tunables =====================
SPEED_THRESHOLD_KMH  = 3.6
DISTANCE_THRESHOLD_M = 100.0
MIN_SEG_DURATION_S   = 3.0




import os
import re
import gc
import glob
from typing import List, Tuple, Iterable

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

NEEDED_COLS = [
    'part','frame_id','time_s',
    'follower_uid','leader_uid',
    'headway_distance_m','net_headway_distance_m',
    'time_headway_s','net_time_headway_s',
    'rel_v_kph','rel_a_mps2',
    'TTC_s',
    'leader_v','leader_a',
    'follower_v','follower_a'
]

DTYPES_NUMERIC_FLOAT32 = [
    'time_s','headway_distance_m','net_headway_distance_m',
    'time_headway_s','net_time_headway_s',
    'rel_v_kph','rel_a_mps2','TTC_s',
    'leader_v','leader_a','follower_v','follower_a'
]




def infer_lane_name(file_path: str, df: pd.DataFrame) -> str:
    basename = os.path.basename(file_path)
    m = re.search(r'(lane-?\d+)', basename, flags=re.IGNORECASE)
    if not m:
        m = re.search(r'(lane-?\d+)', file_path, flags=re.IGNORECASE)
    if m:
        return m.group(1).lower()
    for col in ['lane', 'lane_smoothed', 'lane_name']:
        if col in df.columns:
            try:
                val = df[col].mode().iloc[0]
                return str(val).lower()
            except Exception:
                pass
    return 'lane'


def find_valid_segments(group: pd.DataFrame,
                        speed_threshold: float = SPEED_THRESHOLD_KMH,
                        distance_threshold: float = DISTANCE_THRESHOLD_M,
                        min_duration: float = MIN_SEG_DURATION_S) -> List[pd.DataFrame]:
    segments: List[pd.DataFrame] = []
    current_segment_rows: List[int] = []
    prev_frame_id = None

    for idx, row in group.iterrows():
        cond_speed = row['follower_v'] > speed_threshold
        cond_dist  = row['headway_distance_m'] <= distance_threshold
        cond = cond_speed and cond_dist
        if cond and (prev_frame_id is None or row['frame_id'] == prev_frame_id + 1):
            current_segment_rows.append(idx)
        else:
            if current_segment_rows:
                seg = group.loc[current_segment_rows]
                duration = seg['time_s'].iloc[-1] - seg['time_s'].iloc[0]
                if duration >= min_duration:
                    segments.append(seg.copy())
                current_segment_rows = []
            if cond:
                current_segment_rows.append(idx)
        prev_frame_id = row['frame_id']

    if current_segment_rows:
        seg = group.loc[current_segment_rows]
        duration = seg['time_s'].iloc[-1] - seg['time_s'].iloc[0]
        if duration >= min_duration:
            segments.append(seg.copy())
    return segments


def compute_segment_statistics(segment: pd.DataFrame) -> dict:
    metrics = {
        'headway_distance_m': segment['headway_distance_m'],
        'net_headway_distance_m': segment['net_headway_distance_m'],
        'time_headway_s': segment['time_headway_s'],
        'net_time_headway_s': segment['net_time_headway_s'],
        'rel_v_kph': segment['rel_v_kph'],
        'rel_a_mps2': segment['rel_a_mps2'],
        'TTC_s': segment['TTC_s'],
        'leader_v': segment['leader_v'],
        'leader_a': segment['leader_a'],
        'follower_v': segment['follower_v'],
        'follower_a': segment['follower_a'],
    }
    out = {}
    for name, series in metrics.items():
        clean = series.dropna()
        out[f'{name}_min']  = clean.min()  if not clean.empty else np.nan
        out[f'{name}_max']  = clean.max()  if not clean.empty else np.nan
        out[f'{name}_mean'] = clean.mean() if not clean.empty else np.nan
    return out


def plot_pair(group: pd.DataFrame, segments: List[pd.DataFrame], out_dir: str, base_name: str,
              speed_threshold: float = SPEED_THRESHOLD_KMH,
              distance_threshold: float = DISTANCE_THRESHOLD_M,
              min_duration: float = MIN_SEG_DURATION_S) -> Tuple[str, str]:
    C_FOLLOWER = 'tab:blue'
    C_LEADER   = 'tab:orange'
    C_THW      = 'tab:green'

    follower_kmh = group['follower_v']
    leader_kmh   = group['leader_v']
    t   = group['time_s']
    thw = group['time_headway_s']

    valid_mask = (group['follower_v'] > speed_threshold) & (group['headway_distance_m'] <= distance_threshold)
    cont_mask  = group['frame_id'].diff().fillna(1) == 1
    contiguous_valid = valid_mask & cont_mask
    contiguous_valid.iloc[0] = bool(valid_mask.iloc[0])

    spans: List[Tuple[int,int,bool]] = []
    start = 0
    is_valid = bool(contiguous_valid.iloc[0])
    for i in range(1, len(group)):
        if bool(contiguous_valid.iloc[i]) != is_valid:
            spans.append((start, i-1, is_valid))
            start = i
            is_valid = bool(contiguous_valid.iloc[i])
    spans.append((start, len(group)-1, is_valid))

    for idx, (s, e, ok) in enumerate(spans):
        if ok and (t.iloc[e] - t.iloc[s] < min_duration):
            spans[idx] = (s, e, False)

    original_path = os.path.join(out_dir, f"{base_name}_original.png")
    fig = plt.figure(figsize=(9,4))
    ax1 = plt.gca()
    ax2 = ax1.twinx()
    ax1.plot(t, follower_kmh, label='Follower speed', color=C_FOLLOWER, linewidth=1.2)
    ax1.plot(t, leader_kmh,   label='Leader speed',   color=C_LEADER,   linewidth=1.2)
    ax2.plot(t, thw,          label='Time headway (s)', color=C_THW, linestyle='--', linewidth=1.2)
    ax1.set_xlabel('Time (s)')
    ax1.set_ylabel('Speed (km/h)')
    ax2.set_ylabel('Time headway (s)')
    ax2.tick_params(axis='y')
    for s, e, ok in spans:
        if not ok:
            color = 'yellow' if bool(valid_mask.iloc[s]) else 'red'
            ax1.axvspan(t.iloc[s], t.iloc[e], color=color, alpha=0.3)
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    by_label = {}
    for h, lb in zip(lines1 + lines2, labels1 + labels2):
        if lb not in by_label:
            by_label[lb] = h
    ax1.legend(by_label.values(), by_label.keys(), fontsize=7, loc='best')
    plt.title(f"Original: {base_name}")
    plt.tight_layout()
    fig.savefig(original_path, dpi=150)
    plt.close(fig)

    filtered_path = os.path.join(out_dir, f"{base_name}_1st_filtered.png")
    fig = plt.figure(figsize=(9,4))
    ax1 = plt.gca()
    ax2 = ax1.twinx()

    segment_boundaries = []
    first = True
    for seg in segments:
        t_seg = seg['time_s']
        if t_seg.empty:
            continue
        ax1.plot(t_seg, seg['follower_v'], color='tab:blue', linewidth=1.2,
                 label='Follower speed' if first else '_nolegend_')
        ax1.plot(t_seg, seg['leader_v'],   color='tab:orange', linewidth=1.2,
                 label='Leader speed' if first else '_nolegend_')
        ax2.plot(t_seg, seg['time_headway_s'], color='tab:green', linestyle='--', linewidth=1.2,
                 label='Time headway (s)' if first else '_nolegend_')

        start_t, end_t = t_seg.iloc[0], t_seg.iloc[-1]
        segment_boundaries.append((start_t, end_t))

        ax1.scatter([start_t, end_t],
                    [seg['follower_v'].iloc[0], seg['follower_v'].iloc[-1]],
                    marker='o', s=16, color='tab:blue')
        ax1.scatter([start_t, end_t],
                    [seg['leader_v'].iloc[0],  seg['leader_v'].iloc[-1]],
                    marker='o', s=16, color='tab:orange')
        ax2.scatter([start_t, end_t],
                    [seg['time_headway_s'].iloc[0], seg['time_headway_s'].iloc[-1]],
                    marker='o', s=16, color='tab:green')
        first = False

    ax1.set_xlabel('Time (s)')
    ax1.set_ylabel('Speed (km/h)')
    ax2.set_ylabel('Time headway (s)')
    ax2.tick_params(axis='y')
    for start_t, end_t in segment_boundaries:
        ax1.axvline(start_t, linestyle=':', linewidth=0.8)
        ax1.axvline(end_t,   linestyle=':', linewidth=0.8)
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    by_label = {}
    for h, lb in zip(lines1 + lines2, labels1 + labels2):
        if lb not in by_label:
            by_label[lb] = h
    ax1.legend(by_label.values(), by_label.keys(), fontsize=7, loc='best')
    plt.title(f"Filtered: {base_name}")
    plt.tight_layout()
    fig.savefig(filtered_path, dpi=150)
    plt.close(fig)

    return original_path, filtered_path





def _append_df_csv(df: pd.DataFrame, csv_path: str) -> None:
    first = not os.path.exists(csv_path)
    df.to_csv(csv_path, mode='a', header=first, index=False)

def _append_dict_csv(row: dict, csv_path: str) -> None:
    first = not os.path.exists(csv_path)
    pd.DataFrame([row]).to_csv(csv_path, mode='a', header=first, index=False)

def process_dataset(file_path: str,
                    output_root: str,
                    speed_threshold: float = SPEED_THRESHOLD_KMH,
                    distance_threshold: float = DISTANCE_THRESHOLD_M,
                    min_duration: float = MIN_SEG_DURATION_S) -> None:
    df = pd.read_excel(file_path, sheet_name=0, usecols=lambda c: c in set(NEEDED_COLS))

    if 'part' in df.columns:
        try:
            df['part'] = pd.to_numeric(df['part'], errors='coerce').fillna(-1).astype('int16')
        except Exception:
            df['part'] = df['part'].astype(str)
    if 'frame_id' in df.columns:
        df['frame_id'] = pd.to_numeric(df['frame_id'], errors='coerce').astype('Int64').astype('float64')
    for col in DTYPES_NUMERIC_FLOAT32:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce').astype('float32')
    for col in ['follower_uid','leader_uid']:
        if col in df.columns:
            df[col] = df[col].astype(str)

    lane_name = infer_lane_name(file_path, df)
    file_stub = os.path.splitext(os.path.basename(file_path))[0]
    parts = sorted(pd.unique(df['part'])) if 'part' in df.columns else [0]

    for part in parts:
        part_df = df[df['part'] == part].copy() if 'part' in df.columns else df.copy()
        part_df.sort_values(['follower_uid','leader_uid','frame_id'], inplace=True)

        part_dir = os.path.join(output_root, f"{lane_name}_{file_stub}_part{part}")
        filtered_plots_dir  = os.path.join(part_dir, f"{lane_name}_{file_stub}_part{part}_1st_filtered_plots")
        original_plots_dir  = os.path.join(part_dir, f"{lane_name}_{file_stub}_part{part}_original_plots")
        os.makedirs(filtered_plots_dir, exist_ok=True)
        os.makedirs(original_plots_dir, exist_ok=True)

        seg_csv_path   = os.path.join(part_dir, f"{lane_name}_{file_stub}_part{part}_1st_filtered_data.csv")
        summ_csv_path  = os.path.join(part_dir, f"{lane_name}_{file_stub}_part{part}_1st_segment_summary.csv")
        for p in [seg_csv_path, summ_csv_path]:
            if os.path.exists(p):
                os.remove(p)

        for (follower, leader), group in part_df.groupby(['follower_uid','leader_uid'], sort=False):
            group_sorted = group.sort_values('frame_id').reset_index(drop=True)

            segments = find_valid_segments(group_sorted,
                                           speed_threshold=speed_threshold,
                                           distance_threshold=distance_threshold,
                                           min_duration=min_duration)
            if not segments:
                del group_sorted
                continue

            for seg_id, seg in enumerate(segments, start=1):
                seg = seg.copy()
                seg['segment_id']   = seg_id
                seg['part']         = part
                _append_df_csv(seg, seg_csv_path)
                stats = compute_segment_statistics(seg)
                _append_dict_csv({
                    'segment_id': seg_id,
                    'part': part,
                    'follower_uid': follower,
                    'leader_uid': leader,
                    **stats
                }, summ_csv_path)
                del seg

            base_name = f"{lane_name}_{file_stub}_part{part}_{follower}_{leader}"
            orig_path, filt_path = plot_pair(group_sorted, segments, original_plots_dir, base_name,
                                             speed_threshold=speed_threshold,
                                             distance_threshold=distance_threshold,
                                             min_duration=min_duration)
            os.replace(filt_path, os.path.join(filtered_plots_dir, os.path.basename(filt_path)))
            del group_sorted, segments
            gc.collect()

        del part_df
        gc.collect()

    del df
    gc.collect()


def _expand_inputs(paths_or_dirs: Iterable[str], pattern: str, recursive: bool=True) -> List[str]:
    found = []
    for p in paths_or_dirs:
        if os.path.isfile(p) and p.lower().endswith('.xlsx'):
            found.append(p)
        elif os.path.isdir(p):
            glob_pat = os.path.join(p, '**', pattern) if recursive else os.path.join(p, pattern)
            found.extend(glob.glob(glob_pat, recursive=recursive))
        else:
            found.extend(glob.glob(p, recursive=recursive))
    found = sorted(list(dict.fromkeys(found)))
    return found


def process_many(paths_or_dirs: Iterable[str],
                 output_root: str,
                 pattern: str = '*following_parts*.xlsx',
                 recursive: bool = True,
                 speed_threshold: float = SPEED_THRESHOLD_KMH,
                 distance_threshold: float = DISTANCE_THRESHOLD_M,
                 min_duration: float = MIN_SEG_DURATION_S) -> None:
    files = _expand_inputs(paths_or_dirs, pattern=pattern, recursive=recursive)
    if not files:
        print(f"[WARN] No files matched. Inputs={paths_or_dirs}, pattern='{pattern}'")
        return

    os.makedirs(output_root, exist_ok=True)
    print(f"[INFO] Found {len(files)} files to process.")
    for i, f in enumerate(files, 1):
        print(f"[{i}/{len(files)}] Processing: {f}")
        try:
            process_dataset(f, output_root=output_root,
                            speed_threshold=speed_threshold,
                            distance_threshold=distance_threshold,
                            min_duration=min_duration)
        except Exception as e:
            print(f"[ERROR] Failed on {f}: {e}")
        gc.collect()

        
        

# ===================== Run =====================
process_many(paths_or_dirs=INPUTS,
             output_root=OUTPUT_ROOT,
             pattern=PATTERN,
             recursive=RECURSIVE,
             speed_threshold=SPEED_THRESHOLD_KMH,
             distance_threshold=DISTANCE_THRESHOLD_M,
             min_duration=MIN_SEG_DURATION_S)

print("[DONE] All processing completed.")
print("Output root:", OUTPUT_ROOT)


[INFO] Found 6 files to process.
[1/6] Processing: /Volumes/weishanshan/Geo trax tool results/DJI_0031/step4 Splited car following behavior/lane-1_following_parts.xlsx
[2/6] Processing: /Volumes/weishanshan/Geo trax tool results/DJI_0031/step4 Splited car following behavior/lane-2_following_parts.xlsx
[3/6] Processing: /Volumes/weishanshan/Geo trax tool results/DJI_0031/step4 Splited car following behavior/lane1_following_parts.xlsx
[4/6] Processing: /Volumes/weishanshan/Geo trax tool results/DJI_0031/step4 Splited car following behavior/lane2_following_parts.xlsx
[5/6] Processing: /Volumes/weishanshan/Geo trax tool results/DJI_0031/step4 Splited car following behavior/lane_middle_LTR_following_parts.xlsx
[6/6] Processing: /Volumes/weishanshan/Geo trax tool results/DJI_0031/step4 Splited car following behavior/lane_middle_RTL_following_parts.xlsx
