In [None]:
import pandas as pd 
import numpy as np

In [2]:
from math import floor

def merge_intervals(intervals):
    """
    Given a list of (start, end) intervals (assumed sorted by start),
    merge overlapping intervals and return the merged list.
    """
    if not intervals:
        return []
    intervals = sorted(intervals, key=lambda x: x[0])
    merged = [intervals[0]]
    for current in intervals[1:]:
        prev_start, prev_end = merged[-1]
        cur_start, cur_end = current
        if cur_start <= prev_end:  # Overlap
            merged[-1] = (prev_start, max(prev_end, cur_end))
        else:
            merged.append(current)
    return merged

def max_aircraft_count(df, trenches, lat_min, lat_max, lon_min, lon_max, cell_size, window_size):
    lat_bounds = np.arange(lat_min, lat_max + cell_size, cell_size)
    lon_bounds = np.arange(lon_min, lon_max + cell_size, cell_size)
    
    # Dictionary to accumulate time intervals.
    # Key: (window_start, alt_trench_idx, lat_idx, lon_idx)
    # Value: dict mapping flight id -> list of (t_start, t_end) intervals (seconds)
    intervals_dict = {}
    
    # Process each flight segment
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Counting aircraft"):
        flight_id = row['id']
        seg_start = float(row['from_time'])
        seg_end   = float(row['to_time'])
        seg_duration = seg_end - seg_start
        if seg_duration <= 0:
            continue
        
        lat0, lat1 = float(row['from_lat']), float(row['to_lat'])
        lon0, lon1 = float(row['from_lon']), float(row['to_lon'])
        alt0, alt1 = float(row['from_alt']), float(row['to_alt'])
        
        # Determine overlapping 15‑minute windows.
        window_start = floor(seg_start / window_size) * window_size
        while window_start < seg_end:
            window_end = window_start + window_size
            t0 = max(seg_start, window_start)
            t1 = min(seg_end, window_end)
            if t1 <= t0:
                window_start += window_size
                continue
            
            tau0 = (t0 - seg_start) / seg_duration
            tau1 = (t1 - seg_start) / seg_duration
            taus = [tau0, tau1]
            
            # Helper to add a tau value within (tau0, tau1)
            def add_tau(t):
                if tau0 < t < tau1:
                    taus.append(t)
            
            # Latitude grid boundaries
            if lat1 != lat0:
                lat_seg_min, lat_seg_max = min(lat0, lat1), max(lat0, lat1)
                for b in lat_bounds:
                    if lat_seg_min < b < lat_seg_max:
                        t_val = (b - lat0) / (lat1 - lat0)
                        add_tau(t_val)
            
            # Longitude grid boundaries
            if lon1 != lon0:
                lon_seg_min, lon_seg_max = min(lon0, lon1), max(lon0, lon1)
                for b in lon_bounds:
                    if lon_seg_min < b < lon_seg_max:
                        t_val = (b - lon0) / (lon1 - lon0)
                        add_tau(t_val)
            
            # Altitude trench boundaries
            if alt1 != alt0:
                for alt_lower, alt_upper in trenches:
                    # Check lower boundary
                    if (alt_lower - alt0) * (alt_lower - alt1) < 0:
                        t_val = (alt_lower - alt0) / (alt1 - alt0)
                        add_tau(t_val)
                    # Check upper boundary
                    if (alt_upper - alt0) * (alt_upper - alt1) < 0:
                        t_val = (alt_upper - alt0) / (alt1 - alt0)
                        add_tau(t_val)
            
            taus = sorted(set(taus))
            
            # Break the segment in this window into subintervals
            for i in range(len(taus) - 1):
                tau_sub0, tau_sub1 = taus[i], taus[i+1]
                tau_mid = (tau_sub0 + tau_sub1) / 2.0
                
                # Compute mid-point of the subinterval
                lat_mid = lat0 + tau_mid * (lat1 - lat0)
                lon_mid = lon0 + tau_mid * (lon1 - lon0)
                alt_mid = alt0 + tau_mid * (alt1 - alt0)
                
                if not (lat_min <= lat_mid <= lat_max and lon_min <= lon_mid <= lon_max):
                    continue
                
                # Determine grid cell indices
                lat_idx = int((lat_mid - lat_min) / cell_size)
                lon_idx = int((lon_mid - lon_min) / cell_size)
                
                # Determine altitude trench index
                alt_trench_idx = None
                for idx_trench, (alt_low, alt_high) in enumerate(trenches):
                    if alt_low <= alt_mid < alt_high:
                        alt_trench_idx = idx_trench
                        break
                if alt_trench_idx is None:
                    continue
                
                # Calculate the absolute start and end times for this subinterval
                t_sub_start = seg_start + tau_sub0 * seg_duration
                t_sub_end   = seg_start + tau_sub1 * seg_duration
                
                key = (window_start, alt_trench_idx, lat_idx, lon_idx)
                if key not in intervals_dict:
                    intervals_dict[key] = {}
                if flight_id not in intervals_dict[key]:
                    intervals_dict[key][flight_id] = []
                intervals_dict[key][flight_id].append((t_sub_start, t_sub_end))
            
            window_start += window_size
    
    # Now, for each grid cell in each window, compute the maximum instantaneous
    # count (i.e. the maximum number of flights concurrently present).
    results = {}
    for key, flights in intervals_dict.items():
        # We'll collect events from all flight intervals.
        events = []
        # For each flight, merge intervals (so overlapping intervals for the same flight
        # are counted only once) and add start/end events.
        for flight_id, interval_list in flights.items():
            merged = merge_intervals(interval_list)
            for start, end in merged:
                # Use (time, delta) events. To ensure that an interval ending at the same
                # time another begins does not count as overlap, we sort end events before start events.
                events.append((start, 1))
                events.append((end, -1))
        if not events:
            results[key] = 0
            continue
        # Sort events by time; in case of ties, end (-1) events come before start (+1)
        events.sort(key=lambda x: (x[0], x[1]))
        
        current_count = 0
        max_count = 0
        for time, delta in events:
            current_count += delta
            max_count = max(max_count, current_count)
        results[key] = max_count
    
    return results


