In [1]:
# pyright: basic, reportUnknownVariableType=false, reportUnknownMemberType=false

import polars as pl
import json
import numpy as np
import pyarrow as pa
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import glob
import os

# Configure polars for better performance
pl.Config.set_fmt_str_lengths(50)
pl.Config.set_tbl_rows(20)
pl.Config.set_tbl_cols(20)

polars.config.Config

In [None]:
def extract_rtts_native(n=3):
    """
    Returns a list of Polars expressions to parse a JSON 'result' column
    and extract the first 'n' RTT values into separate columns.
    """
    rtt_columns = []
    
    for i in range(n):
        rtt_expr = (
            pl.col("result")
            .str.json_path_match(f"$[{i}].rtt")
            .cast(pl.Float32, strict=False)  # Use Float32 instead of Float64
            .alias(f"rtt_{i+1}")
        )
        rtt_columns.append(rtt_expr)
    
    return rtt_columns

def optimize_data_types():
    """
    Returns column optimizations for space efficiency.
    
    Saves ~22 bytes per row (60% reduction):
    - prb_id: i64 -> u32 (4 bytes saved)
    - sent: i64 -> u8 (7 bytes saved) 
    - rcvd: i64 -> u8 (7 bytes saved)
    - avg: f64 -> f32 (4 bytes saved)
    """
    return [
        pl.col("prb_id").cast(pl.UInt32),
        pl.col("sent").cast(pl.UInt8), 
        pl.col("rcvd").cast(pl.UInt8),
        pl.col("avg").cast(pl.Float32),
        pl.col("ts")  # Keep as i64 (unix timestamp)
    ]

In [3]:
SOURCE_GLOB = "data/ping/**/*.parquet"
OUTPUT_PATH = "data/ping_parsed.parquet"
BATCH_SIZE = 100  # Process 1000 files at a time. Adjust as needed.
N_RTTS = 3

# Get a list of all file paths
all_files = glob.glob(SOURCE_GLOB)
print(f"Found {len(all_files)} files to process.")

Found 96562 files to process.


In [4]:
# Process with the fixed JSON parsing function
# Use directory-based approach (industry standard for large datasets)
import os
from pathlib import Path

# Create output directory for partitioned dataset
OUTPUT_DIR = "data/ping_parsed_parts"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Clean up any existing files
for f in Path(OUTPUT_DIR).glob("*.parquet"):
    f.unlink()

processed_batches = 0
total_batches = (len(all_files) + BATCH_SIZE - 1) // BATCH_SIZE

print(f"Starting processing of {len(all_files)} files in {total_batches} batches...")
print(f"Each batch: ~100 files = ~768MB per output file")
print(f"Output directory: {OUTPUT_DIR}")

for i in range(0, len(all_files), BATCH_SIZE):
    batch_files = all_files[i : i + BATCH_SIZE]
    processed_batches += 1
    
    print(f"Processing batch {processed_batches}/{total_batches} ({len(batch_files)} files)...")

    try:
        # Scan only the files in the current batch
        lazy_df = pl.scan_parquet(batch_files)

        # Apply the transformation with the fixed JSON parsing
        transformed_lazy = lazy_df.with_columns(
            extract_rtts_native(n=N_RTTS)
        ).drop("result")

        # Write each batch to a separate file in the partitioned dataset
        batch_output = f"{OUTPUT_DIR}/part_{processed_batches:04d}.parquet"
        transformed_lazy.sink_parquet(batch_output)
        
        # Get file size for progress tracking
        file_size = os.path.getsize(batch_output) / (1024**2)  # MB
        print(f"✓ Saved batch {processed_batches} ({file_size:.0f}MB)")
            
    except Exception as e:
        print(f"✗ Error in batch {processed_batches}: {str(e)}")
        print(f"  Sample files: {batch_files[:3]}")
        break  # Stop on first error to debug

print(f"\nProcessing complete: {processed_batches}/{total_batches} batches")

