# 42-wav-embedding-preprocess
> Generating embeddings for all allowed audio files

In this notebook, build a dataframe with embeddings from all of the audio files that we manually provided timestamps for.

In [None]:
#all_no_test
#default_exp audio_preprocessing

In [None]:
#export
# modeling packages
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import soundfile as sf
import torch
import librosa
import warnings
import difflib

# data science packages
import pandas as pd
import numpy as np

#saving data format files
import pyarrow as pa
import pyarrow.parquet as pq

# other python packages
import os.path
import glob
import re

# Set preliminaries
Here, we define the base filepath and look at some of the audio files that we have.

In [None]:
# base file path on accre
base_prefix = '/data/p_dsi/wise/data/'
audio_prefix = base_prefix + 'resampled_audio_16khz/'
audio_csv_prefix = '/data/p_dsi/wise/data/test_files/'

sampling_rate = 16000

# get list of resampled audio files
audio_files_list = glob.glob(audio_prefix + '*.wav')
print(audio_files_list[0])
len(audio_files_list)

/data/p_dsi/wise/data/resampled_audio_16khz/134-2.wav


109

In [None]:
#get list of data files
csv_files_list = glob.glob(audio_csv_prefix + '*.csv')
print(csv_files_list[0])
len(csv_files_list)

/data/p_dsi/wise/data/test_files/134-1.csv


11

# Pre-process data
Here, we provide some validation and pre-processing functionality because the audio files have blanks in the timestamp columns that need to be filled forward, non-standard formats on the timestamps, and other challenges. Let's write some functions that can help with this.

Note that the following complexities exist in the data:
* Files `120-1`, `251-1` currently have 1 digit in the milliseconds.
* File `008-1` has errors in the format of the timestamps currently (colons) at the beginning and 4 groupings instead of 3 (e.g. 00:09:41.870 instead of 00:09:41.870)
* File `123-1` only has 2 digits in the milliseconds.

## Time helper function
Here, we need to be able to express the time in milliseconds so that we'll be able to do computations on the time. Note that the following code expects for the timestamps to be perfect. This is used after the pre-processing of the timestamps, which makes sure that they're in the correct formats or have been removed altogether. Performance on different/variable timestamp formats is not guaranteed.

In [None]:
#export
def to_millseconds(time):
    '''
    Function to_millseconds: converts time with timestamp string of format '\d\d:\d\d.\d*' to milliseconds
    Inputs: time: String in required format
    Outputs: integer of converted time in milliseconds
    '''
    
    if isinstance(time, str)==False:
        raise TypeError('The input datatype of {0} must be a string to use to_milliseconds.'.format(time))
    
    #Timestamp pattern to use later
    ts_target = re.compile('\d{2}:\d{2}\.\d{1,}')
    if ts_target.match(time) is None:
        raise RuntimeError("The input of {0} does not match the format of \d\d:\d\d.\d*. Fix this before continuing.")
    
    #get split pieces
    sp = re.split(":|\.", time)
    
    #get milliseconds
    ms = int(sp[0])*60*1000
    ms = ms + float(sp[1] + '.' + sp[2])*1000
    ms = int(ms)
    
    return ms

In [None]:
print(to_millseconds("00:02.200"))
print(to_millseconds("00:02.25055"))
try:
    print(to_millseconds(4))
except Exception as e:
    print(e)

2200
2250
The input datatype of 4 must be a string to use to_milliseconds.


Some files have NaNs and filler NA characters. Let's make a pre-processing function to help us with this, keeping in mind that timestamp pre-processing already exists in a later commit of the repo, which can be added later.

In [None]:
#export
def short_warning(message):
    '''
    Function short_warning: shortened version of warnings.warn_explicit to remove unnecessary echo
        Input: message to be printed as warning message
        Output: warning
    '''
    warnings.warn_explicit(message, UserWarning, '', 0)

