In [None]:
from itertools import product
from functools import partial
from typing import Callable, NamedTuple, List, Tuple

import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt

from estimator.precision import precision
from generator.sequence import generate_random_sequence
from synchronizer.max_flow.max_flow import default_max_flow_synchronizer
from simulator.ground_truth.predefined import predefined_policies
from simulator.policy import policies_str
from synchronizer.alignment import Alignment
from utils.consume import consume
from utils.eval_loop import eval_loop
from utils.test_signal import TestSignal
from utils.sample import sample_uniform

## Test Tooling
### Samples for plots

In [None]:
PrecisionSample = NamedTuple('PrecisionSample', [('precision', float), ('nw_condition', str), ('symbol_bits', str)])

def precision_sample_fn(nw_condition: str, symbol_bits: int) -> Callable[[TestSignal, Alignment], PrecisionSample]:
    def precision_sample(test_signal: TestSignal, alignment: Alignment) -> PrecisionSample:
        _precision = precision(alignment, test_signal.ground_truth)
        return PrecisionSample(precision=_precision, nw_condition=nw_condition, symbol_bits=symbol_bits)
    return precision_sample

### Test Signal Generation

In [None]:
def sample_signal_lengths():
    return sample_uniform(10, 30), 50

def generate_random_test_signal_fn(nw_condition: str, symbol_bits: int) -> Callable[[], TestSignal]:
    return partial(TestSignal.generate,
                   generator=generate_random_sequence(symbol_bits),
                   policies=predefined_policies[nw_condition],
                   sample_signal_lengths=sample_signal_lengths)

### Synchronizer

In [None]:
synchronizer = default_max_flow_synchronizer

### Collecting Samples

In [None]:
def collect_samples(nw_condition: str, symbol_bits: int, num_samples: int) -> List[PrecisionSample]:
    loop = partial(eval_loop,
                   generate_test_signal=generate_random_test_signal_fn(nw_condition, symbol_bits),
                   synchronizer=synchronizer,
                   postprocess=precision_sample_fn(nw_condition, symbol_bits))
    return consume(loop(), length=num_samples)

## Running Experiments
### Network Parameters
We use a predefined set of policies:

In [None]:
for key, policies in predefined_policies.items():
    print(f'{key}:\n     {policies_str(policies)}\n')

### Variable Parameters

In [None]:
nw_conditions = ['normal']
test_symbol_bits = [2, 3, 4, 5, 6]

### Runs

In [None]:
runs = 5000
samples = []
for nw_condition, symbol_bits in product(nw_conditions, test_symbol_bits):
    samples += collect_samples(nw_condition, symbol_bits, runs)

samples_df = pd.DataFrame(samples)

How many of the runs had no events at all?

In [None]:
'{:.01f}%'.format(100 * samples_df["precision"].value_counts()[-1] / (len(samples_df)))

## Plots

### Symbol Bits vs Avg. Accuracy
What is the average precision depending on symbol_bits and network condition?

In [None]:
agg_df = samples_df[samples_df['precision'] >= 0].groupby(['symbol_bits', 'nw_condition']).agg({
    'precision': ['mean']
}).unstack(1)
agg_df.columns = [col[2] for col in agg_df.columns]
agg_df.columns.name = 'NW Condition'
agg_df.index.name = 'Symbol Bits'
agg_df.plot()

In [None]:
agg_df

In [None]:
# TODO: Only consider one NW condition for plot