# Show total output stats
if os.path.exists(OUTPUT_DIR):
    parquet_files = list(Path(OUTPUT_DIR).glob("*.parquet"))
    total_size = sum(f.stat().st_size for f in parquet_files)
    total_size_gb = total_size / (1024**3)
    
    print(f"Total output: {len(parquet_files)} files, {total_size_gb:.2f} GB")
    print(f"Average file size: {total_size_gb/len(parquet_files)*1024:.0f} MB")
    
    print(f"\n📁 Partitioned dataset ready at: {OUTPUT_DIR}/")
    print(f"💡 To read entire dataset: pl.scan_parquet('{OUTPUT_DIR}/*.parquet')")
    
    # Verify we can read the dataset
    try:
        total_rows = pl.scan_parquet(f"{OUTPUT_DIR}/*.parquet").select(pl.len()).collect().item()
        print(f"✅ Verified: {total_rows:,} total rows across all files")
    except Exception as e:
        print(f"⚠️  Verification failed: {e}")

Starting processing of 96562 files in 966 batches...
Each batch: ~100 files = ~768MB per output file
Output directory: data/ping_parsed_parts
Processing batch 1/966 (100 files)...
✓ Saved batch 1 (733MB)
Processing batch 2/966 (100 files)...
✓ Saved batch 2 (733MB)
Processing batch 3/966 (100 files)...
✓ Saved batch 3 (723MB)
Processing batch 4/966 (100 files)...
✓ Saved batch 4 (667MB)
Processing batch 5/966 (100 files)...
✓ Saved batch 5 (709MB)
Processing batch 6/966 (100 files)...
✓ Saved batch 6 (651MB)
Processing batch 7/966 (100 files)...
✓ Saved batch 7 (666MB)
Processing batch 8/966 (100 files)...
✓ Saved batch 8 (735MB)
Processing batch 9/966 (100 files)...
✓ Saved batch 9 (735MB)
Processing batch 10/966 (100 files)...
✓ Saved batch 10 (735MB)
Processing batch 11/966 (100 files)...
✓ Saved batch 11 (736MB)
Processing batch 12/966 (100 files)...
✓ Saved batch 12 (736MB)
Processing batch 13/966 (100 files)...
✓ Saved batch 13 (732MB)
Processing batch 14/966 (100 files)...
✓ Sav

In [7]:
# Read probe IP mapping CSV
ip_map = pl.read_csv("probe_ip_map.csv")
ip_map = ip_map.rename({"prb_id": "dst_prb_id"})
print("IP map sample:")
display(ip_map.head(10))

In [None]:
print("IP map head:")
display(ip_map.head(10))

In [None]:
print("Main dataframe head:")
display(df.head(10).collect())

In [None]:
# Join data with IP mapping
df1 = df.rename({"prb_id": "src"})
dfl2 = df1.join(
    ip_map.select(["ip", "dst_prb_id"]),
    left_on="dst_addr", right_on="ip",
    how="left"
)

# Fill nulls and cast to int
dfl2 = dfl2.with_columns(
    pl.col("dst_prb_id").fill_null(-1).cast(pl.Int64)
)

print("After join:")
display(dfl2.head(10).collect())

# Filter where dst_prb_id is not -1
dfl3 = dfl2.filter(pl.col("dst_prb_id") != -1)
print("\nAfter filtering:")
display(dfl3.head(10).collect())

# Select columns and rename
dfl6 = dfl3.select(["src", "dst_prb_id", "ts", "avg", "result", "sent", "rcvd"])
dfl6 = dfl6.rename({"dst_prb_id": "dst"})

print("\nFinal dataframe:")
display(dfl6.head(10).collect())
print(f"Total rows: {dfl6.collect().height}")

In [None]:
# Get unique edges
edges = dfl6.select(["src", "dst"]).unique()
print("Unique edges:")
display(edges.head(10).collect())

In [None]:
print(f"Total unique edges: {edges.collect().height}")

In [None]:
# Count connections per source
conn_counts = edges.group_by("src").agg(
    pl.count().alias("connection_count")
).sort("connection_count", descending=True)

print("Top 1000 sources by connection count:")
display(conn_counts.head(1000).collect())

In [None]:
# Get top 100 sources
top_src = conn_counts.head(100).select("src").to_series().to_list()