The following function `_fix_added_timestamp` is just a helper function to do some pre-processing to fix irregularly formatted timestamps. It will try to correct some timestamps automatically, and others, it will just drop (indicated in the code as a fatal error). The function reports all of its activities. The validation it performs is:
* Makes sure timestamps are strings - generates a fatal error indicator if the input fails
* Sees how many segments it has (e.g., except 00:00.000, not 00:00:00.000) - if it's not 3 or 4, this generates a fatal error indicator
* If it's not explicitly formatted as 00:00:000, it right or left pads the segments to reach the desired lengths. Note that the code will not _remove_ digits, so if 00:00:000 is 00:000:000 instead, this will just remain this way. However, we don't have this problem in our dataset.

In [None]:
#exporti
def _fix_added_timestamp(row_info):
    '''
    Function _fix_added_timestamp: validates timestamps and tries to fix them; returns a df with column
    `fatal_error` included. This is a pandas helper function and should not be applied directly without .apply.
    Input: row_info: pandas Series corresponding to a single row
    Returns: row_info with corrected timestamps or same timestamp with a new column 'fatal_error' with 1 if the
    timestamp could not be successfully converted.
    '''
    
    #Timestamp pattern to use later
    ts_target = re.compile('\d{2}:\d{2}\.\d{3}')
    
    #Keep count of fatal errors
    fatal_errors = 0

    for ts_type in ['start_timestamp', 'end_timestamp']:
        
        #Make sure it's a string
        if isinstance(row_info[ts_type], str)==False:
            short_warning('{0}: Row {1} has a {2} that is not a string with value {3}. Cannot automatically fix.'
                          .format(row_info['id'], row_info.name, ts_type, row_info[ts_type]))
            fatal_errors = fatal_errors + 1
            continue
            
        #See if it has too many segments
        ts_pieces = re.split(":|\.", row_info[ts_type])
        if len(ts_pieces) != 3:
            if len(ts_pieces) == 4:
                short_warning('{0}: Row {1} {2} with value {3} has 4 time parts instead of 3. Automatically fixing...'
                              .format(row_info['id'], row_info.name, ts_type, row_info[ts_type]))
                
                ts_pieces = ts_pieces[1:4]
                row_info[ts_type] = ts_pieces[0] + ':' + ts_pieces[1] + '.' + ts_pieces[2]
            else:
                short_warning('{0}: Row {1} with value {2} has {3} pieces in {4} and cannot be fixed automatically. Please amend.'
                             .format(row_info['id'], row_info.name, row_info[ts_type], len(ts_pieces), ts_type))
                fatal_errors = fatal_errors + 1
                continue
        
        #If it's perfect, let's just be done
        if ts_target.match(row_info[ts_type]) is not None:
            continue

        #Otherwise, let's get it into the right format
        ts_pieces[0] = ts_pieces[0].rjust(2,'0')
        ts_pieces[1] = ts_pieces[1].rjust(2,'0')
        ts_pieces[2] = ts_pieces[2].ljust(3,'0')
        
        #Update values
        short_warning('{0}: Row {1} {2} has the incorrect format of {3}. Automatically fixing...'
                      .format(row_info['id'], row_info.name, ts_type, row_info[ts_type]))
        row_info[ts_type] = ts_pieces[0] + ':' + ts_pieces[1] + '.' + ts_pieces[2]
        
    #Save fatal errors
    row_info['fatal_errors'] = fatal_errors
            
    return row_info
        

The `preprocess_audio_segments_csv` function does all the rest of the processing and at the end, removes fatal error rows. It performs the following cleaning/pre-processing functions:
* Drops columns that start with "Unnamed: "
* Strips leading and trailing whitespace from start and end timestamp
* Drops anything that is NA in start or end timestamp
* Calls the fix timestamp function above
* Drops fatal error rows due to timestamp malformation
* Adds columns for converting timestamps to milliseconds
* Add duration
* Generate fatal error for rows with negative or 0 durations
* Generates a fatal error for rows with duration_ms > duration_max (currently 15s)
* Drops all fatal error rows
* Drops fatal row column

