Compares the following two methods of getting RTT samples:

1. RTTs from TCP timestamps using the method described in "[New Methods for Passive Estimation of TCP Round-Trip Times](http://cobweb.cs.uga.edu/~kangli/src/pam05.pdf)"
2. RTTs from square waves as described in the [Proposal for adding a Spin Bit to QUIC](https://britram.github.io/draft-trammell-quic-spin/draft-trammell-quic-spin.html)

In [None]:
STORE_PATH = '/tmp/anon-v4.hdf5' # Path to datapoints from `00_extract_flows.ipynb`

In [None]:
import pandas as pd
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from collections import namedtuple
from typing import Tuple, Callable, NamedTuple, Iterable
from itertools import chain
from functools import reduce

import sys
sys.path.append('..')
from rtt import rtts_from_timestamps, rtts_from_square_wave, Flow

In [None]:
with pd.HDFStore(STORE_PATH) as store:
    tcp_df = store['tcp_df']

In [None]:
# Group by flows and compute RTTs
def compute_rtts(rtt_fn: Callable[[Tuple[str, pd.DataFrame]], pd.DataFrame]) -> pd.DataFrame:
    rtt_df = pd.DataFrame()
    for flow in tcp_df.groupby('flow_hash'):
        rtt_df = pd.concat([rtt_df, rtt_fn(flow)])
    return rtt_df

tcp_df.set_index(['flow_hash', 'timestamp'])
ts_rtt_df = compute_rtts(rtts_from_timestamps)
sw_rtt_df = compute_rtts(rtts_from_square_wave)

## Evaluation

In [None]:
# First, convert the timedeltas to microsecond ints
def convert_rtts_to_microseconds(rtt_df: pd.DataFrame) -> pd.DataFrame:
    rtt_df['rtt'] = rtt_df['rtt'] / np.timedelta64(1, 'us')

convert_rtts_to_microseconds(ts_rtt_df)
convert_rtts_to_microseconds(sw_rtt_df)

In [None]:
class Metric(NamedTuple):
    df: pd.DataFrame
    name: str
    bits_per_packet: float
    sample_rate: float = 1.0
        
def sample_metric(metric: Metric, sample_rate: float) -> Metric:
    """Subsamples the metric's `df` with `sample_rate` and recomputes `bits_per_packet`."""
    assert metric.sample_rate == 1.0, 'Metric has been sampled before!'
    sampled_df = metric.df.sample(frac=sample_rate)
    bits_per_packet= len(sampled_df) * metric.bits_per_packet / len(metric.df)
    return Metric(df=sampled_df,
                     name=metric.name, 
                     sample_rate=sample_rate, 
                     bits_per_packet=bits_per_packet)

SAMPLE_RATES = [0.25,0.5,1]

TIMESTAMP_HEADER_SIZE = 64.0  # bits
ts_metric = Metric(df=ts_rtt_df, name='Timestamp RTT', bits_per_packet=TIMESTAMP_HEADER_SIZE)

sampled_ts_metrics = [sample_metric(ts_metric, rate) for rate in SAMPLE_RATES]
sw_metric = Metric(df=sw_rtt_df, name='Square Wave RTT', bits_per_packet=1.0, )

In [None]:
# TODO: This assumes that the RTTs are normally distributed. They are, however, HEAVY-tailed. How to do stats?
def aggregate_metric(metric: Metric) -> Metric:
    return metric._replace(df=metric.df.groupby('flow_hash').agg({
           'rtt': ['mean', 'std', 'count']
    }))

aggregated_metrics = [aggregate_metric(df) for df in [sw_metric, *sampled_ts_metrics]]

In [None]:
aggregated_metrics

In [None]:
def merge_metrics(metrics: Iterable[Metric]) -> pd.DataFrame:
    for index, metric in enumerate(metrics):
        metric.df.rename({'rtt': f'rtt_{index}'}, axis='columns', inplace=True)
    metrics_dfs = (m.df for m in metrics)
    return reduce(lambda left, right: left.join(right), metrics_dfs)

merge_metrics(aggregated_metrics)

def draw_metrics_bar_chart(metrics: Iterable[Metric], unit: str = 'ms'):
    df = merge_metrics(metrics)
    mean_cols = [(f'rtt_{index}', 'mean') for index in range(len(metrics))]
    std_cols = [(f'rtt_{index}', 'std') for index in range(len(metrics))]
    
    ax = df[mean_cols].plot(kind='bar',
                          figsize=(15, 10),
                          yerr=df[std_cols].values.T,
                          legend=True,
                          fontsize=14)
    ax.set_xticklabels([])
    ax.set_ylabel(unit)
    
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles, 
              [f'{m.name}, b/p: {int(m.bits_per_packet)}' for m in metrics], 
              fontsize=14)
    
    plt.show()
    
draw_metrics_bar_chart(aggregated_metrics)  