# Model Evaluations

Quantitative evaluations of all speech generation models (SG-U, SG-C, B2S-Uv, B2S-Cv, B2S-Ur) using the Frechet Audio Distance and Inception Score.

In [None]:
from multiprocessing import Pool
import os
import sys
from tqdm import tqdm
sys.path.append('..') # append code directory to path

from frechet_audio_distance import FrechetAudioDistance, load_audio_task
import numpy as np

from utils.inception_score import InceptionScore

In [2]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [3]:
with open('../../data/HP_VariaNTS_intersection.txt', 'r') as f:
    words = f.read().split(',')

We inherit from the [`FrechetAudioDistance`](https://github.com/gudgud96/frechet-audio-distance) class to overwrite how it loads audio. By default, it loads all audio files in a given directory. However, we want to have more control over which files are loaded, e.g. to compare artificial and real speakers. Therefore, we change it to take a list of files instead.

In [4]:
class CustomFAD(FrechetAudioDistance):
    def __init__(self, use_pca=False, use_activation=False, verbose=False, audio_load_worker=8):
        super().__init__(use_pca, use_activation, verbose, audio_load_worker)

    def load_audio_files_from_list(self, file_list):
        with Pool(self.audio_load_worker) as p:
            res_list = [
                result for result in tqdm(
                    p.imap(load_audio_task, file_list), total=len(file_list), disable=not self.verbose)
            ]
        # res_list = [load_audio_task(fn) for fn in file_list]
        return res_list

    # This function is the same as the parent's function, except that audio_background and audio_eval are loaded with 
    # the load_audio_files_from_list function, instead of the default __load_audio_files that takes directories as input
    # It would be easier to override __load_audio_files directly, but this cannot be done since it's a private method.
    def score(self, background_files, eval_files, store_embds=False):
        audio_background = self.load_audio_files_from_list(background_files)
        audio_eval = self.load_audio_files_from_list(eval_files)

        embds_background = self.get_embeddings(audio_background)
        embds_eval = self.get_embeddings(audio_eval)

        if store_embds:
            np.save("embds_background.npy", embds_background)
            np.save("embds_eval.npy", embds_eval)

        assert len(embds_background) != 0 and len(embds_eval) != 0
        
        mu_background, sigma_background = self.calculate_embd_statistics(embds_background)
        mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)

        fad_score = self.calculate_frechet_distance(
            mu_background, 
            sigma_background, 
            mu_eval, 
            sigma_eval
        )

        return fad_score

In [5]:
frechet = CustomFAD(
    use_pca=False, 
    use_activation=False,
    verbose=False
)

Using cache found in /home/passch/.cache/torch/hub/harritaylor_torchvggish_master


For the Inception Score, we use a custom class that makes use of the `sklearn` MLP we trained in `src/notebooks/speech_classifier.ipynb` (see the implementation in `src/utils/inception_score.py` for details)

In [7]:
inception_score = InceptionScore(clf_path='../../exp/speech_classifier/speech_clf_pipeline_variants_aug.pickle', verbose=False)

# VariaNTS-based Models

In [9]:
# Get the files for the VariaNTS dataset
variants_path = "../../data/VariaNTS/VariaNTS_words_16kHz_HP_synth_aug_flattened_fixed-length"
variants_files = [os.path.join(variants_path, fn) for fn in os.listdir(variants_path)]

As a reference score for the FAD, we compute the FAD between real and artificial speakers.

In [None]:
real_speakers = [fn for fn in variants_files if int(fn.split('/')[-1].split('_')[0][1:]) <= 16]
fake_speakers = [fn for fn in variants_files if int(fn.split('/')[-1].split('_')[0][1:]) > 16]
assert len(real_speakers) == len(fake_speakers)

fad_score = frechet.score(real_speakers, fake_speakers)
print('Real vs. Fake:', round(fad_score, 4))

In [None]:
inc_score = inception_score(variants_files)
print('IS Real+Fake:', round(inc_score, 4))

inc_score = inception_score(real_speakers)
print('IS Real:', round(inc_score, 4))

inc_score = inception_score(fake_speakers)
print('IS Fake:', round(inc_score, 4))

## SG-U (No augs)


In [16]:
eval_path = "../../exp/SG-U_v9_noaug/waveforms/1000/all"
eval_files = [os.path.join(eval_path, fn) for fn in os.listdir(eval_path)]

In [28]:
fad_score = frechet.score(variants_files, eval_files)
round(fad_score, 4)

100%|██████████| 29920/29920 [00:08<00:00, 3569.44it/s]
100%|██████████| 880/880 [00:02<00:00, 338.91it/s]
100%|██████████| 29920/29920 [02:47<00:00, 178.50it/s]
100%|██████████| 880/880 [00:04<00:00, 181.63it/s]