In [None]:
#export
def preprocess_audio_segments_csv(csv_df, duration_max=15000):
    '''
    Function preproces_audio_segments_csv: pre-processes manually-entered timestamps to ensure correct format
    Inputs: csv_df: original dataframe with at least columns start_timestamp, end_timestamp, and id
            duration_max (default 15000): maximum length allowed for an utterance
    Returns: pandas dataframe with corrected or dropped timestamps, corresponding timestamps in ms, and duration
    '''
    
    #Drop unwanted "Unnamed" columns
    drop_cols = [drop_col for drop_col in csv_df.columns if drop_col.startswith('Unnamed')]
    csv_df.drop(columns=drop_cols, inplace=True)
    
    #Strip any leading or trailing whitespace
    csv_df['start_timestamp'] = csv_df['start_timestamp'].str.strip()
    csv_df['end_timestamp'] = csv_df['end_timestamp'].str.strip()
    
    #See if we need to drop NAs and notify of drops
    na_sz = len(csv_df.dropna(subset=['start_timestamp', 'end_timestamp']))
    if na_sz != len(csv_df):
        orig_sz = len(csv_df)
        csv_df.dropna(subset=['start_timestamp', 'end_timestamp'], inplace=True)        
        short_warning("You had {0} NA rows in start_timestamp or end timestamp which were dropped."
                      .format(na_sz))
    
    #See if we have wrong formats on timestamps and process or notify
    csv_df = csv_df.apply(_fix_added_timestamp, axis='columns')
    
    #Determine if the df can continue forward based on timestamps
    no_fatal_errors = csv_df['fatal_errors'].sum()
    if no_fatal_errors != 0:

        #display errors and get all rows except those with fatal errors
        error_rows = csv_df.query('fatal_errors!=0')
        short_warning('File {0} has {1} timestamp errors that cannot be automatically corrected. Dropping these rows.\nDropped row summary due to timestamp (truncated table):\n{2}'
                      .format(csv_df['id'][0], no_fatal_errors, error_rows[['id', 'start_timestamp', 'end_timestamp']]))
        csv_df = csv_df.drop(index=error_rows.index)
    
    #Convert times to milliseconds and calculate duration
    csv_df["start_ms"] = csv_df["start_timestamp"].apply(to_millseconds)
    csv_df["end_ms"] = csv_df["end_timestamp"].apply(to_millseconds)
    csv_df["duration_ms"] = csv_df['end_ms'] - csv_df["start_ms"]
    
    #Validate ms
    csv_df['fatal_errors'] = csv_df['duration_ms'].apply(lambda x: 0 if x > 0 else 1)
    csv_df['fatal_errors'] = csv_df.apply(lambda x: x['fatal_errors'] if x['duration_ms'] <= duration_max else 1,
                                         axis=1)
    no_fatal_errors = csv_df['fatal_errors'].sum()
    if no_fatal_errors != 0:
        
        #display errors and get all rows except those with fatal errors
        error_rows = csv_df.query('fatal_errors!=0')
        short_warning('File {0} has {1} time duration issues. Dropping these rows.\nDropped row summary due to duration (truncated table):\n{2}'
                      .format(csv_df['id'][0], no_fatal_errors,
                              error_rows[['id', 'start_ms', 'end_ms', 'duration_ms']]))
        csv_df = csv_df.drop(index=error_rows.index)
        
    #Once we've removed fatal errors (or have no fatal errors, drop the column and return)
    csv_df.drop(columns=['fatal_errors'], inplace=True)  
    
    #Get the indices together correctly
    csv_df.reset_index(drop=True, inplace=True)
    
    return csv_df

## A few unit tests
Let's just do basic due diligence to make sure that this is working...

In [None]:
test_df = pd.read_csv('/data/p_dsi/wise/data/test_files/008-1.csv')
test_df = preprocess_audio_segments_csv(test_df)

This seems to work correctly, but we've cleaned the output here because it is replicated below when all of the files are read.
# Generate cleaned csv files
Now, let's generate a list of the pre-processed csv_files. We'll use `csv_files_list` from above, and also get a list of all of the actual file numbers.

In [None]:
#Get file numbers
csv_files_nos = [re.split('/|\.', fname)[-2] for fname in csv_files_list]
print(csv_files_nos)

['134-1', '055-1', '083-2', '273-3', '120-1', '083-3', '251-1', '008-1', '123-1', '134-2', '083-1']


In [None]:
#Load data
raw_dfs = [pd.read_csv(fname) for fname in csv_files_list]

