After submitting all jobs with `source slurm/whisper_decode_video_slurm_wrapper.sh`, use this notebook to print the results of all decoding runs. It will load the decoding WER / BLEU scores and print them in a convinient table.

In [24]:
import os
def print_results(results, beam):
    # Print table headers
    print("Beam size: {}".format(beam))
    # Extract languages
    languages = list(results[beam].keys())
    # Extract results
    audio_clean = [results[beam][lang][modalities[0]][str(noises[0])] for lang in languages]
    audio_visual_clean = [results[beam][lang][modalities[1]][str(noises[0])] for lang in languages]
    audio_babble_lrs3 = [results[beam][lang][modalities[0]][str(noises[1])] for lang in languages]
    audio_visual_babble_lrs3 = [results[beam][lang][modalities[1]][str(noises[1])] for lang in languages]
    # Print Audio-Clean row
    print('Audio-Clean ', end='')
    for val in audio_clean:
        print(str(val) + ' ', end='')
    print()
    # Print Audio-Visual-Clean row
    print('Audio-Visual-Clean ', end='')
    for val in audio_visual_clean:
        print(str(val) + ' ', end='')
    print()
    # Print Audio-Babble row
    print('Audio-Babble-LRS3 ', end='')
    for val in audio_babble_lrs3:
        print(str(val) + ' ', end='')
    print()
    # Print Audio-Visual-Babble row
    print('Audio-Visual-Babble-LRS3 ', end='')
    for val in audio_visual_babble_lrs3:
        print(str(val) + ' ', end='')
    print()
    print("Avg clean non En: {}".format(round(sum(audio_clean[1:]) / (len(languages) -1), 1)))
    print("Avg clean non En: {}".format(round(sum(audio_visual_clean[1:]) / (len(languages) -1), 1)))
    print("Avg noisy non En: {}".format(round(sum(audio_babble_lrs3[1:]) / (len(languages) -1), 1)))
    print("Avg noisy non En: {}".format(round(sum(audio_visual_babble_lrs3[1:]) / (len(languages) -1), 1)))

In [25]:
import os
# root = '../decode/models/checkpoint/'

root = '../decode/models/'
checkpoint = 'whisper-flamingo_multi-all_small.pt'

# ASR zero-shot
# root = '../decode/'
# checkpoint = 'large-v2'
# checkpoint = 'medium'
# checkpoint = 'small'

# fixed
# langs = ['en', 'ar', 'de', 'el', 'es', 'fr', 'it', 'pt', 'ru'] 
langs = ['en', 'es', 'fr', 'it', 'pt',] 
# langs = ['en', 'ar', 'de', 'el', 'ru'] 

noises = [1000, 0] # clean, 0
modalities = ['asr', 'avsr']
# beams = [1]
beams = [5]
visible = 0 # full eval set
noise_fn = 'lrs3'

results = {beam: {lang: {modality: {str(noise): 0 for noise in noises} for modality in modalities} for lang in langs} for beam in beams}
for beam in beams:
    for lang in langs:
        for noise in noises:
            for modality in modalities:
                try:
                    # file = 'bleu.368862'
                    file = 'wer.368862'
                    with open(os.path.join(root, checkpoint, lang, 'test', modality, 'snr-{}'.format(noise), 'visible-{}'.format(visible), 'beam-{}'.format(beam), noise_fn, file)) as f:                    
                        first_line = f.readline().strip('\n')
                        # prefix = 5 if lang == 'en' else 6
                        prefix = 5
                        results[beam][lang][modality][str(noise)] = round(float(first_line[prefix:]), 2)
                except:
                    continue
    
print('Languages ', end='')
for lang in langs:
    print(lang + ' ', end='')
print()
for beam in beams:
    print_results(results, beam)

Languages en es fr it pt 
Beam size: 5
Audio-Clean 0 0 0 0 0 
Audio-Visual-Clean 4.21 9.56 13.77 12.74 12.87 
Audio-Babble-LRS3 0 0 0 0 0 
Audio-Visual-Babble-LRS3 8.73 33.61 31.94 41.19 42.78 
Avg clean non En: 0.0
Avg clean non En: 12.2
Avg noisy non En: 0.0
Avg noisy non En: 37.4


# Noise Type Analysis

In [26]:
# root = '../decode/models/checkpoint/'
# checkpoint = 'whisper-flamingo_medium_multi-all_normalized_b0.6'