In [None]:
# Filter to top sources only
ddf = dfl6.filter(
    pl.col("src").is_in(top_src) & pl.col("dst").is_in(top_src)
)

print("Filtered to top sources:")
display(ddf.head(10).collect())
print(f"Total rows: {ddf.collect().height}")

In [None]:
# Filter for partial packet loss
partial_loss = ddf.filter(
    (pl.col("rcvd") < 3) & (pl.col("rcvd") > 0)
)
display(partial_loss.head(10).collect())

In [None]:
# Temporal investigation
print("=== Temporal Investigation ===")

# Get timestamp statistics
ts_stats = ddf.select("ts").describe()
print("Timestamp statistics:")
display(ts_stats.collect())

# Get unique timestamps
ts_values = ddf.select("ts").unique().sort("ts").collect()
ts_values_list = ts_values["ts"].to_list()
print(f"\nNumber of unique timestamps: {len(ts_values_list)}")

# Calculate time differences
if len(ts_values_list) > 1:
    ts_diffs = [ts_values_list[i+1] - ts_values_list[i] for i in range(len(ts_values_list)-1)]
    print(f"Time step differences (first 10): {ts_diffs[:10]}")
    print(f"Min time step: {min(ts_diffs)}")
    print(f"Max time step: {max(ts_diffs)}")
    print(f"Most common time step: {max(set(ts_diffs), key=ts_diffs.count)}")

print("\n=== Data Completeness Investigation ===")

# Get unique values for each dimension
unique_src = ddf.select("src").unique().collect().height
unique_dst = ddf.select("dst").unique().collect().height
unique_ts = ddf.select("ts").unique().collect().height

print(f"Unique sources: {unique_src}")
print(f"Unique destinations: {unique_dst}")
print(f"Unique timestamps: {unique_ts}")

# Calculate theoretical vs actual data points
theoretical_points = unique_src * unique_dst * unique_ts
actual_points = ddf.collect().height
print(f"Theoretical data points: {theoretical_points:,}")
print(f"Actual data points: {actual_points:,}")
print(f"Data completeness: {actual_points/theoretical_points*100:.2f}%")

# Check for self-loops
self_loops = ddf.filter(pl.col("src") == pl.col("dst")).collect()
print(f"Self-loops (src == dst): {self_loops.height}")

# Sample data structure
print("\n=== Sample Data Structure ===")
sample_data = ddf.select(["src", "dst", "ts", "avg"]).head(20).collect()
print("Sample data:")
display(sample_data)

# Check for multiple measurements per src-dst-ts combination
duplicates = ddf.group_by(["src", "dst", "ts"]).agg(
    pl.count().alias("count")
).collect()

print(f"\nMultiple measurements per src-dst-ts combination:")
print(f"Max measurements per combination: {duplicates['count'].max()}")
print(f"Mean measurements per combination: {duplicates['count'].mean():.2f}")

In [None]:
# Group by src, dst, ts and count
grouped_counts = ddf.group_by(["src", "dst", "ts"]).agg(
    pl.count().alias("count")
).collect()
display(grouped_counts)

In [None]:
def calculate_temporal_variation_basic_polars(df):
    """
    Calculate basic temporal variation statistics using polars
    """
    print("=== Basic Temporal Variation Analysis ===")
    
    # Convert timestamp to datetime for easier analysis
    df_with_time = df.with_columns([
        pl.col("ts").cast(pl.Datetime).alias("datetime"),
        pl.col("ts").cast(pl.Datetime).dt.hour().alias("hour"),
        pl.col("ts").cast(pl.Datetime).dt.weekday().alias("day_of_week")
    ]).collect()
    
    # Overall temporal statistics
    min_time = df_with_time["datetime"].min()
    max_time = df_with_time["datetime"].max()
    print(f"Time span: {min_time} to {max_time}")
    print(f"Total duration: {max_time - min_time}")
    print(f"Number of unique timestamps: {df_with_time['ts'].n_unique()}")
    
    # Temporal distribution of measurements
    hourly_counts = df_with_time.group_by("hour").agg(
        pl.count().alias("count")
    )
    daily_counts = df_with_time.group_by("day_of_week").agg(
        pl.count().alias("count")
    )
    
    print(f"\nMeasurements per hour (mean): {hourly_counts['count'].mean():.2f}")
    print(f"Measurements per hour (std): {hourly_counts['count'].std():.2f}")
    print(f"Measurements per day (mean): {daily_counts['count'].mean():.2f}")
    
    return df_with_time, hourly_counts, daily_counts