In [None]:
[print(file_no, ':', len(raw_df)) for raw_df, file_no in zip(raw_dfs, csv_files_nos)];

134-1 : 220
055-1 : 207
083-2 : 135
273-3 : 252
120-1 : 166
083-3 : 166
251-1 : 187
008-1 : 170
123-1 : 194
134-2 : 158
083-1 : 129


In [None]:
#Clean data
cleaned_dfs = [preprocess_audio_segments_csv(raw_df) for raw_df in raw_dfs]



Dropped row summary due to duration (truncated table):
        id  start_ms  end_ms  duration_ms
221  273-3    524420  522760        -1660


Dropped row summary due to duration (truncated table):
        id  start_ms  end_ms  duration_ms
3    251-1     13900  167200       153300
5    251-1     28200  148300       120100
85   251-1    287200  286700         -500
131  251-1    452200  443300        -8900
Dropped row summary due to timestamp (truncated table):
       id start_timestamp end_timestamp
88  008-1               —             —
89  008-1               —             —
90  008-1               —             —
Dropped row summary due to duration (truncated table):
        id  start_ms  end_ms  duration_ms
164  008-1    579341  578016        -1325


Dropped row summary due to duration (truncated table):
        id  start_ms  end_ms  duration_ms
32   123-1    111430  111380          -50
63   123-1    130250  193000        62750
192  123-1    597750   30000      -567750
193  123-1     30000  601870       571870
Dropped row summary due to duration (truncated table):
       id  start_ms  end_ms  duration_ms
28  134-2    119403  140445        21042


# Functions to generate the embeddings
Now, we'll generate the audio embeddings. The main function generates the embeddings based on a selected model and processor and assist by adding a column of data for milliseconds. The helper function directly generates the embeddings.

In [None]:
#exporti
def _get_audio_embeddings(row_info, wav_file, aud_processor, aud_mdl, samp_rate):
    '''
    Function _get_audio_embeddings: generates embeddings for a wave file using a model. Function not to be used
    directly without pandas .apply function.
    Inputs: row_info: pandas Series of row info with minimally start_index and end_index
            wav_file: list or numpy array of wave file
            aud_processor: huggingface audio processor for inputs
            aud_mdl: huggingface audio model to generate embeddings
            samp_rate: sampling rate of audio
    Outputs: pandas Series of row info with added columns 'last_hidden_state',
    'shape_state', and 'last_hidden_state_mean'
    '''
    
    #Get the processed input values using the processor
    input_values = aud_processor(wav_file[row_info['start_index'] : row_info['end_index']],
                                 return_tensors="pt", sampling_rate = samp_rate).input_values
    
    #Get the embeddings values
    last_hidden_state = aud_mdl(input_values).last_hidden_state[0,:,:]
    row_info['last_hidden_state'] = last_hidden_state.tolist()
    row_info['shape_state'] = list(last_hidden_state.shape)
    row_info['last_hidden_state_mean'] = torch.mean(last_hidden_state, dim=0).tolist()
    
    #Return
    return row_info

