In [1]:
# pyright: basic, reportUnknownVariableType=false, reportUnknownMemberType=false

import polars as pl
import json
import numpy as np
import pyarrow as pa
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import glob
import os

# Configure polars for better performance
pl.Config.set_fmt_str_lengths(50)
pl.Config.set_tbl_rows(20)
pl.Config.set_tbl_cols(20)

polars.config.Config

In [5]:
def extract_rtts_native(n=3):
    """
    Returns a list of Polars expressions to parse a JSON 'result' column
    and extract the first 'n' RTT values into separate columns.

    This version uses str.json_path_match to robustly handle mixed JSON schemas
    (e.g., some with 'rtt', others with 'x') without schema inference errors.
    """
    # 1. Use json_path_match to find all 'rtt' values in the list.
    #    The JSONPath `r"$[*].rtt"` means:
    #    - `$`     : Start at the root.
    #    - `[*]`   : Get all elements in the top-level array.
    #    - `.rtt`  : From each element, extract the value of the "rtt" key.
    #    This returns a List of Strings, e.g., ["25.519", "25.671", "25.801"].
    #    If no 'rtt' keys are found, it correctly returns an empty list or null.
    #    THIS IS THE KEY FIX, as it bypasses schema validation.
    rtt_list_expr = pl.col("result").str.json_path_match(r"$[*].rtt")

    # 2. The result of json_path_match is a list of STRINGS. We must cast them to floats.
    #    We use list.eval to run an expression on each element of the list.
    #    `strict=False` will turn any non-numeric values into null instead of erroring.
    rtt_floats_expr = rtt_list_expr.list.eval(
        pl.element().cast(pl.Float64, strict=False)
    )

    # 3. Create 'n' new columns by getting elements from the list of RTTs.
    #    This part remains the same.
    rtt_columns = [
        rtt_floats_expr.list.get(i).alias(f"rtt_{i+1}")
        for i in range(n)
    ]

    return rtt_columns

In [3]:
SOURCE_GLOB = "data/ping/**/*.parquet"
OUTPUT_PATH = "data/ping_parsed.parquet"
BATCH_SIZE = 100  # Process 1000 files at a time. Adjust as needed.
N_RTTS = 3

# Get a list of all file paths
all_files = glob.glob(SOURCE_GLOB)
print(f"Found {len(all_files)} files to process.")

Found 96562 files to process.


In [6]:
# Ensure the output directory exists and the old file is gone
if os.path.exists(OUTPUT_PATH):
    os.remove(OUTPUT_PATH)

first_batch = True
for i in range(0, len(all_files), BATCH_SIZE):
    batch_files = all_files[i : i + BATCH_SIZE]
    print(f"Processing batch {i // BATCH_SIZE + 1} ({len(batch_files)} files)...")

    # Scan only the files in the current batch
    lazy_df = pl.scan_parquet(batch_files)
    display(lazy_df.head(10).collect())

    # Apply the same transformation
    transformed_lazy = lazy_df.with_columns(
        extract_rtts_native(n=N_RTTS)
    ).drop("result")

    # For the first batch, create the file. For subsequent batches, append to it.
    if first_batch:
        transformed_lazy.sink_parquet(OUTPUT_PATH)
        first_batch = False
    else:
        # To append, we collect the result of the small scan and write_parquet
        # This is safe because each batch is small
        df_batch_result = transformed_lazy.collect(streaming=True)
        df_batch_result.write_parquet(OUTPUT_PATH, use_pyarrow=True, append=True)

Processing batch 1 (100 files)...


