In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import torch
from pickle_cache import PickleCache
from pyannote.core import Segment
from IPython.display import display, clear_output
from ipywidgets import Button
from collections import defaultdict
from random import sample
from intervaltree import IntervalTree
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from voice_feedback.audio import AudioFile
from voice_feedback.utils import two_button

from rekall import Interval, IntervalSet, Bounds3D, IntervalSetMapping
from vgrid import VGridSpec, VideoMetadata, FlatFormat, LabelState
from vgrid_jupyter import VGridWidget

In [6]:
DATA_PATH = '../data'
pcache = PickleCache()

In [7]:
audio_path = f'{DATA_PATH}/zoom.wav'
def diarize():
    pipeline = torch.hub.load('pyannote/pyannote-audio', 'dia_ami')
    return pipeline({
        'uri': 'filename',
        'audio': audio_path
    })

diarization = pcache.get('diarization', diarize)

In [8]:
audio = AudioFile(audio_path)

In [9]:
segments = defaultdict(list)
for seg, track, label in diarization.itertracks(yield_label=True):
    if type(label) == str and seg.duration > 2.:
        segments[label].append(seg)
segments = dict(segments)    

In [10]:
split_segs = iter(segments.items())
seg_labels = {}

def next_item():
    clear_output()
    try:
        show(*next(split_segs))
    except:
        pass

def show(k, segs):
    rand_segs = sample(segs, k=5)
    for seg in rand_segs:
        display(audio.display(audio.interval(seg.start, seg.end)))
        
    me = Button(description='Me')
    not_me = Button(description='Not me')
    

    def on_click_me(_):
        seg_labels[k] = True
        next_item()
        
    def on_click_not_me(_):
        seg_labels[k] = False
        next_item()

    me.on_click(on_click_me)
    not_me.on_click(on_click_not_me)
    
    display(me)
    display(not_me)
        
next_item()

Button(description='Me', style=ButtonStyle())

Button(description='Not me', style=ButtonStyle())

In [11]:
seg_labels = pcache.get('seg_labels', lambda: seg_labels)

In [13]:
me_segs = sorted([seg for k, me in seg_labels.items() for seg in segments[k] if me], key=lambda s: s.start)
pcache.set('me_segs', me_segs)

In [14]:
video_id = 1
video = VideoMetadata(path='../data/zoom_0.mp4', id=video_id)
video.path = 'zoom_0.mp4'

segs_to_label = me_segs[:10]
intervals = IntervalSet([Interval(Bounds3D(seg.start, seg.end)) for seg in segs_to_label])
interval_map = IntervalSetMapping({video_id: intervals})

vgrid_spec = VGridSpec(
  video_meta=[video],
  vis_format=FlatFormat(interval_map),
  video_endpoint='http://charlotte.stanford.edu:8887',
  auto_zoom=True)

vgrid_widget = VGridWidget(vgrid_spec=vgrid_spec.to_json())
vgrid_widget

VGridWidget(vgrid_spec={'interval_blocks': [{'interval_sets': [{'name': 'default', 'interval_set': [{'bounds':…

In [None]:
label_state = LabelState(lambda: vgrid_widget.label_state)
new_intvls = [(intvl.bounds['t1'], intvl.bounds['t2'])
              for block_labels in label_state.block_labels().values() 
              for intvl in block_labels.new_intervals().get_intervals()]
bad_sounds = sorted(new_intvls, key=lambda i: i[0])
bad_tree = IntervalTree.from_tuples([(start, end, None) for (start, end) in bad_sounds])

In [None]:
good_tree = IntervalTree.from_tuples([(seg.start, seg.end) for seg in me_segs])
for intvl in bad_tree:
    good_tree.chop(intvl.begin, intvl.end)

In [None]:
def tree_to_wavs(tree, out_dir, dur):
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    
    for i, intvl in tqdm(list(enumerate(tree))):
        (start, end) = (intvl.begin, intvl.end)
        for start2 in np.arange(start, end, dur):
            end2 = start2 + dur
            if end2 < end or start == start2:
                audio.write(f'{out_dir}/{i:05d}.wav', audio.interval(start2, end2))

tree_to_wavs(bad_tree, '../data/stutter/bad', dur=2.)
tree_to_wavs(good_tree, '../data/stutter/good', dur=2.)