root = '../decode/models/'
# checkpoint= 'whisper_multi-all_medium.pt'
# checkpoint = 'whisper_multi-all_small.pt'
# checkpoint = 'whisper-flamingo_multi-all_small.pt'
checkpoint = 'whisper-flamingo_multi-all_medium.pt'

# ASR zero-shot
# checkpoint = 'medium'
# checkpoint = 'small'
# root = '../decode/'

normalizer = 'fairseq'
langs = ['en', 'es', 'fr', 'it', 'pt',] 
noises = [-10, -5, 0, 5, 10] # clean, 0
noise_fns = ['lrs3', 'muavic', 'babble', 'speech','music', 'noise', ]
modalities = ['asr', 'avsr']
beams = [5]
visible = 0 # full set


results = {noise_fn: {beam: {lang: {modality: {str(noise): 0 for noise in noises} for modality in modalities} for lang in langs} for beam in beams} for noise_fn in noise_fns}
for noise_fn in noise_fns:
    print(noise_fn)
    for beam in beams:
        for lang in langs:
            for noise in noises:
                for modality in modalities:
                    try:
                        # file = 'wer.368862' if lang == 'en' else 'bleu.368862'
                        file = 'wer.368862'
                        wer_path = os.path.join(root, checkpoint, lang, 'test', modality, 'snr-{}'.format(noise), 'visible-{}'.format(visible), 'beam-{}'.format(beam), noise_fn, file)
                        # print(wer_path)
                        with open(wer_path) as f:                    
                            first_line = f.readline().strip('\n')
                            # prefix = 5 if lang == 'en' else 6
                            prefix = 5
                            results[noise_fn][beam][lang][modality][str(noise)] = round(float(first_line[prefix:]), 1)
                            # results[noise_fn][beam][lang][modality][str(noise)] = round(compute_wer(wer_path, normalizer, lang), 1)
                    except:
                        continue
results

lrs3
muavic
babble
speech
music
noise