23.352

In [17]:
inc_score = inception_score(eval_files)
round(inc_score, 4)

100%|██████████| 880/880 [00:19<00:00, 44.46it/s]


3.0851

## SG-U

In [18]:
eval_path = "../../exp/SG-U_v9/waveforms/230/all"
eval_files = [os.path.join(eval_path, fn) for fn in os.listdir(eval_path)]

In [31]:
fad_score = frechet.score(variants_files, eval_files)
round(fad_score, 4)

100%|██████████| 29920/29920 [00:09<00:00, 3096.30it/s]
100%|██████████| 880/880 [00:00<00:00, 3516.14it/s]
100%|██████████| 29920/29920 [02:47<00:00, 178.69it/s]
100%|██████████| 880/880 [00:04<00:00, 181.63it/s]


23.3348

In [19]:
inc_score = inception_score(eval_files)
round(inc_score, 4)

100%|██████████| 880/880 [00:18<00:00, 47.83it/s]


7.0429

## SG-C

In [14]:
eval_path = "/home/passch/exp/ClassCond-PT-v3_h256_d36_T200_betaT0.02_L1000_cond/waveforms/180/all"
eval_files = [os.path.join(eval_path, fn) for fn in os.listdir(eval_path)]

In [34]:
fad_score = frechet.score(variants_files, eval_files)
round(fad_score, 4)

100%|██████████| 29920/29920 [00:07<00:00, 3853.32it/s]
100%|██████████| 880/880 [00:02<00:00, 323.01it/s]
100%|██████████| 29920/29920 [02:54<00:00, 171.13it/s]
100%|██████████| 880/880 [00:04<00:00, 181.17it/s]


23.3266

In [15]:
inc_score = inception_score(eval_files)
round(inc_score, 4)

100%|██████████| 880/880 [00:12<00:00, 69.25it/s]


13.5953

## B2S-Cv (Brain- & Class-conditional Finetuning)

In [16]:
eval_path = "/home/passch/exp/BrainClassCond-FT-VariaNTS-v9_h256_d36_T200_betaT0.02_L1000_cond/waveforms/800"

eval_files_train = [os.path.join(eval_path, "train", fn) for fn in os.listdir(os.path.join(eval_path, "train"))]
eval_files_val = [os.path.join(eval_path, "val", fn) for fn in os.listdir(os.path.join(eval_path, "val"))]
eval_files = [*eval_files_train, *eval_files_val]

In [36]:
fad_score = frechet.score(variants_files, eval_files)
print(round(fad_score, 4))

fad_score = frechet.score(variants_files, eval_files_train)
print(round(fad_score, 4))

fad_score = frechet.score(variants_files, eval_files_val)
print(round(fad_score, 4))

100%|██████████| 29920/29920 [00:09<00:00, 3026.37it/s]
100%|██████████| 1008/1008 [00:03<00:00, 314.91it/s]
100%|██████████| 29920/29920 [02:56<00:00, 169.17it/s]
100%|██████████| 1008/1008 [00:05<00:00, 178.00it/s]


23.3273


100%|██████████| 29920/29920 [00:09<00:00, 3107.94it/s]
100%|██████████| 848/848 [00:00<00:00, 3069.81it/s]
100%|██████████| 29920/29920 [02:49<00:00, 176.52it/s]
100%|██████████| 848/848 [00:04<00:00, 180.13it/s]


23.3276


100%|██████████| 29920/29920 [00:09<00:00, 3300.39it/s]
100%|██████████| 160/160 [00:00<00:00, 3103.76it/s]
100%|██████████| 29920/29920 [02:48<00:00, 177.25it/s]
100%|██████████| 160/160 [00:00<00:00, 184.41it/s]


23.3257


In [17]:
inc_score = inception_score(eval_files)
print(round(inc_score, 4))

inc_score = inception_score(eval_files_train)
print(round(inc_score, 4))

inc_score = inception_score(eval_files_val)
print(round(inc_score, 4))

100%|██████████| 1008/1008 [00:18<00:00, 54.92it/s]


12.4959


100%|██████████| 848/848 [00:07<00:00, 114.38it/s]


13.3938


100%|██████████| 160/160 [00:01<00:00, 89.91it/s]


3.9999


## Brainconditional Finetuning (VariaNTS speech)

In [18]:
eval_path = "/home/passch/exp/BrainCond-FT-VariaNTS-v3_h256_d36_T200_betaT0.02_L1000_cond/waveforms/140"

