In [None]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from src.longitudinal_ssms import TTC
from src.two_dimensional_ssms import TTC2D
from src.efficiency_utils import evaluate_efficiency

In [None]:
samples = pd.read_hdf('./assets/samples.h5', key='example')
SSMs = ['TTC', 'DRAC', 'MTTC', 'PSD', 'TTC2D', 'ACT', 'TAdv']

In [None]:
# Example use of TTC returning a dataframe
results = TTC(samples, toreturn='dataframe')
results[['TTC']].head()

In [None]:
# Example use of TTC2D returning values
ttc2d_values = TTC2D(samples, toreturn='values')
_ = plt.hist(ttc2d_values, bins=np.linspace(0, 15, 35), alpha=0.5, label='TTC2D')
_ = plt.hist(results['TTC'], bins=np.linspace(0, 15, 35), alpha=0.5, label='TTC1D')
plt.legend()
plt.xlabel('Time to Collision (s)')
plt.ylabel('Frequency')

In [None]:
samples = samples.loc[:1e4-1] # extract 1e4 samples for testing
eval_results = {}
for ssm in tqdm(SSMs):
    for num_pairs in [1e4, 1e5, 1e6]:
        num_pairs = int(num_pairs)
        num_repeat = int(num_pairs//len(samples))
        test_samples = pd.concat([samples]*num_repeat, ignore_index=True)
        avg_time, run_time = evaluate_efficiency(test_samples, ssm, 20, average_only=False)
        print(f'{ssm} with {num_pairs} pairs: {avg_time:.2f} seconds on average.')
        eval_results[(ssm, num_pairs)] = run_time

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5), constrained_layout=True, gridspec_kw={'wspace': 0.05})
fig.suptitle('Time taken with different SSMs for a large number of vehicle pairs')
for axid, num_pairs in enumerate([1e4, 1e5, 1e6]):
    ax = axes[axid]
    ax.set_title(f'For 1e{int(np.log10(num_pairs))} pairs')
    num_pairs = int(num_pairs)
    for pos, ssm in enumerate(SSMs):
        data = eval_results[(ssm, num_pairs)]
        ax.boxplot(data, positions=[pos], showfliers=True, widths=0.6)
    xlabels = [f'{ssm}\n({np.median(eval_results[(ssm, num_pairs)]):.2f}s)' for ssm in SSMs]
    ax.set_xticks(list(range(len(SSMs))))
    ax.set_xticklabels(xlabels)
    ax.set_ylabel('Time (s)')

In [None]:
fig.savefig('./assets/efficiency_comparison.svg', dpi=400, bbox_inches='tight')