def analyze_node_pair_temporal_variation_polars(df, top_n_pairs=10):
    """
    Analyze temporal variation for specific node pairs using polars
    """
    print("=== Node Pair Temporal Variation Analysis ===")
    
    # Find most active node pairs
    pair_counts = df.group_by(["src", "dst"]).agg(
        pl.count().alias("count")
    ).sort("count", descending=True).collect()
    
    top_pairs = pair_counts.head(top_n_pairs)
    temporal_variations = {}
    
    for row in top_pairs.iter_rows(named=True):
        src, dst, count = row["src"], row["dst"], row["count"]
        
        # Filter data for this pair
        pair_data = df.filter(
            (pl.col("src") == src) & (pl.col("dst") == dst)
        ).with_columns(
            pl.col("ts").cast(pl.Datetime).alias("datetime")
        ).collect()
        
        if pair_data.height > 0:
            # Calculate temporal statistics
            hourly_var = pair_data.group_by(
                pl.col("datetime").dt.hour()
            ).agg(
                pl.col("avg").std().alias("std")
            )["std"].mean()
            
            daily_var = pair_data.group_by(
                pl.col("datetime").dt.weekday()
            ).agg(
                pl.col("avg").std().alias("std")
            )["std"].mean()
            
            # Calculate coefficient of variation
            avg_std = pair_data["avg"].std()
            avg_mean = pair_data["avg"].mean()
            cv = avg_std / avg_mean if avg_mean != 0 else 0
            
            # Calculate temporal autocorrelation (simplified)
            if pair_data.height > 1:
                sorted_data = pair_data.sort("ts")
                latencies = sorted_data["avg"].to_list()
                if len(latencies) > 2:
                    # Simple autocorrelation calculation
                    mean_lat = np.mean(latencies)
                    var_lat = np.var(latencies)
                    if var_lat > 0:
                        autocorr = np.corrcoef(latencies[:-1], latencies[1:])[0, 1]
                    else:
                        autocorr = np.nan
                else:
                    autocorr = np.nan
            else:
                autocorr = np.nan
            
            temporal_variations[(src, dst)] = {
                'count': count,
                'hourly_variation': hourly_var,
                'daily_variation': daily_var,
                'coefficient_of_variation': cv,
                'autocorrelation': autocorr,
                'mean_latency': avg_mean,
                'std_latency': avg_std
            }
    
    # Create summary dataframe
    variation_data = []
    for (src, dst), stats in temporal_variations.items():
        stats['src'] = src
        stats['dst'] = dst
        variation_data.append(stats)
    
    variation_df = pl.DataFrame(variation_data)
    
    print("Temporal variation for top node pairs:")
    display(variation_df.round(3))
    
    return variation_df

# Usage
df_with_time, hourly_counts, daily_counts = calculate_temporal_variation_basic_polars(ddf)
variation_df = analyze_node_pair_temporal_variation_polars(ddf, top_n_pairs=10)