eval_files_train = [os.path.join(eval_path, "train", fn) for fn in os.listdir(os.path.join(eval_path, "train"))]
eval_files_val = [os.path.join(eval_path, "val", fn) for fn in os.listdir(os.path.join(eval_path, "val"))]
eval_files = [*eval_files_train, *eval_files_val]

In [38]:
fad_score = frechet.score(variants_files, eval_files)
print(round(fad_score, 4))

fad_score = frechet.score(variants_files, eval_files_train)
print(round(fad_score, 4))

fad_score = frechet.score(variants_files, eval_files_val)
print(round(fad_score, 4))

100%|██████████| 29920/29920 [00:09<00:00, 3092.47it/s]
100%|██████████| 1008/1008 [00:02<00:00, 349.12it/s]
100%|██████████| 29920/29920 [02:50<00:00, 175.65it/s]
100%|██████████| 1008/1008 [00:06<00:00, 165.83it/s]


23.3333


100%|██████████| 29920/29920 [00:09<00:00, 3301.65it/s]
100%|██████████| 848/848 [00:00<00:00, 3278.55it/s]
100%|██████████| 29920/29920 [02:51<00:00, 174.95it/s]
100%|██████████| 848/848 [00:04<00:00, 179.29it/s]


23.3334


100%|██████████| 29920/29920 [00:10<00:00, 2972.45it/s]
100%|██████████| 160/160 [00:00<00:00, 2637.80it/s]
100%|██████████| 29920/29920 [02:47<00:00, 178.71it/s]
100%|██████████| 160/160 [00:00<00:00, 167.25it/s]


23.3332


In [19]:
inc_score = inception_score(eval_files)
print(round(inc_score, 4))

inc_score = inception_score(eval_files_train)
print(round(inc_score, 4))

inc_score = inception_score(eval_files_val)
print(round(inc_score, 4))

100%|██████████| 1008/1008 [00:18<00:00, 55.49it/s]


11.4359


100%|██████████| 848/848 [00:09<00:00, 93.70it/s] 


11.9151


100%|██████████| 160/160 [00:01<00:00, 94.95it/s] 

7.9639





# Reconstruction-based Model B2S-Ur

There is only one model trained to reconstruct the actual speaker voice recorded during the reading task, B2S-Ur. Because of this, we cannot compare the objective metrics with those of the other models above. Nonetheless, the below code computes the FAD for completeness, in case this becomes relevant at some point. Since we did not train a speech classifier for the speaker's data, there is no code to compute the Inception Score, though.

In [None]:
# Get files for the speech dataset recorded during reading of the Harry Potter chapter
hp_audio_path = "../../data/HP1_ECoG_conditional/sub-002_fixed-length"
hp_files = [os.path.join(hp_audio_path, fn) for fn in os.listdir(hp_audio_path) if fn.endswith('.wav')]

# with open('/home/passch/data/datasplits/HP1_ECoG_conditional/sub-002/train.csv', 'r') as f:
#     hp_train_files = f.read().split(',')
# with open('/home/passch/data/datasplits/HP1_ECoG_conditional/sub-002/val.csv', 'r') as f:
#     hp_val_files = f.read().split(',')
# hp_train_files = [os.path.join(hp_audio_path, fn) for fn in hp_train_files]
# hp_val_files = [os.path.join(hp_audio_path, fn) for fn in hp_val_files]

In [None]:
# For HP reference score, randomly separate the data 50 times and compute mean
# and variance of the FADs of all of them 
hp_fad_scores = []
for i in tqdm(range(50)):
    half = len(hp_files) // 2
    np.random.shuffle(hp_files)
    fad_score = frechet.score(hp_files[:half], hp_files[half:])
    hp_fad_scores.append(fad_score)
print(np.mean(hp_fad_scores), np.std(hp_fad_scores))
hp_files = sorted(hp_files)

100%|██████████| 50/50 [01:25<00:00,  1.71s/it]

0.5854788612242084 0.06481413255675972





In [29]:
# Get the model generated files for train and val set
eval_path = "../../exp/B2S-UR_v5/waveforms/70"

eval_files_train = [os.path.join(eval_path, "train", fn) for fn in os.listdir(os.path.join(eval_path, "train"))]
eval_files_val = [os.path.join(eval_path, "val", fn) for fn in os.listdir(os.path.join(eval_path, "val"))]
eval_files = [*eval_files_train, *eval_files_val]

In [29]:
fad_score = frechet.score(hp_files, eval_files)
print(round(fad_score, 4))

fad_score = frechet.score(hp_files, eval_files_train)
print(round(fad_score, 4))

fad_score = frechet.score(hp_files, eval_files_val)
print(round(fad_score, 4))

6.0585
6.0584
6.059
