## Speaker Identification using Whisper
* Train a classification model to identify if the speaker in a audio segment is lex or not

### Data:
* Expecting training audio clips in data/audio_dataset. Download from [here](https://drive.google.com/drive/folders/1LkR8oskSMNFo-YV-YOGGPucr-J0nEk9W?usp=share_link)

In [29]:
import torch
from pathlib import Path
import os
import sys
import pandas as pd
from tqdm import tqdm

print('is cuda available:', torch.cuda.is_available())

# Add whisper repo to path to import
repo_dir = Path(os.getcwd()).parents[0]/'whisper'
sys.path.append(str(repo_dir))
import whisper


is cuda available: True


### Load Whisper model

In [6]:
model = whisper.load_model("small.en")


### Labelled dataset

* Contains start and end defining segment of the audio clip
* audio_name is the name of the podcast. The podcasts audio files should be present in data/lex_podcasts
* `is_lex` is the label. 1 if speaker from start time to end time of the audio clip is lex. 0 if not

#### Data augmentation:
* All samples with audio_idx < 0 are auto generated based on keywords. Choose to use this however in training, since all audio clips will share a very similar spectral pattern

In [32]:
audio_dataset_dir = Path('data/audio_dataset')
labelled_path = 'data/labelled_dataset.csv'
df = pd.read_csv(labelled_path)

df.head()

Unnamed: 0,start,end,text,fname,audio_name,audio_idx,is_lex
0,02:49:11.280,02:49:15.120,"And, you know, some people also ask, are you ...",episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",0,0.0
1,02:20:14.140,02:20:17.260,I still do that often.,episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",1,1.0
2,00:19:15.360,00:19:17.320,things that you put into context of GPT.,episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",2,0.0
3,02:45:11.760,02:45:16.000,"and that also gives, you know, huge perspecti...",episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",3,0.0
4,01:33:44.600,01:33:49.160,"You, it's often the way how it works is you o...",episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",4,0.0


In [33]:
podcasts = list(df['audio_name'].unique())
print('Num audio clips:', len(df))
print(f'Num unique podcasts: {len(podcasts)}')

print(f'\n\nPostcasts containing the most tagged clips\n{df["audio_name"].value_counts().head(8)}')
print(f'\n\nPostcasts containing the least tagged clips\n{df["audio_name"].value_counts().tail(2)}')

Num audio clips: 698
Num unique podcasts: 68


Postcasts containing the most tagged clips
Elon Musk： Neuralink, AI, Autopilot, and the Pale Blue Dot ｜ Lex Fridman Podcast #49                   65
Ray Dalio： Principles, the Economic Machine, AI & the Arc of Life ｜ Lex Fridman Podcast #54            21
Judea Pearl： Causal Reasoning, Counterfactuals, and the Path to AGI ｜ Lex Fridman Podcast #56          21
Dmitry Korkin： Computational Biology of Coronavirus ｜ Lex Fridman Podcast #90                          21
Jeremy Howard： fast.ai Deep Learning Courses and Research ｜ Lex Fridman Podcast #35                    21
Cumrun Vafa： String Theory ｜ Lex Fridman Podcast #204                                                  21
Po-Shen Loh： Mathematics, Math Olympiad, Combinatorics & Contact Tracing ｜ Lex Fridman Podcast #183    20
Jim Keller： Moore's Law, Microprocessors, and First Principles ｜ Lex Fridman Podcast #70               20
Name: audio_name, dtype: int64


Postcasts containing the leas

## Get Whisper Embeddings
* Requires GPU to finish fast. If it takes too long on CPU, consider using a smaller Whisper model 

In [7]:
if torch.cuda.is_available():
    model = model.cuda()
    
if not os.path.exists(audio_dataset_dir):
    if not audio_dataset_dir.exists():
        raise ValueError('Expecting audio clips data in ', audio_dataset_dir)
        
filenames = df['audio_idx'].apply(lambda x: str(x) + '.mp3')
audio_paths  = [audio_dataset_dir/filename for filename in filenames]

idx_to_path = {idx: path for idx, path in enumerate(audio_paths)}

hidden_l1 = []
hidden_l2 = []
hidden_last = []

metadata_outputs = []
for audio_path in tqdm(audio_paths, total=len(audio_paths), disable=False):
    
    assert os.path.exists(audio_path)

    # load audio
    audio = whisper.load_audio(str(audio_path))
    audio = whisper.pad_or_trim(audio)
    
    # create mel spectogram input for whisper encoder
    mel = whisper.log_mel_spectrogram(audio).to(model.device)
    mel = mel[None, :, :]
    
    # forward pass through encoder
    with torch.no_grad():
        _ = model.embed_audio(mel)
    del _
    
    # get various hidden states of enocder
    hidden_l1.append(model.encoder.encoder_out1.cpu())
    hidden_l2.append(model.encoder.encoder_out2.cpu())
    hidden_last.append(model.encoder.encoder_out_last.cpu())
    
    
    metadata_outputs.append({"metadata": None, "audio_path": audio_path.name})


100%|████████████████████████████████████████████████████████████████████████████████| 648/648 [01:43<00:00,  6.27it/s]


In [9]:
df2 = pd.concat([df, pd.DataFrame(metadata_outputs)], axis=1)
df2.head()

Unnamed: 0,start,end,text,fname,audio_name,audio_idx,is_lex,metadata,audio_path
0,02:49:11.280,02:49:15.120,"And, you know, some people also ask, are you ...",episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",0,0.0,,0.mp3
1,02:20:14.140,02:20:17.260,I still do that often.,episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",1,1.0,,1.mp3
2,00:19:15.360,00:19:17.320,things that you put into context of GPT.,episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",2,0.0,,2.mp3
3,02:45:11.760,02:45:16.000,"and that also gives, you know, huge perspecti...",episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",3,0.0,,3.mp3
4,01:33:44.600,01:33:49.160,"You, it's often the way how it works is you o...",episode_215,"Wojciech Zaremba： OpenAI Codex, GPT-3, Robotic...",4,0.0,,4.mp3


In [None]:
# import pickle

# for checkpoint_dir in ['data', 'data/checkpoint']:
#     checkpoint_dir = Path(checkpoint_dir)
    
#     if not checkpoint_dir.exists():
#         checkpoint_dir.mkdir()
        
# df.to_csv(checkpoint_dir/'df_checkpoint.csv', index=False)

# with open(checkpoint_dir/'vector_outputs2.pkl', 'wb') as f:
#     pickle.dump(vector_outputs2, f)
    
# with open(checkpoint_dir/'vector_outputs.pkl', 'wb') as f:
#     pickle.dump(vector_outputs, f)

### Create Feature Vectors for Classifier

In [46]:
def get_feature_vector(vector_outputs):
    """
    Get features to train classifier. 
    
    Concatate mean features across 3 time windows 
    
    """
    new_vectors = []
    for row in vector_outputs:
        out = [row[0, :500, :].mean(1).flatten(), row[0, 500:1000, :].mean(1).flatten(), row[0, 1000:, :].mean(1).flatten()]
        new_vectors.append(torch.cat(out, -1)[None, :])
        
    new_vectors = torch.cat(new_vectors, 0)
    return new_vectors

def get_feature_vector2(vector_outputs):
    """
    Get features to train classifier. 
    
    Mean of features across entire timewindow of a clip
    
    """
    new_vectors = []
    for row in vector_outputs:
        out = row[0, :, :].mean(0)
        new_vectors.append(out[None, :])
        
    new_vectors = torch.cat(new_vectors, 0)
    return new_vectors

def get_feature_vector3(vector_outputs):
    """
    Get features to train classifier. 
    
    Mean and std of features across entire timewindow of a clip
    
    """
    
    new_vectors = []
    for row in vector_outputs:
        out = row[0, :, :].mean(0)
        out2 = row[0, :, :].std(0)
        new_vectors.append(torch.cat([out[None, :], out2[None, :]], dim=1))
        
    new_vectors = torch.cat(new_vectors, 0)
    return new_vectors


new_vectors = get_feature_vector3(hidden_l1)


In [47]:
from collections import Counter
import numpy as np 

def cv_split(df, seed=42):
    """
    Split on speaker
    """
    
    import random
    
    randgen = random.Random(seed)
        
    subdf = df[df['audio_idx'] >= 0]
    
    # negative audio_idx are not manually labelled, heuristically labelled. 
    # Near 100% are accurate, but speech pattern is same so only use few samples
#     augmented_df = df[df['audio_idx'] < -1].sample(15)
    augmented_df = pd.DataFrame()
    
    # split based on speakerid 
    sources = list(subdf['audio_name'].unique())
    
    test_split = randgen.sample(sources, len(sources) // 4)
    train_split = list(set(sources).difference(test_split))

    test_df = df[df['audio_name'].isin(test_split)]
    train_df = df[df['audio_name'].isin(train_split)]
    
    # add heuristic data to labelled data
    train_df = pd.concat([train_df, augmented_df], axis=0)
                               
    
    return train_df, test_df
    
def generate_splits(df, num_splits=5):
    """
    Splitting on speaker ID randomly leads to very high class imbalance in test set - in some podcasts lex tags are very few. 
    The strategy is to keep splitting on random seeds until a split of 40%-60% is reached
    
    """
    
    cvs = []
    seen_seeds = set()
    for split in range(num_splits):
        seeds = np.random.randint(low=0, high=1000, size=(50,))
        for seed in seeds: 
            # cant use same seed again 
            if seed in seen_seeds:
                continue
            train_df, test_df = cv_split(df, seed=seed)
            counts = test_df['is_lex'].value_counts()
            counts = counts/counts.sum()
            if (counts.loc[1.0] >= 0.40) and (counts.loc[1.0] <= 0.60):
                print(f'Split found with ratio: {counts.to_dict()}: seed: {seed}')
                break 
                
        cvs.append((train_df, test_df))
        seen_seeds.add(seed)
        
    return cvs 

        
splits = generate_splits(df, num_splits=5)        


Split found with ratio: {0.0: 0.5942028985507246, 1.0: 0.4057971014492754}: seed: 800
Split found with ratio: {0.0: 0.5989010989010989, 1.0: 0.4010989010989011}: seed: 404
Split found with ratio: {0.0: 0.5786516853932584, 1.0: 0.42134831460674155}: seed: 11
Split found with ratio: {0.0: 0.5944444444444444, 1.0: 0.40555555555555556}: seed: 766


In [54]:
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, classification_report
from sklearn.exceptions import ConvergenceWarning
import warnings 
from sklearn import preprocessing
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn import svm

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=ConvergenceWarning)

def get_metrics(true, pred):
    f1, recall, precision, accuracy = f1_score(true, pred), recall_score(true, pred), precision_score(true, pred), accuracy_score(true, pred)
    
    return {'f1': f1, 'recall': recall, 'precision': precision, 'accuracy': accuracy}

def train_eval(X, splits):
    """
    Train model
    
    Inputs: 
        - X: All features X 
        - splits: List of tuples of (train_df, test_df)
    
    """
    test_metrics = []
    train_metrics = []
    fold_preds = []
    for train_df, test_df in splits:
        X_train, y_train = X[list(train_df.index), :], train_df['is_lex']
        X_test, y_test = X[list(test_df.index), :], test_df['is_lex']

        scalar = preprocessing.StandardScaler()
        
        # overfits fast with logistic regression
#         clf = LogisticRegression(random_state=0, max_iter=15, C=.7)
        clf = svm.SVC(kernel='rbf', C=.7)

        pipeline = Pipeline([('transformer', scalar), ('estimator', clf)])

        pipeline.fit(X_train, y_train)
        pred_train = pipeline.predict(X_train)
        pred_test = pipeline.predict(X_test)


        train_metrics.append(get_metrics(y_train, pred_train))
        test_metrics.append(get_metrics(y_test, pred_test))
        
        fold_df = test_df.copy()
        fold_df['preds'] = pred_test
        
        fold_preds.append(fold_df)
        
        
        

    train_metrics = pd.DataFrame(train_metrics)
    test_metrics = pd.DataFrame(test_metrics)

    display('Test stats ', test_metrics.describe().loc[[ 'mean','std']])
    display('Train stats ', train_metrics.describe().loc[[ 'mean','std']])
    
    return train_metrics, test_metrics, fold_preds



## Featurize and train

### Featurize:
* Each hidden state is of shape (batch_size, 1500, hidden_size). The 1500 is the hidden state across 1500 time periods. We need to summarize features across the 3 time windows and create a feature vector of size (batch_size, hidden_size)
* This what get_feature_vector functions do: 
    * `get_feature_vector3`: Performs the best. Creates mean and std features across time window. Concatenate mean and std across time. Creates feature of shape (batch_size, hidden_size + hidden_size).
    * `get_feature_vector2`: Creates mean features across time windows. (batch_size, hidden_size)
    * `get_feature_vector1`: Calculates mean features across 3 time windows of 1500 (1-500, 500-1000, 1000-1500). Creates feature shape of (batch_size, hidden_size * 3)
 
### Train and Predict
* Train classifier

In [58]:
# utilize different hidden states for training and evaluate
for X, hidden_name in zip([hidden_l1, hidden_l2, hidden_last], ['first_hidden', 'second_hidden', 'last_hidden']):
          
    # create features. 
    # X_use: (num_samples, 768 * 2)
    X_use = get_feature_vector3(X)
    
    print(f'\n----Metrics of output of {hidden_name} encoder block output----\n')
    train_metrics, test_metrics, pred_dfs = train_eval(X_use, splits)
    


----Metrics of output of first_hidden encoder block output----



'Test stats '

Unnamed: 0,f1,recall,precision,accuracy
mean,0.705483,0.690793,0.736748,0.780618
std,0.076161,0.127613,0.075481,0.052567


'Train stats '

Unnamed: 0,f1,recall,precision,accuracy
mean,0.880556,0.813446,0.960803,0.936389
std,0.019417,0.036626,0.008286,0.0067



----Metrics of output of second_hidden encoder block output----



'Test stats '

Unnamed: 0,f1,recall,precision,accuracy
mean,0.798997,0.812523,0.793472,0.843412
std,0.039503,0.083739,0.063103,0.027118


'Train stats '

Unnamed: 0,f1,recall,precision,accuracy
mean,0.933816,0.898089,0.972818,0.963249
std,0.009799,0.020933,0.006511,0.003601



----Metrics of output of last_hidden encoder block output----



'Test stats '

Unnamed: 0,f1,recall,precision,accuracy
mean,0.376044,0.251682,0.917949,0.693209
std,0.146189,0.127624,0.092173,0.06515


'Train stats '

Unnamed: 0,f1,recall,precision,accuracy
mean,0.740069,0.595155,0.994393,0.882272
std,0.080055,0.105802,0.007821,0.025725


### Predict
* Using hidden state at layer 2 performs the best. 

In [63]:

X_use = get_feature_vector3(hidden_l2)
train_metrics, test_metrics, pred_dfs = train_eval(X_use, splits)


'Test stats '

Unnamed: 0,f1,recall,precision,accuracy
mean,0.798997,0.812523,0.793472,0.843412
std,0.039503,0.083739,0.063103,0.027118


'Train stats '

Unnamed: 0,f1,recall,precision,accuracy
mean,0.933816,0.898089,0.972818,0.963249
std,0.009799,0.020933,0.006511,0.003601


In [68]:
fold_idx = 0
foldk_preds = pred_dfs[fold_idx]

display('Metrics ', test_metrics.loc[fold_idx])

display('Preds', foldk_preds.head())

'Metrics '

f1           0.785047
recall       0.750000
precision    0.823529
accuracy     0.833333
Name: 0, dtype: float64

'Preds'

Unnamed: 0,start,end,text,fname,audio_name,audio_idx,is_lex,preds
19,02:04:05.240,02:04:10.240,"Yes, he manipulated her as well, lied to her,...",episode_288,"Sarma Melngailis： Bad Vegan, Fraud, Prison, an...",20,1.0,0.0
20,00:50:15.200,00:50:19.520,"And in a sense, not really knowing what was g...",episode_288,"Sarma Melngailis： Bad Vegan, Fraud, Prison, an...",21,0.0,0.0
21,01:41:24.240,01:41:25.240,Thank you.,episode_288,"Sarma Melngailis： Bad Vegan, Fraud, Prison, an...",22,0.0,0.0
22,03:12:28.240,03:12:39.240,"This sounds not to mock people, but this soun...",episode_288,"Sarma Melngailis： Bad Vegan, Fraud, Prison, an...",23,1.0,0.0
23,00:33:04.560,00:33:05.520,Alec Baldwin.,episode_288,"Sarma Melngailis： Bad Vegan, Fraud, Prison, an...",24,1.0,0.0


#### List to audio_idx clip and explore the predictions

In [75]:
foldk_preds.sample(15)

Unnamed: 0,start,end,text,fname,audio_name,audio_idx,is_lex,preds
482,02:09:39.120,02:09:43.600,Long ago when you were baby Po or today,episode_183,"Po-Shen Loh： Mathematics, Math Olympiad, Combi...",494,1.0,1.0
420,00:10:53.140,00:10:55.100,How do you think about language?,episode_120,François Chollet： Measures of Intelligence ｜ L...,431,1.0,1.0
23,00:33:04.560,00:33:05.520,Alec Baldwin.,episode_288,"Sarma Melngailis： Bad Vegan, Fraud, Prison, an...",24,1.0,0.0
479,01:29:44.120,01:29:45.680,where you know what a fraction is,episode_183,"Po-Shen Loh： Mathematics, Math Olympiad, Combi...",491,0.0,0.0
354,00:38:14.560,00:38:21.760,"you know, if there are millions or billions o...",episode_154,"Avi Loeb： Aliens, Black Holes, and the Mystery...",364,0.0,0.0
36,00:55:29.600,00:55:35.920,you'll be more tired in the afternoon. So if ...,episode_288,"Sarma Melngailis： Bad Vegan, Fraud, Prison, an...",37,0.0,0.0
310,00:12:49.400,00:12:52.400,but the fact that a virus can just take over,episode_090,Dmitry Korkin： Computational Biology of Corona...,317,1.0,1.0
577,01:12:11.520,01:12:15.080,Something is still pushing things outward.,episode_232,"Brian Greene： Quantum Gravity, The Big Bang, A...",594,0.0,0.0
351,02:15:01.680,02:15:07.600,like proposing different kind of hypotheses o...,episode_154,"Avi Loeb： Aliens, Black Holes, and the Mystery...",361,1.0,1.0
307,01:29:37.400,01:29:46.400,"And then the third is, you know, things go mu...",episode_090,Dmitry Korkin： Computational Biology of Corona...,314,1.0,1.0