{'lrs3': {5: {'en': {'asr': {'-10': 0, '-5': 0, '0': 0, '5': 0, '10': 0},
    'avsr': {'-10': 40.0, '-5': 25.7, '0': 7.5, '5': 3.8, '10': 3.4}},
   'es': {'asr': {'-10': 0, '-5': 0, '0': 0, '5': 0, '10': 0},
    'avsr': {'-10': 91.2, '-5': 70.3, '0': 28.0, '5': 13.9, '10': 9.9}},
   'fr': {'asr': {'-10': 0, '-5': 0, '0': 0, '5': 0, '10': 0},
    'avsr': {'-10': 98.1, '-5': 63.3, '0': 27.5, '5': 16.0, '10': 12.7}},
   'it': {'asr': {'-10': 0, '-5': 0, '0': 0, '5': 0, '10': 0},
    'avsr': {'-10': 92.7, '-5': 73.3, '0': 35.2, '5': 18.0, '10': 12.6}},
   'pt': {'asr': {'-10': 0, '-5': 0, '0': 0, '5': 0, '10': 0},
    'avsr': {'-10': 94.5, '-5': 74.3, '0': 36.0, '5': 20.0, '10': 13.7}}}},
 'muavic': {5: {'en': {'asr': {'-10': 0, '-5': 0, '0': 0, '5': 0, '10': 0},
    'avsr': {'-10': 40.9, '-5': 32.2, '0': 8.7, '5': 4.0, '10': 3.6}},
   'es': {'asr': {'-10': 0, '-5': 0, '0': 0, '5': 0, '10': 0},
    'avsr': {'-10': 97.6, '-5': 77.8, '0': 32.2, '5': 14.4, '10': 10.1}},
   'fr': {'asr': {'-10

In [27]:
# languages = list(results[noise_fn][beam].keys())

print(checkpoint)
# modality = 'asr'
modality = 'avsr'
beam = 5
print('-10  -5  0   5   10  ' * 6)
for lang in langs:
    print(lang)
    scores = [results[noise_fn][beam][lang][modality][str(noise)] for noise_fn in noise_fns for noise in noises ]
    for val in scores:
        print(str(val) + ' ', end='')
    print()

whisper-flamingo_multi-all_medium.pt
-10  -5  0   5   10  -10  -5  0   5   10  -10  -5  0   5   10  -10  -5  0   5   10  -10  -5  0   5   10  -10  -5  0   5   10  
en
40.0 25.7 7.5 3.8 3.4 40.9 32.2 8.7 4.0 3.6 39.7 22.1 6.1 3.7 3.6 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
es
91.2 70.3 28.0 13.9 9.9 97.6 77.8 32.2 14.4 10.1 91.0 0 26.2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
fr
98.1 63.3 27.5 16.0 12.7 107.2 75.1 30.7 17.1 12.7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
it
92.7 73.3 35.2 18.0 12.6 0 0 37.0 18.3 12.6 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
pt
94.5 74.3 36.0 20.0 13.7 100.4 79.7 38.6 20.1 14.1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 


# Test text normalization

In [17]:
import json
import editdistance
from whisper.normalizers import EnglishTextNormalizer, BasicTextNormalizer
from fairseq.scoring.wer import WerScorer, WerScorerConfig

def compute_wer(wer_path, normalizer, lang):
    scorer = WerScorer(
    WerScorerConfig(
        wer_tokenizer="13a",
        wer_remove_punct=True,
        wer_char_level=False,
        wer_lowercase=True
        )
    )
    if lang == 'en':
        std = EnglishTextNormalizer()
    else:
        std = BasicTextNormalizer()
    w_err, w_len = 0, 0
    with open(wer_path.replace('wer.368862', 'wer.json'), 'r') as fp:
        data = json.load(fp)
        hypo = data['pred']
        refs = data['refs']
        for h, r in zip(hypo, refs):
            if normalizer == 'whisper':
                w_err += editdistance.eval(std(r).split(), std(h).split())
                w_len += len(r.split())
            elif normalizer == 'none':
                w_err += editdistance.eval(r.split(), h.split())
                w_len += len(r.split())
            else: 
                scorer.add_string(ref=r, pred=h)
                wer = scorer.score()
        if normalizer == 'whisper' or normalizer == 'none':
            wer = 100. * w_err/w_len
    return wer



In [18]:
import os
# root = '../decode/models/checkpoint/'
# checkpoint = 'whisper-flamingo_multi-all_small.pt'
root = '../decode/models/'
checkpoint='whisper_multi-all_small.pt'

# ASR zero-shot
# root = '../decode/'
# checkpoint = 'large-v2'
# checkpoint = 'medium'
# checkpoint = 'small'

# fixed
langs = ['en', 'ar', 'de', 'el', 'es', 'fr', 'it', 'pt', 'ru'] 
# langs = ['en', 'es', 'fr', 'it', 'pt',] 
# langs = ['en', 'ar', 'de', 'el', 'ru'] 

noises = [1000, 0] # clean, 0
modalities = ['asr', 'avsr']
beams = [5]
visible = 0 # full eval set
noise_fn = 'lrs3'
# normalizer = 'whisper'
normalizer = 'fairseq'

results = {beam: {lang: {modality: {str(noise): 0 for noise in noises} for modality in modalities} for lang in langs} for beam in beams}
for beam in beams:
    for lang in langs:
        for noise in noises:
            for modality in modalities:
                try:
                    file = 'wer.368862'
                    # NOTE: new decoding includes noise file name
                    wer_path = os.path.join(root, checkpoint, lang, 'test', modality, 'snr-{}'.format(noise), 'visible-{}'.format(visible), 'beam-{}'.format(beam), noise_fn, file)
                    # NOTE: old decoding doesn't include noise file name
                    # wer_path = os.path.join(root, checkpoint, lang, 'test', modality, 'snr-{}'.format(noise), 'visible-{}'.format(visible), 'beam-{}'.format(beam), file)
                    with open(wer_path) as f:                    
                        first_line = f.readline().strip('\n')
                        prefix = 5
                        results[beam][lang][modality][str(noise)] = round(compute_wer(wer_path, normalizer, lang), 1)
                except:
                    continue
    
print('Languages ', end='')
for lang in langs:
    print(lang + ' ', end='')
print()
print(checkpoint)
# print(noises)
ckpt_root = '/usr/users/roudi/whisper-flamingo/models/checkpoint'
try:
    print(os.readlink(os.path.join(ckpt_root, checkpoint)).split('/')[-1])
except:
    pass
for beam in beams:
    print_results(results, beam)

Languages en ar de el es fr it pt ru 
whisper_multi-all_small.pt
Beam size: 5
Audio-Clean 3.9 73.4 26.4 18.5 9.5 13.8 12.8 12.8 21.2 
Audio-Visual-Clean 0 0 0 0 0 0 0 0 0 
Audio-Babble-LRS3 16.0 99.5 59.9 56.7 41.7 35.8 50.6 50.2 46.7 
Audio-Visual-Babble-LRS3 0 0 0 0 0 0 0 0 0 
Avg clean non En: 23.6
Avg clean non En: 0.0
Avg noisy non En: 55.1
Avg noisy non En: 0.0