In [None]:
#export
def add_audio_embeddings_info(pd_audio,
                              audio_no,
                              audio_processor,
                              audio_mdl,
                              sampling_rate = 16000, 
                              base_prefix = "/data/p_dsi/wise/data/resampled_audio_16khz/"):
    '''
    Input argument: 
        pd_audio: cleaned dataframe with cleaned start and end timestamps (correctly formatted into xx:xx.xxx)
        audio_no: String of audio_number (e.g., '083-1')
        audio_processor: HF audio processor (e.g., instantiated Wav2Vec2Processor)
        audio_mdl: HF audio base model (e.g., instantiated Wav2Vec2Model)
        sampling_rate (default 16000): integer of sampling rate of audio
        base_prefix (default '/data/p_dsi/wise/data/resampled_audio_16khz'): String of filepath to audio files
    Output:
        a pandas dataframe containing original csv file and addition columns including last hidden states matrix and vector
    '''
    
    #Print some info
    print('Working on file:', audio_no)
    
    #Read in timestamp csv file and corresponding audio file
    audio_wave, sr = sf.read(base_prefix + audio_no + '.wav')
           
    #Calculate indices in audio file
    cal_index = lambda x: int(x) * (sampling_rate // 1000)
    pd_audio["start_index"] = pd_audio["start_ms"].apply(cal_index)
    pd_audio["end_index"] = pd_audio["end_ms"].apply(cal_index)  
    
    #Add embeddings information
    pd_audio = pd_audio.apply(lambda x: _get_audio_embeddings(x, audio_wave,
                                                              audio_processor, audio_mdl, sampling_rate),
                             axis='columns')
    
    #Reset index to make sure continuous numbering
    pd_audio.reset_index(drop=True, inplace=True)
           
    #Return
    return pd_audio

## Unit test
Now, let's just check and make sure this is making sense.

In [None]:
# Initialize the processor and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
test_df = add_audio_embeddings_info(test_df, '008-1', processor, model)
test_df.head()

Working on file: 008-1


Unnamed: 0,id,transcript_filepath,wave_filename,speech,start_timestamp,end_timestamp,label,transcriber_id,start_ms,end_ms,duration_ms,start_index,end_index,last_hidden_state,shape_state,last_hidden_state_mean
0,008-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,acorn.,00:10.900,00:11.420,NEU,198,10900,11420,520,174400,182720,"[[-0.0706741139292717, 0.0020495434291660786, ...","[25, 768]","[-0.05303679406642914, 0.002984378021210432, -..."
1,008-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,and you all did a very nice job ((most of you ...,00:11.420,00:17.350,PRS,198,11420,17350,5930,182720,277600,"[[-0.06873799860477448, -0.019848378375172615,...","[296, 768]","[0.004339495673775673, -0.023257896304130554, ..."
2,008-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,(now) name has asked a very good question.,00:17.350,00:19.850,PRS,198,17350,19850,2500,277600,317600,"[[-0.11779239028692245, -0.004400355275720358,...","[124, 768]","[0.008603579364717007, 0.003087687771767378, -..."
3,008-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,does acorns turn tall (okay).,00:19.850,00:24.371,OTR,198,19850,24371,4521,317600,389936,"[[0.03819892182946205, 0.04985547810792923, -0...","[225, 768]","[-0.022842226549983025, 0.02001359686255455, -..."
4,008-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,(what) what grows to be tall.,00:25.016,00:27.106,OTR,198,25016,27106,2090,400256,433696,"[[-0.06892960518598557, 0.008833239786326885, ...","[104, 768]","[-0.013253878802061081, 0.002626909641548991, ..."


# Generate embeddings file
Now, we'll actually generate the embeddings files and save them. Note that the function below was the motivation for adding a `duration_max` during pre-processing; it seems as if some of the snippets are more than 500s in length, which is untenable. Additionally, given that individual utterances/phrases/sentences are transcribed, a duration of 500s is unlikely and probably an annotation error.

In [None]:
# Generate embeddings files
embeds_list = [add_audio_embeddings_info(cleaned_df, file_no, processor, model)
              for cleaned_df, file_no in zip(cleaned_dfs, csv_files_nos)]

Working on file: 134-1
Working on file: 055-1
Working on file: 083-2
Working on file: 273-3
Working on file: 120-1
Working on file: 083-3
Working on file: 251-1
Working on file: 008-1
Working on file: 123-1
Working on file: 134-2
Working on file: 083-1


## Check the label columns
Here, we just make sure that the labels are all correct. We use the `difflib` package to match the incorrect label to the closest label and then use that one. The function prints out all oddities.

In [None]:
#exporti
def _check_label(row, label_list):
    '''
    Function _check_label: Internal helper function with .apply in pandas to check label. Not to be used directly.
    Inputs: row: pandas Series of dataframe row with minimially 'label' column
            label_list: list of accepted labels in df
    Returns: warning or fixed label in row
    '''
    
    if row['label'] not in label_list:
        #Get match ratio
        matches = [difflib.SequenceMatcher(a=row['label'].lower(), b=test_label.lower()).ratio()
                   for test_label in label_list]
        
        #Get index of best match and set it
        maxindex = np.argmax(matches)
        best_label = label_list[maxindex]
        
        short_warning('File {0}: Row {1} has label {2}; replaced with {3}'
                      .format(row['id'], row.name, row['label'], best_label))
        
        #Fix
        row['label'] = best_label
        
    return row

In [None]:
#export
def check_label(df, label_list=None):
    """
    Check if there is any wrong labels in df
    Inputs: df: pandas data frame
            label_list (default None): list of accepted label names in label column or None to use defaults
    Output: throw warnings when encountering wrong labels, returns corrected labels
    """
    
    if label_list is None:
        label_list = ["OTR", "NEU", "REP", "PRS"]
        
    #Make sure label is right
    df = df.apply(lambda x: _check_label(x, label_list), axis='columns')
    
    return df

In [None]:
#Fix labels on all dfs
embeds_list = [check_label(df) for df in embeds_list]



## Save data as parquet file format
Now, we'll save all of the data into the `embeddings_parquets` directory on ACCRE in the base_prefix.

In [None]:
#export
def write_nd_parquet(df, filepath):
    '''
    Function write_nd_parquet: writes a parquet file with complex columns. May be unnecessary.
    Inputs: df: dataframe to be written
            filepath: full filepath for output
    Output: None, prints the filepath that the dataframe was written to.
    '''
    
    #Convert to table
    pq_table = pa.Table.from_pandas(df)
    
    #Save file
    pq.write_table(pq_table, filepath)
    print('Wrote dataframe to:', filepath)
    
    return

In [None]:
#Actually save the files
[write_nd_parquet(embed_df, base_prefix + 'embedding_parquet/' + file_no + '.parquet')
 for embed_df, file_no in zip(embeds_list, csv_files_nos)];

Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/134-1.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/055-1.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/083-2.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/273-3.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/120-1.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/083-3.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/251-1.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/008-1.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/123-1.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/134-2.parquet
Wrote dataframe to: /data/p_dsi/wise/data/embedding_parquet/083-1.parquet


## Reload one of parquet data and show its content

In [None]:
reload_arrow_c134_1 = pd.read_parquet(base_prefix + 'embedding_parquet/134-1.parquet')
reload_arrow_c134_1.head()

Unnamed: 0,id,transcript_filepath,wave_filename,speech,start_timestamp,end_timestamp,label,transcriber_id,start_ms,end_ms,duration_ms,start_index,end_index,last_hidden_state,shape_state,last_hidden_state_mean
0,134-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,good.,00:02.164,00:03.250,PRS,198,2164,3250,1086,34624,52000,"[[-0.04941119998693466, 0.02282359078526497, -...","[54, 768]","[-0.008398521691560745, 0.03856978937983513, -..."
1,134-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,what's this word.,00:03.250,00:04.014,OTR,198,3250,4014,764,52000,64224,"[[0.1669340878725052, -0.08026524633169174, -0...","[37, 768]","[0.048272937536239624, 0.007540153339505196, -..."
2,134-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,SEE spells~,00:06.308,00:08.170,OTR,198,6308,8170,1862,100928,130720,"[[0.021184563636779785, 0.03635965287685394, -...","[92, 768]","[-0.04078539460897446, 0.023049654439091682, -..."
3,134-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,what do you do with your eyes.,00:08.170,00:10.371,OTR,198,8170,10371,2201,130720,165936,"[[-0.02608766406774521, 0.012260298244655132, ...","[109, 768]","[-0.013341746293008327, 0.0059875613078475, -0..."
4,134-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,see.,00:12.000,00:13.473,NEU,198,12000,13473,1473,192000,215568,"[[-0.025457218289375305, 0.053853556513786316,...","[73, 768]","[-0.044089991599321365, 0.02164131961762905, -..."


In [None]:
#Check to make sure it's working on the reload
display(reload_arrow_c134_1['last_hidden_state_mean'][0][:10])
print(len(reload_arrow_c134_1['last_hidden_state_mean'][0]))

array([-0.00839852,  0.03856979, -0.04993065, -0.03876822,  0.02194165,
       -0.08988117,  0.0664359 , -0.03430568,  0.08918358, -0.37447736])

768


Fantastic! Things seem to be working well here!