prb_id,dst_addr,ts,sent,rcvd,avg,result
i64,str,i64,i64,i64,f64,str
6878,"""91.243.43.19""",1749256784,3,0,-1.0,"""[{""x"": ""*""}, {""x"": ""*""}, {""x"": ""*""}]"""
6878,"""2001:4ba0:ffe0:ffff::4""",1749256785,3,3,25.62511,"""[{""rtt"": 25.519027}, {""rtt"": 25.671107}, {""rtt"": 2…"
6878,"""213.91.165.187""",1749256786,3,3,65.974445,"""[{""rtt"": 66.019935}, {""rtt"": 65.976451}, {""rtt"": 6…"
6878,"""146.185.219.73""",1749256787,3,3,86.520995,"""[{""rtt"": 86.613995}, {""rtt"": 86.530067}, {""rtt"": 8…"
6878,"""2a01:9e01:4d05:3333::a""",1749256789,3,0,-1.0,"""[{""x"": ""*""}, {""x"": ""*""}, {""x"": ""*""}]"""
6878,"""69.30.249.206""",1749256789,3,3,148.248946,"""[{""rtt"": 148.23621}, {""rtt"": 148.301316}, {""rtt"": …"
6878,"""45.41.55.165""",1749256789,3,0,-1.0,"""[{""x"": ""*""}, {""x"": ""*""}, {""x"": ""*""}]"""
6878,"""92.38.176.25""",1749256790,3,3,126.80081,"""[{""rtt"": 126.767651}, {""rtt"": 126.777332}, {""rtt"":…"
6878,"""212.62.68.29""",1749256793,3,3,29.075086,"""[{""rtt"": 29.087843}, {""rtt"": 29.118559}, {""rtt"": 2…"
6878,"""34.89.240.90""",1749256797,3,0,-1.0,"""[{""x"": ""*""}, {""x"": ""*""}, {""x"": ""*""}]"""


InvalidOperationError: list.eval operation not supported for dtype `str`

Resolved plan until failure:

	---> FAILED HERE RESOLVING 'sink' <---
Parquet SCAN [data/ping/ping-2025-06-07T0000.parquet/part.121.parquet, ... 99 other sources] [id: 129519447274064]
PROJECT */7 COLUMNS

In [7]:
# Read probe IP mapping CSV
ip_map = pl.read_csv("probe_ip_map.csv")
ip_map = ip_map.rename({"prb_id": "dst_prb_id"})
print("IP map sample:")
display(ip_map.head(10))

In [None]:
print("IP map head:")
display(ip_map.head(10))

In [None]:
print("Main dataframe head:")
display(df.head(10).collect())

In [None]:
# Join data with IP mapping
df1 = df.rename({"prb_id": "src"})
dfl2 = df1.join(
    ip_map.select(["ip", "dst_prb_id"]),
    left_on="dst_addr", right_on="ip",
    how="left"
)

# Fill nulls and cast to int
dfl2 = dfl2.with_columns(
    pl.col("dst_prb_id").fill_null(-1).cast(pl.Int64)
)

print("After join:")
display(dfl2.head(10).collect())

# Filter where dst_prb_id is not -1
dfl3 = dfl2.filter(pl.col("dst_prb_id") != -1)
print("\nAfter filtering:")
display(dfl3.head(10).collect())

# Select columns and rename
dfl6 = dfl3.select(["src", "dst_prb_id", "ts", "avg", "result", "sent", "rcvd"])
dfl6 = dfl6.rename({"dst_prb_id": "dst"})

print("\nFinal dataframe:")
display(dfl6.head(10).collect())
print(f"Total rows: {dfl6.collect().height}")

In [None]:
# Get unique edges
edges = dfl6.select(["src", "dst"]).unique()
print("Unique edges:")
display(edges.head(10).collect())

In [None]:
print(f"Total unique edges: {edges.collect().height}")

In [None]:
# Count connections per source
conn_counts = edges.group_by("src").agg(
    pl.count().alias("connection_count")
).sort("connection_count", descending=True)

print("Top 1000 sources by connection count:")
display(conn_counts.head(1000).collect())

In [None]:
# Get top 100 sources
top_src = conn_counts.head(100).select("src").to_series().to_list()

In [None]:
# Filter to top sources only
ddf = dfl6.filter(
    pl.col("src").is_in(top_src) & pl.col("dst").is_in(top_src)
)

print("Filtered to top sources:")
display(ddf.head(10).collect())
print(f"Total rows: {ddf.collect().height}")

In [None]:
# Filter for partial packet loss
partial_loss = ddf.filter(
    (pl.col("rcvd") < 3) & (pl.col("rcvd") > 0)
)
display(partial_loss.head(10).collect())

In [None]:
# Temporal investigation
print("=== Temporal Investigation ===")

# Get timestamp statistics
ts_stats = ddf.select("ts").describe()
print("Timestamp statistics:")
display(ts_stats.collect())

# Get unique timestamps
ts_values = ddf.select("ts").unique().sort("ts").collect()
ts_values_list = ts_values["ts"].to_list()
print(f"\nNumber of unique timestamps: {len(ts_values_list)}")

# Calculate time differences
if len(ts_values_list) > 1:
    ts_diffs = [ts_values_list[i+1] - ts_values_list[i] for i in range(len(ts_values_list)-1)]
    print(f"Time step differences (first 10): {ts_diffs[:10]}")
    print(f"Min time step: {min(ts_diffs)}")
    print(f"Max time step: {max(ts_diffs)}")
    print(f"Most common time step: {max(set(ts_diffs), key=ts_diffs.count)}")

print("\n=== Data Completeness Investigation ===")

# Get unique values for each dimension
unique_src = ddf.select("src").unique().collect().height
unique_dst = ddf.select("dst").unique().collect().height
unique_ts = ddf.select("ts").unique().collect().height

print(f"Unique sources: {unique_src}")
print(f"Unique destinations: {unique_dst}")
print(f"Unique timestamps: {unique_ts}")

# Calculate theoretical vs actual data points
theoretical_points = unique_src * unique_dst * unique_ts
actual_points = ddf.collect().height
print(f"Theoretical data points: {theoretical_points:,}")
print(f"Actual data points: {actual_points:,}")
print(f"Data completeness: {actual_points/theoretical_points*100:.2f}%")

# Check for self-loops
self_loops = ddf.filter(pl.col("src") == pl.col("dst")).collect()
print(f"Self-loops (src == dst): {self_loops.height}")

# Sample data structure
print("\n=== Sample Data Structure ===")
sample_data = ddf.select(["src", "dst", "ts", "avg"]).head(20).collect()
print("Sample data:")
display(sample_data)

# Check for multiple measurements per src-dst-ts combination
duplicates = ddf.group_by(["src", "dst", "ts"]).agg(
    pl.count().alias("count")
).collect()

print(f"\nMultiple measurements per src-dst-ts combination:")
print(f"Max measurements per combination: {duplicates['count'].max()}")
print(f"Mean measurements per combination: {duplicates['count'].mean():.2f}")

In [None]:
# Group by src, dst, ts and count
grouped_counts = ddf.group_by(["src", "dst", "ts"]).agg(
    pl.count().alias("count")
).collect()
display(grouped_counts)

In [None]:
def calculate_temporal_variation_basic_polars(df):
    """
    Calculate basic temporal variation statistics using polars
    """
    print("=== Basic Temporal Variation Analysis ===")
    
    # Convert timestamp to datetime for easier analysis
    df_with_time = df.with_columns([
        pl.col("ts").cast(pl.Datetime).alias("datetime"),
        pl.col("ts").cast(pl.Datetime).dt.hour().alias("hour"),
        pl.col("ts").cast(pl.Datetime).dt.weekday().alias("day_of_week")
    ]).collect()
    
    # Overall temporal statistics
    min_time = df_with_time["datetime"].min()
    max_time = df_with_time["datetime"].max()
    print(f"Time span: {min_time} to {max_time}")
    print(f"Total duration: {max_time - min_time}")
    print(f"Number of unique timestamps: {df_with_time['ts'].n_unique()}")
    
    # Temporal distribution of measurements
    hourly_counts = df_with_time.group_by("hour").agg(
        pl.count().alias("count")
    )
    daily_counts = df_with_time.group_by("day_of_week").agg(
        pl.count().alias("count")
    )
    
    print(f"\nMeasurements per hour (mean): {hourly_counts['count'].mean():.2f}")
    print(f"Measurements per hour (std): {hourly_counts['count'].std():.2f}")
    print(f"Measurements per day (mean): {daily_counts['count'].mean():.2f}")
    
    return df_with_time, hourly_counts, daily_counts

def analyze_node_pair_temporal_variation_polars(df, top_n_pairs=10):
    """
    Analyze temporal variation for specific node pairs using polars
    """
    print("=== Node Pair Temporal Variation Analysis ===")
    
    # Find most active node pairs
    pair_counts = df.group_by(["src", "dst"]).agg(
        pl.count().alias("count")
    ).sort("count", descending=True).collect()
    
    top_pairs = pair_counts.head(top_n_pairs)
    temporal_variations = {}
    
    for row in top_pairs.iter_rows(named=True):
        src, dst, count = row["src"], row["dst"], row["count"]
        
        # Filter data for this pair
        pair_data = df.filter(
            (pl.col("src") == src) & (pl.col("dst") == dst)
        ).with_columns(
            pl.col("ts").cast(pl.Datetime).alias("datetime")
        ).collect()
        
        if pair_data.height > 0:
            # Calculate temporal statistics
            hourly_var = pair_data.group_by(
                pl.col("datetime").dt.hour()
            ).agg(
                pl.col("avg").std().alias("std")
            )["std"].mean()
            
            daily_var = pair_data.group_by(
                pl.col("datetime").dt.weekday()
            ).agg(
                pl.col("avg").std().alias("std")
            )["std"].mean()
            
            # Calculate coefficient of variation
            avg_std = pair_data["avg"].std()
            avg_mean = pair_data["avg"].mean()
            cv = avg_std / avg_mean if avg_mean != 0 else 0
            
            # Calculate temporal autocorrelation (simplified)
            if pair_data.height > 1:
                sorted_data = pair_data.sort("ts")
                latencies = sorted_data["avg"].to_list()
                if len(latencies) > 2:
                    # Simple autocorrelation calculation
                    mean_lat = np.mean(latencies)
                    var_lat = np.var(latencies)
                    if var_lat > 0:
                        autocorr = np.corrcoef(latencies[:-1], latencies[1:])[0, 1]
                    else:
                        autocorr = np.nan
                else:
                    autocorr = np.nan
            else:
                autocorr = np.nan
            
            temporal_variations[(src, dst)] = {
                'count': count,
                'hourly_variation': hourly_var,
                'daily_variation': daily_var,
                'coefficient_of_variation': cv,
                'autocorrelation': autocorr,
                'mean_latency': avg_mean,
                'std_latency': avg_std
            }
    
    # Create summary dataframe
    variation_data = []
    for (src, dst), stats in temporal_variations.items():
        stats['src'] = src
        stats['dst'] = dst
        variation_data.append(stats)
    
    variation_df = pl.DataFrame(variation_data)
    
    print("Temporal variation for top node pairs:")
    display(variation_df.round(3))
    
    return variation_df

# Usage
df_with_time, hourly_counts, daily_counts = calculate_temporal_variation_basic_polars(ddf)
variation_df = analyze_node_pair_temporal_variation_polars(ddf, top_n_pairs=10)

In [None]:
def pairwise_time_variance_heatmaps_polars(df):
    # Convert to pandas for matplotlib compatibility
    df_pandas = df.collect().to_pandas()
    
    # Get all unique src and dst
    srcs = sorted(df_pandas['src'].unique())
    dsts = sorted(df_pandas['dst'].unique())
    src_idx = {s: i for i, s in enumerate(srcs)}
    dst_idx = {d: i for i, d in enumerate(dsts)}
    
    # Initialize matrices
    stddev_matrix = np.full((len(srcs), len(dsts)), np.nan)
    autocorr_matrix = np.full((len(srcs), len(dsts)), np.nan)
    
    # Group by pair and compute stats
    for (src, dst), group in df_pandas.groupby(['src', 'dst']):
        if len(group) > 1:
            latencies = group.sort_values('ts')['avg'].values
            stddev_matrix[src_idx[src], dst_idx[dst]] = np.std(latencies)
            # Autocorrelation (lag-1)
            if len(latencies) > 2:
                autocorr_matrix[src_idx[src], dst_idx[dst]] = pd.Series(latencies).autocorr()
    
    fig, axes = plt.subplots(1, 2, figsize=(18, 7))
    im0 = axes[0].imshow(stddev_matrix, aspect='auto', cmap='magma')
    axes[0].set_title('Per-Pair Latency Stddev Over Time')
    axes[0].set_xlabel('Destination Index')
    axes[0].set_ylabel('Source Index')
    plt.colorbar(im0, ax=axes[0], label='Stddev (ms)')
    
    im1 = axes[1].imshow(autocorr_matrix, aspect='auto', cmap='coolwarm', vmin=-1, vmax=1)
    axes[1].set_title('Per-Pair Latency Autocorrelation (Lag-1)')
    axes[1].set_xlabel('Destination Index')
    axes[1].set_ylabel('Source Index')
    plt.colorbar(im1, ax=axes[1], label='Autocorrelation')
    
    plt.tight_layout()
    plt.show()

# Usage
pairwise_time_variance_heatmaps_polars(ddf)

In [None]:
# Generate latency distribution visualization
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Find node pairs where src != dst (no self-loops) with good data coverage
node_pairs = ddf.group_by(["src", "dst"]).agg(
    pl.count().alias("count")
).sort("count", descending=True).collect()

# Filter out self-loops (where src == dst)
valid_pairs = node_pairs.filter(pl.col("src") != pl.col("dst"))
print("Top 10 most active node pairs (excluding self-loops):")
display(valid_pairs.head(10))

# Select a single node pair with good data coverage
if valid_pairs.height > 0:
    # Get the most active pair with at least 20 measurements
    active_pairs = valid_pairs.filter(pl.col("count") >= 20)
    if active_pairs.height > 0:
        selected_row = active_pairs.head(1).collect()
        src, dst = selected_row["src"][0], selected_row["dst"][0]
        count = selected_row["count"][0]
        print(f"\nSelected node pair: {src} → {dst} with {count} measurements")
        
        # Get data for this specific pair
        pair_data = ddf.filter(
            (pl.col("src") == src) & 
            (pl.col("dst") == dst) & 
            (pl.col("avg") > 0)
        ).collect()
        
        if pair_data.height > 0:
            # Create simple histogram
            plt.figure(figsize=(10, 6))
            plt.hist(pair_data["avg"].to_list(), bins=30, alpha=0.7, edgecolor='black', color='skyblue')
            plt.xlabel('Average Latency (ms)')
            plt.ylabel('Frequency')
            plt.title(f'Latency Distribution: Node {src} → Node {dst}\n({pair_data.height} measurements)')
            plt.grid(True, alpha=0.3)
            
            # Add statistics as text
            mean_latency = pair_data["avg"].mean()
            median_latency = pair_data["avg"].median()
            std_latency = pair_data["avg"].std()
            min_latency = pair_data["avg"].min()
            max_latency = pair_data["avg"].max()
            
            stats_text = f'Mean: {mean_latency:.2f} ms\n'
            stats_text += f'Median: {median_latency:.2f} ms\n'
            stats_text += f'Std Dev: {std_latency:.2f} ms\n'
            stats_text += f'Range: {min_latency:.2f} - {max_latency:.2f} ms'
            
            plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes, 
                    verticalalignment='top', horizontalalignment='right',
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
            plt.tight_layout()
            plt.show()
            
            print(f"\nStatistics for Node {src} → Node {dst}:")
            print(f"Number of measurements: {pair_data.height}")
            print(f"Mean latency: {mean_latency:.2f} ms")
            print(f"Median latency: {median_latency:.2f} ms")
            print(f"Standard deviation: {std_latency:.2f} ms")
            print(f"Min latency: {min_latency:.2f} ms")
            print(f"Max latency: {max_latency:.2f} ms")
        else:
            print("No successful measurements found for this pair")
    else:
        print("No node pairs found with sufficient data (>= 20 measurements)")
else:
    print("No valid node pairs found (all are self-loops)")