# Multiprocessing Program Entry

In [None]:
import pandas as pd 
import numpy as np
import os
import glob
from multiprocessing import Pool
from tqdm import tqdm
import time
from functools import partial

def process_file(file_path, trenches, lat_min, lat_max, lon_min, lon_max, cell_size, window_size):
    """
    Process a single CSV file and save the results.
    """
    try:
        # Read the CSV file
        df = pd.read_csv(file_path)
        
        # Apply the max_aircraft_count function
        results = max_aircraft_count(df, trenches, lat_min, lat_max, lon_min, lon_max, cell_size, window_size)
        
        # Convert results to DataFrame
        results_df = pd.DataFrame([
            {
                'window_start': key[0],
                'alt_trench_idx': key[1],
                'lat_idx': key[2],
                'lon_idx': key[3],
                'max_count': count
            }
            for key, count in results.items()
        ])
        
        # Create the output path
        output_path = file_path.replace('routes', 'counts')
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Save the results
        results_df.to_csv(output_path, index=False)
        
        return file_path, True
    except Exception as e:
        return file_path, f"Error: {str(e)}"

def process_all_files(routes_dir='routes', num_processes=None, 
                     trenches=[(0, 10000), (10000, 20000), (20000, 40000)],
                     lat_min=45.0, lat_max=55.0, lon_min=20.0, lon_max=40.0,
                     cell_size=0.1, window_size=900):  # window_size=900 for 15 minutes in seconds
    """
    Process all CSV files in the routes directory and its subdirectories.
    
    Parameters:
    - routes_dir: Directory containing the route files
    - num_processes: Number of processes to use (defaults to CPU count)
    - trenches, lat_min, lat_max, lon_min, lon_max, cell_size, window_size: Parameters for max_aircraft_count
    """
    # Find all CSV files in the routes directory and its subdirectories
    csv_files = glob.glob(os.path.join(routes_dir, '**', '*.csv'), recursive=True)
    
    if not csv_files:
        print(f"No CSV files found in {routes_dir}")
        return
    
    print(f"Found {len(csv_files)} CSV files to process")
    
    # Create the counts directory if it doesn't exist
    counts_dir = 'counts'
    os.makedirs(counts_dir, exist_ok=True)
    
    # Create a partial function with fixed parameters
    process_func = partial(
        process_file,
        trenches=trenches,
        lat_min=lat_min,
        lat_max=lat_max,
        lon_min=lon_min,
        lon_max=lon_max,
        cell_size=cell_size,
        window_size=window_size
    )
    
    # Process files in parallel
    start_time = time.time()
    with Pool(processes=num_processes) as pool:
        results = list(tqdm(
            pool.imap(process_func, csv_files),
            total=len(csv_files),
            desc="Processing files"
        ))
    
    # Report results
    success_count = sum(1 for _, status in results if status is True)
    failed_files = [(file, status) for file, status in results if status is not True]
    
    print(f"Processing completed in {time.time() - start_time:.2f} seconds")
    print(f"Successfully processed {success_count} out of {len(csv_files)} files")
    
    if failed_files:
        print(f"Failed to process {len(failed_files)} files:")
        for file, error in failed_files[:10]:  # Show first 10 errors
            print(f"  - {file}: {error}")
        if len(failed_files) > 10:
            print(f"  ... and {len(failed_files) - 10} more")

# Example usage
if __name__ == "__main__":
    # Define your parameters here
    altitude_trenches = [
        (7315.2, 8839.2),  # Lower En-Route (24,000 ft to 29,000 ft in meters)
        (8839.2, 10668.0),  # Upper En-Route, Lower RVSM (29,000 ft to 35,000 ft in meters)
        (10668.0, 21243.0)  # Upper En-Route, Upper RVSM (35,000 ft to 69,696 ft in meters)
    ]
    
    lat_min, lat_max = 30.0, 72.0  # Example latitude range
    lon_min, lon_max = -15.0, 40.0  # Example longitude range
    cell_size = 0.25  # Example cell size in degrees
    window_size = 15 * 60  # 15 minutes in seconds
    
    process_all_files(
        routes_dir='routes',
        num_processes=None,
        trenches=altitude_trenches,
        lat_min=lat_min,
        lat_max=lat_max,
        lon_min=lon_min,
        lon_max=lon_max,
        cell_size=cell_size,
        window_size=window_size
    )