In [20]:
import os, sys, glob
import pandas as pd
import subprocess

def run_command(cmd, cwd=None):
    """Utility function to run a shell command."""
    return subprocess.check_output(cmd, shell=True, cwd=cwd).decode().strip()

def copy_contents(in_dir, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    if not os.path.exists(in_dir):
        raise ValueError(f'Input directory {in_dir} does not exist.')
    
    run_command(f'cp -r {in_dir}/* {out_dir}')


In [21]:
from typing import Dict, Optional
from pathlib import Path
import json

def read_index(path: str) -> Dict:
    """Reads the data index, removing the initial stages of nesting."""
    path = Path(path)
    kind = os.path.splitext(path.name)[0]
    data = json.loads(path.read_text())
    return data["protocols"][kind]["subjects"]

def _index_dict_to_dataframe(data: Dict) -> pd.DataFrame:
    subjects = data.keys()
    entries = []

    for subject in subjects:
        experiments = data[subject]["experiments"]
        for experiment in experiments:
            sessions = experiments[experiment]["sessions"]
            for session in sessions:
                entry = sessions[session]
                entry["subject"] = subject
                entry["experiment"] = experiment
                entry["session"] = int(session)
                entries.append(entry)

    df = pd.DataFrame(entries)
    return df


In [22]:
# load session index directly
exps = ['ltpFR2']
data = list()
data.append(_index_dict_to_dataframe(read_index('/protocols/ltp.json')))
index_df = pd.concat(data, sort=True)

index_df

# if sys.version.split(' ')[0] >= '3.10' and sys.version.split(' ')[0] < '3.11':
#     index_df = pd.read_csv(f'session_index_{"-".join(exps)}.csv')
# else:
#     import cmlreaders as cml

#     index_df = cml.get_data_index('ltp')
#     print('Available experiments:\n', index_df.experiment.unique())
#     index_df = index_df.query('experiment in @exps')
#     index_df.to_csv(f'session_index_{"-".join(exps)}.csv')

#     # load example session events
#     sess_df = index_df.iloc[25]
#     r = cml.CMLReader(subject=sess_df.subject, session=sess_df.session, experiment=sess_df.experiment)
#     evs = r.load('events')
#     print('Event types:\n', evs.type.unique())

Unnamed: 0,all_events,experiment,import_type,math_events,original_session,session,subject,subject_alias,task_events
0,protocols/ltp/subjects/LTP063/experiments/ltpF...,ltpFR,build,protocols/ltp/subjects/LTP063/experiments/ltpF...,0,0,LTP063,LTP063,protocols/ltp/subjects/LTP063/experiments/ltpF...
1,protocols/ltp/subjects/LTP063/experiments/ltpF...,ltpFR,build,protocols/ltp/subjects/LTP063/experiments/ltpF...,1,1,LTP063,LTP063,protocols/ltp/subjects/LTP063/experiments/ltpF...
2,protocols/ltp/subjects/LTP063/experiments/ltpF...,ltpFR,build,protocols/ltp/subjects/LTP063/experiments/ltpF...,10,10,LTP063,LTP063,protocols/ltp/subjects/LTP063/experiments/ltpF...
3,protocols/ltp/subjects/LTP063/experiments/ltpF...,ltpFR,build,protocols/ltp/subjects/LTP063/experiments/ltpF...,11,11,LTP063,LTP063,protocols/ltp/subjects/LTP063/experiments/ltpF...
4,protocols/ltp/subjects/LTP063/experiments/ltpF...,ltpFR,build,protocols/ltp/subjects/LTP063/experiments/ltpF...,12,12,LTP063,LTP063,protocols/ltp/subjects/LTP063/experiments/ltpF...
...,...,...,...,...,...,...,...,...,...
7564,protocols/ltp/subjects/PLTP811/experiments/pre...,prelim,build,,0,0,PLTP811,PLTP811,protocols/ltp/subjects/PLTP811/experiments/pre...
7565,protocols/ltp/subjects/PLTP812/experiments/pre...,prelim,build,,0,0,PLTP812,PLTP812,protocols/ltp/subjects/PLTP812/experiments/pre...
7566,protocols/ltp/subjects/PLTP813/experiments/pre...,prelim,build,,0,0,PLTP813,PLTP813,protocols/ltp/subjects/PLTP813/experiments/pre...
7567,protocols/ltp/subjects/PLTP814/experiments/pre...,prelim,build,,0,0,PLTP814,PLTP814,protocols/ltp/subjects/PLTP814/experiments/pre...


In [23]:
def get_folders_containing_files(directory=".", pattern="*"):
    """
    Get folders containing files that match the specified glob pattern.
    
    Parameters:
    - directory (str): The starting directory for the search.
    - pattern (str): The glob pattern to search for. e.g. "*.wav", "audio_*.mp3", "document_?.txt", etc.

    Returns:
    - list: A list of folders containing files that match the glob pattern.
    """

    # Get all the files matching the pattern in the directory and its sub-directories.
    matching_files = glob.glob(os.path.join(directory, '**', pattern), recursive=True)
    
    # Get the unique directories containing the files that match the pattern.
    folders = set(os.path.dirname(file) for file in matching_files)

    return list(folders)


# build train-val-test split
from sklearn.model_selection import GroupShuffleSplit

random_seed = 42
train_prop = 0.5
val_prop = 0.2
test_prop = 0.3
assert train_prop + val_prop + test_prop == 1

index_df['split_group'] = index_df['subject'] + '_' + index_df['experiment']
train_val_idx, test_idx = next(GroupShuffleSplit(train_size=train_prop + val_prop, 
                                                 test_size=test_prop, 
                                                 random_state=random_seed).split(index_df,
                                                                                 groups=index_df['split_group']))
train_idx, val_idx = next(GroupShuffleSplit(train_size=train_prop / (train_prop + val_prop),
                                            test_size=val_prop / (train_prop + val_prop),
                                            random_state=random_seed + 1).split(index_df.iloc[train_val_idx],
                                                                                groups=index_df.iloc[train_val_idx]['split_group']))
index_dfs = {'train': index_df.iloc[train_val_idx].iloc[train_idx],
             'val': index_df.iloc[train_val_idx].iloc[val_idx],
             'test': index_df.iloc[test_idx]}

# assert no overlap across splits
assert not set(index_dfs['train'].split_group).intersection(index_dfs['test'].split_group)
assert not set(index_dfs['train'].split_group).intersection(index_dfs['val'].split_group)
assert not set(index_dfs['val'].split_group).intersection(index_dfs['test'].split_group)

# obtain session data directories
input_dirs = {'train': list(), 'val': list(), 'test': list()}
for exp in exps:
    # get all session directories containing .wav files
    wav_dirs = get_folders_containing_files(pattern="0.wav", directory=f'/data/eeg/scalp/ltp/{exp}')
    # drop sessions marked bad
    wav_dirs = {d for d in wav_dirs if not 'bad' in d.lower()}
    for split in input_dirs:
        # keep only sessions that were processed through event_creation for quality control
        exp_dirs = {f'/data/eeg/scalp/ltp/{sess.experiment}/{sess.subject}/session_{sess.session}' 
                    for _, sess in index_dfs[split].iterrows()}
        exp_dirs = exp_dirs.intersection(wav_dirs)
        input_dirs[split].extend(list(exp_dirs))
        print(split)

for split in input_dirs:
    print(f'{split} sessions available: {len(input_dirs[split])}')
print('Example session directories:')
input_dirs['train'][:5]

train
val
test
train sessions available: 1161
val sessions available: 540
test sessions available: 768
Example session directories:


['/data/eeg/scalp/ltp/ltpFR2/LTP334/session_11',
 '/data/eeg/scalp/ltp/ltpFR2/LTP365/session_19',
 '/data/eeg/scalp/ltp/ltpFR2/LTP269/session_16',
 '/data/eeg/scalp/ltp/ltpFR2/LTP384/session_2',
 '/data/eeg/scalp/ltp/ltpFR2/LTP301/session_8']

In [24]:
# obtain session processing output directories
base_dir = os.getcwd()

tags = ['base-whisperx']
all_output_dirs = dict()
all_input_dirs = dict()
output_dirs = dict()

for tag in tags:
    output_dirs[tag] = dict()
    all_input_dirs[tag] = list()
    all_output_dirs[tag] = list()
    for split in input_dirs:
        output_dirs[tag][split] = [
            os.path.join('results', tag, d) if not os.path.isabs(d) else os.path.join('results', tag, split, d[1:])
            for d in input_dirs[split]
        ]
        all_output_dirs[tag].extend(output_dirs[tag][split])
        all_input_dirs[tag].extend(input_dirs[split])
print('Sessions to process:', len(all_input_dirs[tag]))


# was thinking the easiest method to access ground truth would be to compare .csv automated annotation outputs with .ann files, 
# but then decided it'd be cleaner to just use cmlreaders

# # create data set with ground truth annotations (.ann) and word list files (.lst)

# dataset_path = f'data/{exp}'

# for split in splits:
#     split_path = os.path.join(dataset_path, split)
#     for path in output_dirs[tag][split]:
#         raw_data_path = path.split(f'{tag}/{split}')[-1]
#         cp_files = glob.glob(os.path.join(raw_data_path, '*.ann')) + glob.glob(os.path.join(raw_data_path, '*.lst'))
#         out_path = split_path + (raw_data_path if os.path.isabs(raw_data_path) else '/' + raw_data_path)
#         if not os.path.exists(out_path): os.makedirs(out_path)
#         for full_file in cp_files:
#             file = os.path.split(full_file)[-1]
#             run_command(f'cp {full_file} {os.path.join(out_path, file)}')

Sessions to process: 2469


In [25]:
# save all annotation input/output directories
import pickle
with open('input_dirs.pkl', 'wb') as f:
    pickle.dump(all_input_dirs, f)
with open('output_dirs.pkl', 'wb') as f:
    pickle.dump(all_output_dirs, f)
    
# for convenience also save out output directories broken out by train/val/test splits
with open('input_dirs_splits.pkl', 'wb') as f:
    pickle.dump(input_dirs, f)
with open('output_dirs_splits.pkl', 'wb') as f:
    pickle.dump(output_dirs, f)

In [11]:
# Define the number of sessions you want to test
num_test_sessions = 5  # Adjust as needed

# Select the first 'num_test_sessions' input and output directories
test_input_dirs = all_input_dirs[tag][:num_test_sessions]
test_output_dirs = all_output_dirs[tag][:num_test_sessions]

print('Test sessions selected:', len(test_input_dirs))
print('Example session directories:')
print(test_input_dirs)

# Save test input/output directories
with open('test_input_dirs.pkl', 'wb') as f:
    pickle.dump(test_input_dirs, f)
with open('test_output_dirs.pkl', 'wb') as f:
    pickle.dump(test_output_dirs, f)

print('Test input and output directories have been saved to "test_input_dirs.pkl" and "test_output_dirs.pkl".')

Test sessions selected: 5
Example session directories:
['/data/eeg/scalp/ltp/ltpFR2/LTP341/session_18', '/data/eeg/scalp/ltp/ltpFR2/LTP365/session_2', '/data/eeg/scalp/ltp/ltpFR2/LTP341/session_4', '/data/eeg/scalp/ltp/ltpFR2/LTP389/session_0', '/data/eeg/scalp/ltp/ltpFR2/LTP246/session_22']
Test input and output directories have been saved to "test_input_dirs.pkl" and "test_output_dirs.pkl".


In [23]:
# load annotation input/output directories
import pickle
with open('input_dirs.pkl', 'rb') as f:
    input_dirs = pickle.load(f)
with open('output_dirs.pkl', 'rb') as f:
    output_dirs = pickle.load(f)

In [15]:
input_dirs

{'base-whisperx': ['/data/eeg/scalp/ltp/ltpFR2/LTP341/session_18',
  '/data/eeg/scalp/ltp/ltpFR2/LTP365/session_2',
  '/data/eeg/scalp/ltp/ltpFR2/LTP341/session_4',
  '/data/eeg/scalp/ltp/ltpFR2/LTP389/session_0',
  '/data/eeg/scalp/ltp/ltpFR2/LTP246/session_22',
  '/data/eeg/scalp/ltp/ltpFR2/LTP295/session_5',
  '/data/eeg/scalp/ltp/ltpFR2/LTP260/session_8',
  '/data/eeg/scalp/ltp/ltpFR2/LTP385/session_11',
  '/data/eeg/scalp/ltp/ltpFR2/LTP307/session_7',
  '/data/eeg/scalp/ltp/ltpFR2/LTP362/session_2',
  '/data/eeg/scalp/ltp/ltpFR2/LTP269/session_11',
  '/data/eeg/scalp/ltp/ltpFR2/LTP355/session_10',
  '/data/eeg/scalp/ltp/ltpFR2/LTP379/session_14',
  '/data/eeg/scalp/ltp/ltpFR2/LTP133/session_16',
  '/data/eeg/scalp/ltp/ltpFR2/LTP258/session_9',
  '/data/eeg/scalp/ltp/ltpFR2/LTP386/session_11',
  '/data/eeg/scalp/ltp/ltpFR2/LTP373/session_6',
  '/data/eeg/scalp/ltp/ltpFR2/LTP341/session_1',
  '/data/eeg/scalp/ltp/ltpFR2/LTP354/session_1',
  '/data/eeg/scalp/ltp/ltpFR2/LTP336/session

In [9]:
# # load original input/output directories that didn't contain subdirectories for splits ('{tag}/{split}') in paths
# import pickle
# with open('input_dirs_no_split.pkl', 'rb') as f:
#     old_all_input_dirs = pickle.load(f)
# with open('output_dirs_no_split.pkl', 'rb') as f:
#     old_all_output_dirs = pickle.load(f)


In [9]:
# # transfer output directories from original structure to structure including split subdirectories
# new_outs_no_split = set([d.replace('train/', '').replace('val/', '').replace('test/', '') for d in all_output_dirs[tag]])#[:10]#[0].split('data/')[-1]
# old_outs = set(old_all_output_dirs[tag])
# assert new_outs == old_outs

# for old_dir, new_dir in zip(old_all_output_dirs[tag], all_output_dirs[tag]):
#     try:
#         copy_contents(old_dir, new_dir)
#     except ValueError as e:
#         print(e)

## Debugging session mismatch issue

In [9]:
tag = 'base-whisperx'
splits = ['train', 'val', 'test']
for inp, out in zip(input_dirs[tag], output_dirs[tag]):
    # print(inp)
    out_split = None
    for split in splits:
        if split in out: out_split = split
    if isinstance(out_split, type(None)): raise ValueError
    out_strip = out.split(f'{tag}/{out_split}')[-1]
    if inp != out_strip:
        print(inp)
        print(out_strip)
        print()

ValueError: 

In [None]:
tag = 'base-whisperx'
for inp, out in zip(old_all_input_dirs[tag], old_all_output_dirs[tag]):
    # print(inp)
    out_strip = out.split(f'{tag}')[-1]
    assert inp == out_strip
    # print()

In [None]:
for i, d in enumerate(input_dirs[tag]):
    if 'LTP123' in d and 'session_5' in d:
        print(i)

In [10]:
old_all_input_dirs[tag][387]

NameError: name 'old_all_input_dirs' is not defined

In [11]:
for i, d in enumerate(old_all_input_dirs[tag]):
    if 'LTP123' in d and 'session_5' in d:
        print(i)

NameError: name 'old_all_input_dirs' is not defined

In [None]:
input_dirs[tag][133]

In [None]:
# load debug trial (LTP123 ltpFR2 session 5 trial 1)
exp = 'ltpFR2'
import cmlreaders as cml  # (can't install cmlreaders for py3.10)

index_df = cml.get_data_index('ltp')
print('Available experiments:\n', index_df.experiment.unique())
index_df = index_df.query('experiment == @exp')

# load example session events
sess_df = index_df.query('subject == "LTP123" and experiment==@exp and session==5').iloc[0]
r = cml.CMLReader(subject=sess_df.subject, session=sess_df.session, experiment=sess_df.experiment)
evs = r.load('events')
print('Event types:\n', evs.type.unique())


In [None]:
evs.columns

In [None]:
rec_evs = evs.query('type == "REC_WORD" and trial == 1')[['subject', 'experiment', 'session', 'trial', 'item_name']]
rec_evs

# Run model

In [10]:
from automated_annot import run_whisperx
run_whisperx(all_input_dirs[tag][0], all_output_dirs[tag][0])

  from .autonotebook import tqdm as notebook_tqdm
Downloading model.bin: 100%|██████████| 3.09G/3.09G [00:40<00:00, 76.6MB/s]


No language specified, language will be first be detected for each audio file (increases inference time).


TypeError: TranscriptionOptions.__new__() missing 4 required positional arguments: 'max_new_tokens', 'clip_timestamps', 'hallucination_silence_threshold', and 'hotwords'

In [None]:
from automated_annot import run_whisperx
from cmldask import CMLDask
from dask.distributed import wait

tag = 'base-whisperx'
dask_args = {'job_name': 'auto_annotate', 'memory_per_job': "9GB", 'max_n_jobs': 35,
            'death_timeout': 600, 'extra': ['--no-dashboard'], 'log_directory': 'logs'}

client = CMLDask.new_dask_client_slurm(**dask_args)
dask_inputs = [all_input_dirs[tag][:1], all_output_dirs[tag][:1]]
futures = client.map(run_whisperx, *dask_inputs)
wait(futures)

In [None]:
del client

In [None]:
# estimated number of days to run whisperx over data set assuming ~5 minutes per list recording
n_sess = 1200  #len(all_input_dirs[tag])
n_lists = 24
n_cores = 150
min_per_list = 5
n_days = n_sess * n_lists * min_per_list / (n_cores * 24 * 60)
print(f'Estimated days of model inference: {n_days:0.4}')

In [None]:
# get processed sessions and ZIP

tag = 'base-whisperx'
model_out = 'whisperx_out'

splits = ['train', 'val', 'test']
output_paths = dict()

if 'output_dir' not in globals():
    import pickle
    with open('output_dirs_splits.pkl', 'rb') as f:
        output_dirs = pickle.load(f)

dataset_path = os.path.join(f'results/{tag}')
for split in ['*'] + splits:
    split_path = os.path.join(dataset_path, f'{split}')
    path = os.path.join(split_path, f'data/eeg/scalp/ltp/ltpFR2/*/session_*/{model_out}/*.csv'.replace(' ', '').replace("'", ''))
    outputs = glob.glob(path)
    
    sess_path = os.path.join(split_path, f'data/eeg/scalp/ltp/ltpFR2/*/session_*/{model_out}'.replace(' ', '').replace("'", ''))
    sess_outputs = glob.glob(sess_path)
    print('Split:', split.replace('*', 'All'))
    print('\tNumber of processed .wav files:', len(outputs))
    if split == '*':
        print('\tNumber of processed sessions:', len(sess_outputs))
    else:
        print('\tNumber of processed sessions:', len(sess_outputs), '/', len(output_dirs[tag][split]))
    output_paths[split] = outputs

orig_cwd = os.getcwd()
os.chdir(dataset_path)
try:
    for split in splits:
        zip_file = f'{tag}_{split}.zip'
        if os.path.exists(zip_file): run_command(f'rm {zip_file}')
        run_command(f'zip -r {zip_file} {split}/')
except Exception as e:
    os.chdir(orig_cwd)
os.chdir(orig_cwd)


In [None]:

# get word pool and sanity consistency across experiment
import pandas as pd
evs = pd.read_csv('/home1/rdehaan/projects/DataMemoryBrainsSolutions/data/LTPFR2_WORD__REC_WORD_EVENTS.csv')


In [None]:
evs.subject.unique().shape

In [None]:
word_evs = evs.query('type == "WORD"')
word_evs.

In [None]:
evs.type.unique()

In [None]:
sub = "LTP123"
sess = 5
sess_evs = evs.query('subject == @sub and session == @sess and trial == 1 and type in ["REC_WORD"]')
sess_evs.mstime
# reltime from first word, so not from start of recall period since REC_START events not included in compiled event CSV
# correct with offset from automated annotations
sess_evs['reltime'] = sess_evs.mstime - sess_evs.mstime.iloc[0] + 734
sess_evs[['item_name', 'reltime']]

## Voice Activity Detection (VAD) Exploration

In [18]:
def find_first_blank_line(filename):
    with open(filename, 'r') as file:
        for line_number, line in enumerate(file, 1):
            # Using strip() to remove whitespace characters, including spaces and newline
            # If the result is an empty string, then it's a blank line
            if not line.strip():
                return line_number
    return None

# find_first_blank_line(f'dependencies/annotation_gauntlet/session_0_QC/{list_num}.ann')


def load_ann(ann_file):
    ann = pd.read_csv(ann_file, delimiter='\t',
                      names=['onset', 'item_num', 'item_name'], 
                      skiprows=find_first_blank_line(ann_file))
    return ann


def SaveFig(basename):
    d = os.path.split(basename)[0]
    if not os.path.exists(d):
        os.makedirs(d)
    plt.savefig(basename+'.png', bbox_inches='tight')
    plt.savefig(basename+'.pdf', bbox_inches='tight')


In [19]:
# TODO 
# check for last time point of VAD output

import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt
from whisperx.vad import Binarize


def plot_recall_recording(audio, vad_pred, ann, auto_ann, 
                          sr = 16000,  # sample rate in Hz
                          n_secs=15, plot_basename=None):
    # Penn-Total Recall defaults
    lowcut = 1000
    highcut = min(16000, sr / 2 - 1)
    order = 3
    b, a = butter(order, [lowcut, highcut], btype='band', fs=sr)
    audio_filt = filtfilt(b, a, audio)

    t = np.linspace(0, 30, len(audio))[:n_secs * sr]
    vad_out = vad_pred.data
    t_vad = np.linspace(0, 30, len(vad_out))[:int(n_secs / 30 * len(vad_out))]
    vad_plot = vad_out[:int(n_secs / 30 * len(vad_out))]
    plt.figure(figsize=(50, 8))
    audio_plot = audio_filt[:n_secs * sr] * 2.5 * vad_out.std() / audio_filt.max()
    plt.plot(t, audio_plot, label='Audio')
    yplot_min = np.max([audio_plot.min(), -0.7]) - 0.2
    plt.plot(t_vad, vad_plot, label='WhisperX VAD')
    for i, (_, row) in enumerate(ann.iterrows()):
        if row.onset / 1000 > n_secs: continue
        plt.text(x=row.onset / 1000, y=1.05, s=row.item_name, fontsize=30, color='g')
        if i == 0: plt.vlines(x=row.onset / 1000, ymin=yplot_min, ymax=1.0, linestyles='--', color='g', label='True Onset')
        else: plt.vlines(x=row.onset / 1000, ymin=yplot_min, ymax=1.0, linestyles='--', color='g')

    for i, (_, row) in enumerate(auto_ann.iterrows()):
        if row.onset / 1000 > n_secs: continue
        plt.text(x=row.onset / 1000 + 0.03, y=yplot_min+0.1, s=row.item_name, fontsize=30, color='r')
        true_ann = ann.query('item_name == @row.item_name')
        if true_ann.shape[0] > 1:
            # search for nearest match among repeats
            true_ann['diff'] = np.abs(true_ann.onset - row.onset)
            idx = np.argmin(true_ann['diff'])
            true_ann = true_ann.iloc[idx:idx+1]
            
        if true_ann.shape[0] == 1:
            err = row.onset - true_ann.onset.item()
            plt.text(x=row.onset / 1000 + 0.03, y=yplot_min -0.035, s=f'{int(err)}', fontsize=30, color='r')
        if i == 0: plt.vlines(x=row.onset / 1000, ymin=yplot_min, ymax=1.0, linestyles='--', color='r', label='Auto Onset')
        else: plt.vlines(x=row.onset / 1000, ymin=yplot_min, ymax=1.0, linestyles='--', color='r')

    plt.legend(fontsize=25)
    plt.xlabel('Recording Time (s)', fontsize=30)
    plt.xticks(fontsize=25)
    plt.ylabel('Probability of Speech', fontsize=30)
    _ = plt.yticks(fontsize=25)
    plt.ylim([yplot_min - 0.05, 1.0])
    
    vad_params = dict(onset=0.8, offset=0.78)
    # vad_params = dict(onset=0.5, offset=0.363)  # whisperX defaults
    binarize = Binarize(**vad_params)

    segments = binarize(vad_pred)

    segments_list = []
    for speech_turn in segments.get_timeline():
        segments_list.append([speech_turn.start, speech_turn.end])

    data = segments_list
    centers = [(a+b)/2 for a, b in data if b < n_secs]
    errors = [(b-a)/2 for a, b in data if b < n_secs]
    y_positions = [0.75] * len(centers)
    plt.errorbar(centers, y_positions, xerr=errors, fmt='o', linestyle='', color='b', ecolor='b', elinewidth=2.5, capsize=30)
    plt.yticks(y_positions)
    plt.grid(True, axis='x', linestyle='--', alpha=0.7)
    
    if plot_basename:
        SaveFig(plot_basename)

plot_recall_recording(audio, vad_pred, ann, auto_ann)

ModuleNotFoundError: No module named 'whisperx'

In [None]:
import numpy as np
import torch
import pickle
import pdb
from whisperx.vad import load_vad_model
from whisperx.asr import SAMPLE_RATE
import whisperx

use_gpu = False
device = "cuda:0" if use_gpu else "cpu"
vad = load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None)


in_dir = 'test/whisperx_vad_fun/audio'
ann_dir = 'dependencies/annotation_gauntlet/session_0_QC'
# in_dir = 'dependencies/annotation_gauntlet/session_0'
# from gauntlet:
# hard initial phonemes: 11.wav
# mix of hard/soft initial phonemes: 10.wav

# with open('input_dirs.pkl', 'rb') as f:
#     input_dirs = pickle.load(f)
# in_dir = input_dirs['train'][0]

for list_num in range(1, 12): 
    file = in_dir + f'/{list_num}.npy'
    audio = np.load(file)
    # load .ann files, which demarcate end of header with blank line
    ann_file = os.path.join(ann_dir, f'{list_num}.ann')
    ann = load_ann(ann_file)[['item_name', 'onset', 'item_num']]
    
    auto_ann = pd.read_csv(os.path.join(ann_dir, 'whisperx_out', f'{list_num}.csv'))
    auto_ann = auto_ann.rename(columns={'Word': 'item_name', 'Onset': 'onset', 'Offset': 'offset', 'Probability': 'probability'})

    vad_pred = vad({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
    plot_recall_recording(audio, vad_pred, ann, auto_ann, plot_basename=f'test/whisperx_vad_fun/plots/recall_period_list{list_num}')
    
plt.show()

In [14]:
auto_ann

NameError: name 'auto_ann' is not defined

In [15]:
ann

NameError: name 'ann' is not defined

In [None]:
segments_list