In [None]:
%load_ext autoreload
%autoreload 2

import bnl
import frameless_eval as fle
import mir_eval
import random
import plotly.express as px
import pandas as pd

import matplotlib.pyplot as plt
from dataclasses import replace

def px_metrics_bar(metrics, title="mir_eval hierarchy metrics"):
    # metrics: dict or OrderedDict of name -> float
    df = pd.Series(metrics, dtype=float).rename_axis("metric").reset_index(name="score")
    fig = px.bar(df, x="metric", y="score", title=title, text="score", range_y=[0,1.2])
    fig.update_traces(texttemplate="%{text:.3f}", textposition="outside")
    fig.update_layout(xaxis_tickangle=-45, width=450, height=300, margin=dict(l=10, r=10, t=50, b=10))
    return fig

In [None]:
slm_ds = bnl.data.Dataset(manifest_path="~/data/salami/metadata.csv")

In [None]:
track = slm_ds[11]
ref = list(track.refs.values())[0]
est = track.ests['mu1gamma9'].align(ref)

ref.plot().show()
est.plot().show()

In [None]:
est_contour_depth = est.contour('depth')
est_contour_count = est.contour('count')
est_contour_weight = est.contour('prob')

In [None]:
est_by_depth = est_contour_depth.level('mean_shift', bw=0.15).to_ms(name='depth').prune_layers().scrub_labels()
est_by_count = est_contour_count.level('mean_shift', bw=0.15).to_ms(name='count').prune_layers().scrub_labels()
est_by_weight = est_contour_weight.level('mean_shift', bw=0.15).to_ms(name='weight').prune_layers().scrub_labels()

In [None]:
est_by_depth_tall = est_contour_depth.level().to_ms(name='depth').prune_layers().scrub_labels()
est_by_count_tall = est_contour_count.level().to_ms(name='count').prune_layers().scrub_labels()
est_by_weight_tall = est_contour_weight.level().to_ms(name='weight').prune_layers().scrub_labels()

In [None]:
est_by_depth.plot().update_layout(width=450).show()
est_by_count.plot().update_layout(width=450).show()
est_by_weight.plot().update_layout(width=450).show()

In [None]:
import warnings
warnings.filterwarnings('ignore')

print("adjusted levels")
for e, title in zip(
    [est, est_by_depth, est_by_count, est_by_weight], 
    ['raw', 'depth', 'count', 'weight']
):
    hier_score = mir_eval.hierarchy.evaluate(ref.itvls, ref.labels, e.itvls, e.labels)
    del hier_score['L-Precision']
    del hier_score['L-Recall']
    del hier_score['L-Measure']
    px_metrics_bar(hier_score, title=title).show()

In [None]:
import warnings
warnings.filterwarnings('ignore')

print('levels unadjusted, tall hierarchies')
for e, title in zip(
    [est, est_by_depth_tall, est_by_count_tall, est_by_weight_tall], 
    ['raw', 'depth_tall', 'count_tall', 'weight_tall']
):
    hier_score = mir_eval.hierarchy.evaluate(ref.itvls, ref.labels, e.itvls, e.labels)
    del hier_score['L-Precision']
    del hier_score['L-Recall']
    del hier_score['L-Measure']
    px_metrics_bar(hier_score, title=title).show()

In [None]:
# Let's see if we can see the effect of cleaning up boundaries at all
est_contour_depth.plot().show()
est_contour_depth.clean('absorb', window=2).plot().show()
est_contour_depth.clean('kde', bw=1).plot().show()

In [None]:
est_contour_count.plot().show()
est_contour_count.clean('absorb', window=2).plot().show()
est_contour_count.clean('kde', bw=1).plot().show()

In [None]:
est_contour_weight.plot().show()
est_contour_weight.clean('absorb').plot().show()
est_contour_weight.clean('kde').plot().show()

In [None]:
est_contour_weight.level().plot().show()
est_contour_weight.level('mean_shift', bw=0.1).plot().show()

In [None]:
est_contour_count.level().plot().show()
est_contour_count.level('mean_shift', bw=0.1).plot().show()

In [None]:
est_contour_weight.level().plot().show()
est_contour_weight.level('mean_shift', bw=0.1).plot().show()

## Run mono casting pipeline effects on T-measure over all SLM tracks.

In [None]:
from pqdm.processes import pqdm
from bnl.exp import test_mono_casting_effects as tmce
import bnl

slm_ds = bnl.data.Dataset(manifest_path="~/data/salami/metadata.csv")
list_of_dfs = pqdm(slm_ds, tmce, n_jobs=8)

In [None]:
type(list_of_dfs[1]) is pd.DataFrame


In [None]:
valid_dfs = [df for df in list_of_dfs if type(df) is pd.DataFrame]

In [None]:
len(valid_dfs)

In [None]:
import pandas as pd
all_results_df = pd.concat(valid_dfs, ignore_index=True)

In [None]:
all_results_df

In [None]:
est[-1].plot()