In [None]:
def pairwise_time_variance_heatmaps_polars(df):
    # Convert to pandas for matplotlib compatibility
    df_pandas = df.collect().to_pandas()
    
    # Get all unique src and dst
    srcs = sorted(df_pandas['src'].unique())
    dsts = sorted(df_pandas['dst'].unique())
    src_idx = {s: i for i, s in enumerate(srcs)}
    dst_idx = {d: i for i, d in enumerate(dsts)}
    
    # Initialize matrices
    stddev_matrix = np.full((len(srcs), len(dsts)), np.nan)
    autocorr_matrix = np.full((len(srcs), len(dsts)), np.nan)
    
    # Group by pair and compute stats
    for (src, dst), group in df_pandas.groupby(['src', 'dst']):
        if len(group) > 1:
            latencies = group.sort_values('ts')['avg'].values
            stddev_matrix[src_idx[src], dst_idx[dst]] = np.std(latencies)
            # Autocorrelation (lag-1)
            if len(latencies) > 2:
                autocorr_matrix[src_idx[src], dst_idx[dst]] = pd.Series(latencies).autocorr()
    
    fig, axes = plt.subplots(1, 2, figsize=(18, 7))
    im0 = axes[0].imshow(stddev_matrix, aspect='auto', cmap='magma')
    axes[0].set_title('Per-Pair Latency Stddev Over Time')
    axes[0].set_xlabel('Destination Index')
    axes[0].set_ylabel('Source Index')
    plt.colorbar(im0, ax=axes[0], label='Stddev (ms)')
    
    im1 = axes[1].imshow(autocorr_matrix, aspect='auto', cmap='coolwarm', vmin=-1, vmax=1)
    axes[1].set_title('Per-Pair Latency Autocorrelation (Lag-1)')
    axes[1].set_xlabel('Destination Index')
    axes[1].set_ylabel('Source Index')
    plt.colorbar(im1, ax=axes[1], label='Autocorrelation')
    
    plt.tight_layout()
    plt.show()

# Usage
pairwise_time_variance_heatmaps_polars(ddf)

In [None]:
# Generate latency distribution visualization
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Find node pairs where src != dst (no self-loops) with good data coverage
node_pairs = ddf.group_by(["src", "dst"]).agg(
    pl.count().alias("count")
).sort("count", descending=True).collect()

# Filter out self-loops (where src == dst)
valid_pairs = node_pairs.filter(pl.col("src") != pl.col("dst"))
print("Top 10 most active node pairs (excluding self-loops):")
display(valid_pairs.head(10))

# Select a single node pair with good data coverage
if valid_pairs.height > 0:
    # Get the most active pair with at least 20 measurements
    active_pairs = valid_pairs.filter(pl.col("count") >= 20)
    if active_pairs.height > 0:
        selected_row = active_pairs.head(1).collect()
        src, dst = selected_row["src"][0], selected_row["dst"][0]
        count = selected_row["count"][0]
        print(f"\nSelected node pair: {src} → {dst} with {count} measurements")
        
        # Get data for this specific pair
        pair_data = ddf.filter(
            (pl.col("src") == src) & 
            (pl.col("dst") == dst) & 
            (pl.col("avg") > 0)
        ).collect()
        
        if pair_data.height > 0:
            # Create simple histogram
            plt.figure(figsize=(10, 6))
            plt.hist(pair_data["avg"].to_list(), bins=30, alpha=0.7, edgecolor='black', color='skyblue')
            plt.xlabel('Average Latency (ms)')
            plt.ylabel('Frequency')
            plt.title(f'Latency Distribution: Node {src} → Node {dst}\n({pair_data.height} measurements)')
            plt.grid(True, alpha=0.3)
            
            # Add statistics as text
            mean_latency = pair_data["avg"].mean()
            median_latency = pair_data["avg"].median()
            std_latency = pair_data["avg"].std()
            min_latency = pair_data["avg"].min()
            max_latency = pair_data["avg"].max()
            
            stats_text = f'Mean: {mean_latency:.2f} ms\n'
            stats_text += f'Median: {median_latency:.2f} ms\n'
            stats_text += f'Std Dev: {std_latency:.2f} ms\n'
            stats_text += f'Range: {min_latency:.2f} - {max_latency:.2f} ms'
            
            plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes, 
                    verticalalignment='top', horizontalalignment='right',
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
            plt.tight_layout()
            plt.show()
            
            print(f"\nStatistics for Node {src} → Node {dst}:")
            print(f"Number of measurements: {pair_data.height}")
            print(f"Mean latency: {mean_latency:.2f} ms")
            print(f"Median latency: {median_latency:.2f} ms")
            print(f"Standard deviation: {std_latency:.2f} ms")
            print(f"Min latency: {min_latency:.2f} ms")
            print(f"Max latency: {max_latency:.2f} ms")
        else:
            print("No successful measurements found for this pair")
    else:
        print("No node pairs found with sufficient data (>= 20 measurements)")
else:
    print("No valid node pairs found (all are self-loops)")