# 50-audio-training
> Starting to use audio for training

In this notebook, we use a few custom labeled transcripts (see [Issue #49](https://github.com/vanderbilt-data-science/wise/issues/49) for details) to extract subsegments of the audio files which correspond to the labels.  For this reason, we can now use this to directly train the head of a Wav2Vec2 Sequence Classification model.  We'll look into subsetting the data reliably and training the models below.

In [None]:
#all_no_test
#default_exp audio_modeling

In [None]:
#export
#modeling imports
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification, pipeline, TrainingArguments, Trainer
from datasets import load_metric
import torch
import soundfile as sf
import torch
import librosa

#ds imports
import pandas as pd
import numpy as np

#python imports
import os.path
import glob
import re
import warnings

# Organize data
First, we need to have the data in some sort of reasonable form.  We'll make some functions here that can help us out with this.

In [None]:
#file constants
base_prefix = '/data/p_dsi/wise/data/'
sample_csv_dir = base_prefix + 'test_files/'
audio_dir = base_prefix + 'resampled_audio_16khz/'
test_audio_id = '055-1'

In [None]:
sampling_rate = 16000

## Read in sampled csv
Currently, we're just going to take a look at a few files that have been hand-labeled with timestamps provided.  Let's check out just one to start out with.

In [None]:
available_csvs = glob.glob(sample_csv_dir + '*.csv')
len(available_csvs)

2

In [None]:
#print some info
print('Using file:', available_csvs[0])

#read dataframe and preview
ts_df = pd.read_csv(available_csvs[0])
display(ts_df.head())
ts_df.shape

Using file: /data/p_dsi/wise/data/test_files/055-1.csv


Unnamed: 0,id,transcript_filepath,wave_filename,speech,start_timestamp,end_timestamp,label,transcriber_id,Notes
0,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,(okay) we are gonna go on and get started guys.,00:01.000,00:03.380,NEU,198,
1,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,we are gonna do a little bit of reviewing with...,00:03.750,00:06.763,NEU,198,
2,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,(now) keep in mind that we are playing the goo...,00:07.150,00:12.769,NEU,198,
3,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,everyone look up here please.,00:14.012,00:16.260,NEU,198,
4,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,let's go over the problems.,00:16.615,00:18.000,NEU,198,


(207, 9)

Things are looking as expected here.  We can clearly see that we'll have to do some work on the timestamp to get it into a sampling index.

## Conversion of timestamp to sampling index
Here, we'll make some functions to help with the generation of the sampling index.

In [None]:
#export
def timestamp2index(ts, round_type = 'ceil', sampling_rate=16000):
    '''
    Function timestamp2index: converts a timestamp with format dd:dd.ddd to an index given the sampling rate
        ts: string of timestamp in
        round_type (default 'ceil'): string of rounding to perform; can be 'ceil' or 'floor'
        sampling_rate (default 16000): integer of the sampling rate (in Hz) of the audio
    Returns: integer of index of converted timestamp or None if formatted incorrectly
    '''
    
    #define regex
    ts_pat = re.compile('(\d{1,2}):(\d{1,2}).(\d{1,3})')
    
    #get the match
    ts_match = ts_pat.match(ts)
    
    #throw a warning if you have issues
    if ts_match is None:
        warnings.warn('There is an issue with value: {0} and it could not be converted.'.format(ts))
        return None
    
    #convert to full time (note that ljust zero pads on the right)
    ts_seconds = 60*int(ts_match.group(1)) + int(ts_match.group(2)) + int(ts_match.group(3).ljust(3,'0'))/1000
    
    #identify rounding type
    round_func = np.ceil
    if round_type == 'floor':
        round_func = np.floor

    #create index and apply rounding
    ts_ind = int(round_func(ts_seconds * sampling_rate))
    
    return ts_ind

In [None]:
#A few unit tests
ts_utests = ['00:00.000',
             '01:00.000',
             '00:01.000',
             '00:00.500',
             '01:01.50']
[print('Timestamp:', uts, 'Index:', timestamp2index(uts)) for uts in ts_utests];

Timestamp: 00:00.000 Index: 0
Timestamp: 01:00.000 Index: 960000
Timestamp: 00:01.000 Index: 16000
Timestamp: 00:00.500 Index: 8000
Timestamp: 01:01.50 Index: 984000


Fantastic.  This appears to work correctly.  Let's add this onto the data, then.

In [None]:
ts_df['start_index'] = ts_df['start_timestamp'].apply(lambda x: timestamp2index(x, round_type='floor'))
ts_df['end_index'] = ts_df['end_timestamp'].apply(lambda x: timestamp2index(x, round_type='ceil'))

In [None]:
ts_df.head(3)

Unnamed: 0,id,transcript_filepath,wave_filename,speech,start_timestamp,end_timestamp,label,transcriber_id,Notes,start_index,end_index
0,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,(okay) we are gonna go on and get started guys.,00:01.000,00:03.380,NEU,198,,16000,54080
1,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,we are gonna do a little bit of reviewing with...,00:03.750,00:06.763,NEU,198,,60000,108208
2,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,(now) keep in mind that we are playing the goo...,00:07.150,00:12.769,NEU,198,,114400,204304


Fantastic.  It looks like things are looking good in terms of reading the data.

## Adding on integer label
We also need to have an integer label in the dataset.  Let's make and add that here.

In [None]:
#Create dictionary
label_dict = {0:"PRS", 1:"OTR", 2:"NEU", 3:"REP"}

#Invert original
rev_label_dict = {value:key for key, value in label_dict.items()}
rev_label_dict

#Substitute in dataframe
ts_df['i_label'] = ts_df['label'].replace(rev_label_dict)

# Preparing Inputs to Model
Here, we'll use the facebook wav2vec2 models, but we need to do some prep on the inputs to make sure things will go well.  Let's check it out.

## Split the data
We're going to choose to just randomly split the data willy nilly.  Let's check this out.

In [None]:
#randomly permute
arr_df = ts_df.sample(frac=1, random_state=2021)

#assign split based on physical location after reordering
arr_df = arr_df.reset_index()
arr_df = arr_df.rename(columns={'index':'true_order'})
arr_df['split'] = (arr_df.index>np.ceil(len(arr_df)*0.8)).astype(int)
arr_df.head(3)

Unnamed: 0,true_order,id,transcript_filepath,wave_filename,speech,start_timestamp,end_timestamp,label,transcriber_id,Notes,start_index,end_index,i_label,split
0,99,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,(okay) four plus two plus six.,05:02.843,05:06.045,OTR,198,,4845488,4896720,1,0
1,24,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,nine.,01:16.100,01:16.852,NEU,198,,1217600,1229632,2,0
2,23,055-1,~/Box Sync/DSI Documents/cleaned_data/cleaned_...,~/Box Sync/DSI Documents/Audio Files & Tanscri...,name?,01:14.500,01:15.302,OTR,198,,1192000,1204832,1,0


## Pre-process audio data

In [None]:
#read audio data
class_audio, class_sr = sf.read(audio_dir + test_audio_id + '.wav')

In [None]:
#get subsets of audio as a list
audio_clips_train = [class_audio[start:end] for start, end in arr_df.query('split==0')[['start_index', 'end_index']].values]
audio_clips_test = [class_audio[start:end] for start, end in arr_df.query('split==1')[['start_index', 'end_index']].values]

In [None]:
#this looks about right
print(len(audio_clips_train))
print(len(audio_clips_train[0]))
print(len(audio_clips_test))
print(len(audio_clips_test[0]))

167
51232
40
167952


In [None]:
#load processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

In [None]:
#process inputs appropriately
train_inputs = processor(audio_clips_train, return_tensors="pt", padding="longest", sampling_rate=sampling_rate)
test_inputs = processor(audio_clips_test, return_tensors="pt", padding="longest", sampling_rate=sampling_rate)

# Train the model
Now, we have all of our inputs ready, let's try to train this model!

In [None]:
#helpers for class size and class names
no_classes = len(label_dict)

In [None]:
#Create custom Datasets Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

#Create datasets from encodings
train_dataset = CustomDataset(train_inputs, arr_df.query('split==0')['i_label'].tolist())
val_dataset = CustomDataset(test_inputs, arr_df.query('split==1')['i_label'].tolist())

## Create model for task

In [None]:
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=no_classes, id2label=label_dict)

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForSequenceClassification: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['classifier.weight', 'projector.bias', 'wav2vec2.masked_spec_embed', 'classifier.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be 

We see the error above and we're happy to see it.  This means that we've added the "Sequence Classification" part onto the base and it realizes that the assigned weights are meaningless.  Perfect!

## Setup and model training

In [None]:
#set parameters around training
training_args = TrainingArguments("test_trainer",
                                  num_train_epochs = 3,
                                  logging_strategy='epoch', 
                                  evaluation_strategy='epoch',
                                  per_device_train_batch_size=3,
                                  per_device_eval_batch_size=3,
                                  report_to='all'
                                 )

#define the metric; we use accuracy here but we shouldn't
metric = load_metric("accuracy")

#function to calculate metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
#train the model
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=processor,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)
trainer.train()

***** Running training *****
  Num examples = 167
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 63
  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


Step,Training Loss
21,1.3264
42,1.239
63,1.1935




Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=63, training_loss=1.2529784005785745, metrics={'train_runtime': 38.9867, 'train_samples_per_second': 12.851, 'train_steps_per_second': 1.616, 'total_flos': 3.905273132082662e+16, 'train_loss': 1.2529784005785745, 'epoch': 3.0})

In [None]:
trainer.evaluate(train_dataset)

***** Running Evaluation *****
  Num examples = 40
  Batch size = 8
  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


{'eval_loss': 1.1072285175323486,
 'eval_accuracy': 0.55,
 'eval_runtime': 2.1104,
 'eval_samples_per_second': 18.954,
 'eval_steps_per_second': 2.369,
 'epoch': 3.0}

Well!  This is pretty exciting!  We can train the model, which is great!  The performance, on the other hand, is terrible.  There are many ways I think this can be remedied, the first of which would be running more epochs.  We'll take a look!