In [2]:
from collections import defaultdict
from joblib import Parallel, delayed
import pandas as pd
from tqdm import tqdm_notebook as tqdm
import os

In [3]:
import librosa
import numpy as np

In [4]:
from IPython.display import display

In [5]:
import jams
print(jams.__version__)

0.2.1


In [6]:
# Set a padding tolerance of 100ms
TOLERANCE = 0.1

# Set an alignment tolerance of 3s
ALIGN_TOLERANCE = 3.0

In [7]:
def merge_annotations(upper=None, lower=None):
    
    ann = jams.Annotation(namespace='multi_segment',
                          annotation_metadata=upper.annotation_metadata)
    
    try:
        ann.sandbox.update(**upper.sandbox)
    except AttributeError:
        pass
    
    try:
        ann.sandbox.update(**lower.sandbox)
    except AttributeError:
        pass
    
    # Add the uppers
    d_upper = upper.data.copy()
    d_upper['value'] = [{'label': _, 'level': 0} for _ in d_upper['value']]
    
    # Add the lowers
    d_lower = lower.data.copy()
    d_lower['value'] = [{'label': _, 'level': 1} for _ in d_lower['value']]
    
    ann.data = jams.JamsFrame.from_dataframe(pd.concat([d_upper, d_lower], ignore_index=True))
    
    return ann
    

In [8]:
def span_segment(ann=None, duration=None):
    
    if ann.namespace == 'segment_salami_upper':
        pre, post = 'YYYYY', 'ZZZZZ'
    elif ann.namespace == 'segment_salami_lower':
        pre, post = 'yyyyy', 'zzzzz'
    else:
        return
    
    frame = ann.data
    
    # Give a one-second buffer here
    if (frame['time'] > pd.Timedelta(duration + 1.0, unit='s')).any():
        raise RuntimeError('track length exceeded in observation: {0:3g}/{1:3g}'.format(
                frame['time'].max().total_seconds(),
                duration))
        
    # Drop any annotations where time > duration
    frame = jams.JamsFrame(frame[frame['time'] <= pd.Timedelta(duration, unit='s')])
    
    idx = frame['time'].argmin()
    min_time = frame['time'][idx]
    
    # If the minimum time is close enough to zero, just clamp it
    if pd.Timedelta(0) < min_time:
        if min_time <= pd.Timedelta(TOLERANCE, unit='s') or frame['value'][idx].lower() == 'silence':
            frame.loc[idx, 'time'] = pd.Timedelta(0)
            
        else:
            # We need to pad a new segment here
            
            frame.add_observation(time=0,
                                  duration=min_time.total_seconds(),
                                  value=pre, confidence=0)
            
        
    idx = (frame['time'] + frame['duration']).argmax()
    max_time = frame['time'][idx] + frame['duration'][idx]
    
    if max_time != pd.Timedelta(duration, unit='s'):
        if pd.Timedelta(duration - TOLERANCE, unit='s') <= max_time or frame['value'][idx].lower() == 'silence':
            frame.loc[idx, 'duration'] = pd.Timedelta(duration - frame['time'][idx].total_seconds(), unit='s')
            
        else:
            # Pad out a new segment here
            frame.add_observation(time=max_time.total_seconds(),
                                  duration=duration - max_time.total_seconds(),
                                  value=post, confidence=0)
    
    # Sort the rows
    frame.sort_values('time', inplace=True)
    
    ann.data = frame
    
def align_annotations(upper=None, lower=None):
    
    # Every upper-level segment should have a corresponding lower-level segment within TOLERANCE
    
    i_upper, l_upper = upper.data.to_interval_values()
    i_lower, l_lower = lower.data.to_interval_values()
    
    starts, ends = i_upper[:, 0], i_upper[:, 1]
    t_times = np.unique(np.ravel(i_lower))
    start_match = librosa.util.match_events(starts, t_times)
    end_match = librosa.util.match_events(ends, t_times)
    
    starts_adj = t_times[start_match]
    ends_adj = t_times[end_match]
    
    assert np.all(np.abs(starts_adj - starts) <= ALIGN_TOLERANCE), np.max(np.abs(starts - starts_adj))
    
    new_df = jams.JamsFrame()
        
    # Only add intervals with positive duration
    for s, t, l, c in zip(starts_adj, ends_adj, upper.data.value, upper.data.confidence):
        if t > s:
            new_df.add_observation(time=s, duration=t-s, value=l, confidence=c)
        
    upper.data = new_df

In [9]:
def fix_jams(jamsfile):
    
    J = jams.load(jamsfile)
    
    ann_dict = defaultdict(dict)
    
    for upper in J.search(namespace='segment_salami_upper'):
        name = upper.annotation_metadata.annotator.name
        try:
            
            span_segment(upper, duration=J.file_metadata.duration)
            ann_dict[name]['upper'] = upper
        except RuntimeError as exc:
            print(exc, jamsfile, name, 'upper\n')

    for lower in J.search(namespace='segment_salami_lower'):
        name = lower.annotation_metadata.annotator.name
        try:
            
            span_segment(lower, duration=J.file_metadata.duration)
            ann_dict[name]['lower'] = lower
    
        except RuntimeError as exc:
            print(exc, jamsfile, name, 'lower\n')

    for key in ann_dict:
        if len(ann_dict[key]) != 2:
            print('Align and merge failed: {}/{}'.format(jamsfile, key))
            continue
        try:
            align_annotations(**ann_dict[key])
            J.annotations.append(merge_annotations(**ann_dict[key]))
            pass
        except AssertionError as exc:
            print(exc, jamsfile, key)

    return J

In [10]:
files = jams.util.find_with_extension('out_dir/', 'jams')

In [11]:
for jfn in tqdm(files):
    J = fix_jams(jfn)
    J.save(jfn)

(AssertionError(39.009523999999999,), 'out_dir/SALAMI_108.jams', u'Shuli Tang')
(RuntimeError('track length exceeded in observation: 288.392/262.322',), 'out_dir/SALAMI_114.jams', u'Eleni Vasilia Maltas', 'upper\n')
(RuntimeError('track length exceeded in observation: 289.948/262.322',), 'out_dir/SALAMI_114.jams', u'John Turner', 'upper\n')
(RuntimeError('track length exceeded in observation: 294.104/262.322',), 'out_dir/SALAMI_114.jams', u'Evan S. Johnson', 'upper\n')
(RuntimeError('track length exceeded in observation: 401.914/262.322',), 'out_dir/SALAMI_114.jams', u'Shuli Tang', 'upper\n')
(RuntimeError('track length exceeded in observation: 273.183/262.322',), 'out_dir/SALAMI_114.jams', u'Colin Hua', 'lower\n')
(RuntimeError('track length exceeded in observation: 288.392/262.322',), 'out_dir/SALAMI_114.jams', u'Eleni Vasilia Maltas', 'lower\n')
(RuntimeError('track length exceeded in observation: 289.948/262.322',), 'out_dir/SALAMI_114.jams', u'John Turner', 'lower\n')
(RuntimeErro