# Overview
- In this demo, we apply our audio-visual Whisper-Flamingo to a video not seen during training.
- We load a video and decode the model using the original audio, and the audio with added babble noise.
- We first compare Whisper-Flamingo with the original audio-only Whisper and our audio-only En-X Whisper (fine-tuned for En transcription and En-X translation).
- We show the transcription results in English and translation results in Russian. Our models support transcription in English (En) and En-X translation into 6 languages: Greek (El), Spanish (Es), French (Fr), Italian (It), Portuguese (Pt), and Russian (Ru).
- We use Whisper / Whisper-Flamingo Small for this demo; results will be stronger using the Large models.

# Setup
- Install the required packages and download the resources
- This takes ~4 mins
- NOTE: older version of fairseq requires older version of numpy - Colab may need to restart after installing numpy 1.22. Just run the first cell again and the correct version will already be installed.

In [None]:
# NOTE: older version of fairseq requires older version of numpy - Colab may need to restart after installing numpy 1.22
# Just run this code again and the correct version will already be installed
!pip install numpy==1.22 tensorboard==2.9.1



In [None]:
# Verify the correct versions are loaded
import numpy; print(numpy.__version__)
import tensorboard; print(tensorboard.__version__)

1.22.0
2.9.1


In [None]:
# !pip install uv # faster installs compared to pip, but some problems come up

In [None]:
# MuAViC instructions for fairseq https://github.com/facebookresearch/muavic
# Clone the "muavic" branch of av_hubert's repo
!git clone -b muavic https://github.com/facebookresearch/av_hubert.git
# Set the fairseq version
%cd av_hubert
!git submodule init
!git submodule update
%cd fairseq
!pip install --editable ./
# !uv pip install --system --editable ./

fatal: destination path 'av_hubert' already exists and is not an empty directory.
/content/av_hubert
Cloning into '/content/av_hubert/fairseq'...
Submodule path 'fairseq': checked out '272c4c5197250997148fb12c0db6306035f166a4'
/content/av_hubert/fairseq
Obtaining file:///content/av_hubert/fairseq
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting hydra-core<1.1,>=1.0.7 (from fairseq==1.0.0a0+272c4c5)
  Downloading hydra_core-1.0.7-py3-none-any.whl (123 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.8/123.8 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting omegaconf<2.1 (from fairseq==1.0.0a0+272c4c5)
  Downloading omegaconf-2.0.6-py3-none-any.whl (36 kB)
Collecting sacrebleu>=1.4.12

In [None]:
!pip install python-speech-features==0.6 # av-hubert
!pip install tiktoken # whisper
!pip install install pytorch-lightning

# !uv pip install --system tiktoken # whisper
# !uv pip install --system pytorch-lightning
# !uv pip install --system torchmetrics

Collecting python-speech-features==0.6
  Downloading python_speech_features-0.6.tar.gz (5.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: python-speech-features
  Building wheel for python-speech-features (setup.py) ... [?25l[?25hdone
  Created wheel for python-speech-features: filename=python_speech_features-0.6-py3-none-any.whl size=5870 sha256=21a03f6c6d29cb015f4758e8000870e3ba18c0869f70b576a671b1d910d283bc
  Stored in directory: /root/.cache/pip/wheels/5a/9e/68/30bad9462b3926c29e315df16b562216d12bdc215f4d240294
Successfully built python-speech-features
Installing collected packages: python-speech-features
Successfully installed python-speech-features-0.6
Collecting tiktoken
  Downloading tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
Succe

In [None]:
%cd /content/
!git clone https://github.com/roudimit/whisper-flamingo.git
%cd whisper-flamingo
%pwd

In [None]:
import sys
sys.path.insert(0, '/content/av_hubert/fairseq')
import os
import numpy as np
import torch
from scipy.io import wavfile
import whisper
from utils import add_noise

In [None]:
# verify that we are using the local whisper
print(whisper.__file__)

/content/whisper-flamingo-dev/whisper/__init__.py


In [None]:
# download data and models
!wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/demo.tar.gz
!wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/noise.tar.gz
!wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/whisper_en-x_small.pt

--2024-05-09 17:06:45--  https://data.csail.mit.edu/public-release-sls/whisper-flamingo/demo.tar.gz
Resolving data.csail.mit.edu (data.csail.mit.edu)... 128.52.131.233
Connecting to data.csail.mit.edu (data.csail.mit.edu)|128.52.131.233|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 100167680 (96M) [application/x-gzip]
Saving to: ‘demo.tar.gz’


2024-05-09 17:06:47 (76.3 MB/s) - ‘demo.tar.gz’ saved [100167680/100167680]

--2024-05-09 17:06:47--  https://data.csail.mit.edu/public-release-sls/whisper-flamingo/noise.tar.gz
Resolving data.csail.mit.edu (data.csail.mit.edu)... 128.52.131.233
Connecting to data.csail.mit.edu (data.csail.mit.edu)|128.52.131.233|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 766132 (748K) [application/x-gzip]
Saving to: ‘noise.tar.gz’


2024-05-09 17:06:47 (6.53 MB/s) - ‘noise.tar.gz’ saved [766132/766132]

--2024-05-09 17:06:47--  https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/whispe

In [None]:
!tar -xf demo.tar.gz
# adjust the noise tsv files with the correct path to the noise
!tar -xf noise.tar.gz
!echo $(pwd)/noise/babble/muavic/babble_all.wav > ./noise/babble/muavic/test.tsv
!echo $(pwd)/noise/babble/lrs3/noise.wav > ./noise/babble/lrs3/test.tsv

In [None]:
!wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/demo.tar.gz
!tar -xf demo.tar.gz

--2024-05-09 17:06:59--  https://data.csail.mit.edu/public-release-sls/whisper-flamingo/demo.tar.gz
Resolving data.csail.mit.edu (data.csail.mit.edu)... 128.52.131.233
Connecting to data.csail.mit.edu (data.csail.mit.edu)|128.52.131.233|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 100167680 (96M) [application/x-gzip]
Saving to: ‘demo.tar.gz.1’


2024-05-09 17:07:00 (82.3 MB/s) - ‘demo.tar.gz.1’ saved [100167680/100167680]

tar: demo.tar.gz.6: Cannot open: No such file or directory
tar: Error is not recoverable: exiting now


# Process Video

In [None]:
from IPython.display import HTML
from base64 import b64encode
def play_video(video_path, width=200):
  mp4 = open(video_path,'rb').read()
  data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  return HTML(f"""
  <video width={width} controls>
        <source src="{data_url}" type="video/mp4">
  </video>
  """)

In [None]:
def detect_landmark(image, detector, predictor):
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    rects = detector(gray, 1)
    coords = None
    for (_, rect) in enumerate(rects):
        shape = predictor(gray, rect)
        coords = np.zeros((68, 2), dtype=np.int32)
        for i in range(0, 68):
            coords[i] = (shape.part(i).x, shape.part(i).y)
    return coords

def preprocess_video(input_video_path, output_video_path, face_predictor_path, mean_face_path):
  detector = dlib.get_frontal_face_detector()
  predictor = dlib.shape_predictor(face_predictor_path)
  STD_SIZE = (256, 256)
  mean_face_landmarks = np.load(mean_face_path)
  stablePntsIDs = [33, 36, 39, 42, 45]
  videogen = skvideo.io.vread(input_video_path)
  frames = np.array([frame for frame in videogen])
  landmarks = []
  for frame in tqdm(frames):
      landmark = detect_landmark(frame, detector, predictor)
      landmarks.append(landmark)
  preprocessed_landmarks = landmarks_interpolate(landmarks)
  rois = crop_patch(input_video_path, preprocessed_landmarks, mean_face_landmarks, stablePntsIDs, STD_SIZE,
                        window_margin=12, start_idx=48, stop_idx=68, crop_height=96, crop_width=96)
  write_video_ffmpeg(rois, output_video_path, "/usr/bin/ffmpeg")
  return

# Play video

This video is from the LRS3 test set, and we already processed it into the lip-based format. If you want to use your own video, uncomment the code below.

In [None]:
face_predictor_path = "demo/shape_predictor_68_face_landmarks.dat"
mean_face_path = "demo/20words_mean_face.npy"

origin_clip_path = "demo/lrs3_0gks6ceq4eQ_test_00007.mp4"
mouth_roi_path = "demo/demo_lrs3_roi.mp4"
# origin_clip_path = "demo/lrs3_0ZfSOArXbGQ_test_00003.mp4"
# mouth_roi_path = "demo/demo_lrs3_roi_2.mp4"

Uncomment this code to use pre-process your own video. Note that this is slow since the face detection model runs on CPU.

In [None]:
# origin_clip_path = "demo/lrs3_0gks6ceq4eQ_test_00007.mp4" # change this line to your video path
# mouth_roi_path = "demo/processed.mp4"

# !pip install scikit-video
# import sys
# sys.path.insert(0, '/content/av_hubert/avhubert/preparation')
# from align_mouth import landmarks_interpolate, crop_patch, write_video_ffmpeg
# import dlib, cv2, os
# import numpy as np
# import skvideo
# import skvideo.io
# from tqdm import tqdm
# preprocess_video(origin_clip_path, mouth_roi_path, face_predictor_path, mean_face_path)

Original video:

In [None]:
play_video(origin_clip_path, width=300)

Video after face detection, normalization to the reference mean face, and cropping (used as input to AV-HuBERT):

In [None]:
play_video(mouth_roi_path, width=300)

# Add Babble Noise

In [None]:
import IPython
clean_input = whisper.load_audio(origin_clip_path)
print("Original input")
IPython.display.Audio(clean_input, rate=16000)

Original input


In [None]:
noise_fn = 'noise/babble/lrs3/noise.wav'
sample_rate, noise = wavfile.read(noise_fn)
print("Babble noise based on LRS3")
IPython.display.Audio(noise, rate=16000)

Babble noise based on LRS3


In [None]:
# SNR = 0 # negative values make the noise stronger, positive values make the noise weaker
# SNR = -2.5 # negative values make the noise stronger, positive values make the noise weaker
SNR = -5.0 # negative values make the noise stronger, positive values make the noise weaker
noisy_input = add_noise(clean_input * 32768.0, [noise_fn], noise_snr=SNR).flatten().astype(np.float32) / 32768.0
print("Original input with babble noise added at {} SNR".format(SNR))
IPython.display.Audio(noisy_input, rate=16000)

Original input with babble noise added at -5.0 SNR


# Transcribe audio with Whisper Small (original OpenAI weights)

In [None]:
import whisper
model = whisper.load_model("small")

100%|███████████████████████████████████████| 461M/461M [00:05<00:00, 92.6MiB/s]


Whisper dropout rate : 0.0


In [None]:
def decode_audio(input, model, lang="en"):
    audio = whisper.pad_or_trim(input)
    mel = whisper.log_mel_spectrogram(audio).to(model.device)
    options = whisper.DecodingOptions(fp16 = True if torch.cuda.is_available() else False,
                                      language=lang, beam_size=1, without_timestamps=True)
    pred = whisper.decode(model, mel, options).text
    return pred

In [None]:
result = decode_audio(clean_input, model)
print("Transcribing original input : {}".format(result))

Transcribing original input : it's using past experience based on similar situations to try to make meaning.


In [None]:
result = decode_audio(noisy_input, model)
print("Transcribing noisy input : {}".format(result))

Transcribing noisy input : She is not the only one who is based on the sort of situation we are trying to do.


# Transcribe / Translate audio with Whisper En-X Small (ours, fine-tuned on LRS3 & MuAViC)

In [None]:
whisper_en_x_model = whisper.load_model("small")
state_dict = torch.load('whisper_en-x_small.pt', map_location=torch.device('cpu'))
state_dict = state_dict['state_dict']
state_dict_updated = {k[6:]: v  for k, v in state_dict.items()} # remove 'model.'
whisper_en_x_model.load_state_dict(state_dict_updated)

Whisper dropout rate : 0.0


<All keys matched successfully>

In [None]:
def decode_audio_en_x(input, model, lang="en"):
    # Note: we don't pad the audio to 30s for en-x models
    mel = whisper.log_mel_spectrogram(input).to(model.device)
    options = whisper.DecodingOptions(fp16 = True if torch.cuda.is_available() else False,
                                      language=lang, beam_size=1, without_timestamps=True)
    pred = whisper.decode(model, mel, options).text
    return pred

### En Transcription
Note: our model does not capitilize text and add punctuation (besides apostrophe) due to the text normalization in LRS3 training text

In [None]:
result = decode_audio_en_x(clean_input, whisper_en_x_model)
print("Transcribing original input : {}".format(result))

Transcribing original input : it's using past experience based on similar situations to try to make meaning


In [None]:
result = decode_audio_en_x(noisy_input, whisper_en_x_model)
print("Transcribing noisy input : {}".format(result))

Transcribing noisy input : which is a task theory based on the inner situation of the right music


### En-Russian Translation
Note: our model captilizes text and adds punctuation since we trained on raw text in MuAViC

In [None]:
result = decode_audio_en_x(clean_input, whisper_en_x_model, lang="ru")
print("Translating original input : {}".format(result))

Translating original input : Он использует прошлые опыты на основе подобных ситуаций, чтобы попытаться сделать смысл.


In [None]:
result = decode_audio_en_x(noisy_input, whisper_en_x_model, lang="ru")
print("Translating noisy input : {}".format(result))

Translating noisy input : Это была задача, основанная на внутренней ситуации, когда мы были раздражены музыкой.


### En-X Translation
Note: our model captilizes text and adds punctuation since we trained on raw text in MuAViC

In [None]:
# Greek
result = decode_audio_en_x(clean_input, whisper_en_x_model, lang="el")
print("Translating original input : {}".format(result))

Translating original input : Χρησιμοποιεί την παρελθόντη εμπειρία με βάση παρόμοιες καταστάσεις για να προσπαθήσει να κάνει νόημα


In [None]:
# Spanish
result = decode_audio_en_x(clean_input, whisper_en_x_model, lang="es")
print("Translating original input : {}".format(result))

Translating original input : Se utiliza la experiencia del pasado basada en situaciones similares para tratar de hacer significado.


In [None]:
# Portuguese
result = decode_audio_en_x(clean_input, whisper_en_x_model, lang="pt")
print("Translating original input : {}".format(result))

Translating original input : Ele está usando a experiência do passado com base em situações semelhantes para tentar fazer significado.


In [None]:
# French
result = decode_audio_en_x(clean_input, whisper_en_x_model, lang="fr")
print("Translating original input : {}".format(result))

Translating original input : Il utilise l’expérience passée basée sur des situations similaires pour tenter de faire du sens.


In [None]:
# Italian
result = decode_audio_en_x(clean_input, whisper_en_x_model, lang="it")
print("Translating original input : {}".format(result))

Translating original input : utilizza l'esperienza del passato basata su situazioni simili per cercare di fare significato


# Whisper-Flamingo (Audio-Visual)

In [None]:
# Download video models
!wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/whisper-flamingo_en-x_small.pt
!wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/large_noise_pt_noise_ft_433h_only_weights.pt

--2024-05-09 17:09:49--  https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/whisper-flamingo_en-x_small.pt
Resolving data.csail.mit.edu (data.csail.mit.edu)... 128.52.131.233
Connecting to data.csail.mit.edu (data.csail.mit.edu)|128.52.131.233|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2611351155 (2.4G)
Saving to: ‘whisper-flamingo_en-x_small.pt’


2024-05-09 17:10:20 (82.4 MB/s) - ‘whisper-flamingo_en-x_small.pt’ saved [2611351155/2611351155]

--2024-05-09 17:10:20--  https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/large_noise_pt_noise_ft_433h_only_weights.pt
Resolving data.csail.mit.edu (data.csail.mit.edu)... 128.52.131.233
Connecting to data.csail.mit.edu (data.csail.mit.edu)|128.52.131.233|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1910146245 (1.8G)
Saving to: ‘large_noise_pt_noise_ft_433h_only_weights.pt’


2024-05-09 17:10:43 (82.0 MB/s) - ‘large_noise_pt_noise_ft_433h_only_wei

In [None]:
model_type = 'small'
checkpoint= 'whisper-flamingo_en-x_small.pt'
use_av_hubert_encoder = 1
av_fusion = 'separate'
video_model_path = 'large_noise_pt_noise_ft_433h_only_weights.pt'
av_hubert_path =  '/content/av_hubert/avhubert'

In [None]:
def load_model():
    print("Loading Whisper")
    whisper_model = whisper.load_model(model_type,
                                    video=True if av_fusion == 'separate' else 0,
                                    video_model_path=video_model_path,
                                    av_hubert_path=av_hubert_path,
                                    av_hubert_encoder=use_av_hubert_encoder,
                                    av_fusion=av_fusion,
                                    add_gated_x_attn=1 if av_fusion == 'separate' else 0)

    if checkpoint is not None:
        print("Loading checkpoint")
        state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))
        print(state_dict.keys())
        state_dict = state_dict['state_dict']
        state_dict_updated = {k[6:]: v  for k, v in state_dict.items()} # remove 'model.'
        try: # newer models have learnable scaler init 1
            whisper_model.load_state_dict(state_dict_updated)
        except BaseException as e:
            print(str(e))
            print("Loading weights with strict=False")
            whisper_model.load_state_dict(state_dict_updated, strict=False)

    if torch.cuda.is_available() and use_av_hubert_encoder == 1:
        whisper_model.encoder.video_projection_scalar.half()
        whisper_model.encoder.video_model.half()
        model_to_num_layers = {'small': 12, 'medium': 24, 'large-v2': 32}
        if av_fusion == 'separate':
            for i in range(model_to_num_layers[model_type]):
                whisper_model.decoder.blocks[i].attn_gate.data = whisper_model.decoder.blocks[i].attn_gate.half()
                whisper_model.decoder.blocks[i].ff_gate.data = whisper_model.decoder.blocks[i].ff_gate.half()
    return whisper_model
whisper_flamingo_en_x_small = load_model()

Loading Whisper
Whisper dropout rate : 0.0
Loading AV-HuBERT encoder




Using AV-HuBERT encoder with parameters: 325136104
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Adding gated x attn layers
Loading checkpoint
dict_keys(['state_dict'])


In [None]:
from utils import load_video_feats
def decode_audio_video(audio, video_path, model, lang="en"):
    # Note: we don't pad the audio to 30s
    mel = whisper.log_mel_spectrogram(audio).to(model.device)

    video = load_video_feats(video_path, train=False)
    video = torch.tensor(video.astype(np.float32))
    video = video.unsqueeze(0).permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W]
    video = video.half().to(model.device) if torch.cuda.is_available() else video
    # print(audio.shape, audio.dtype)
    # print(video.shape, video.dtype)

    model.eval() # AV-HuBERT batch norm and dropout
    options = whisper.DecodingOptions(fp16 = True if torch.cuda.is_available() else False,
                                      language=lang, without_timestamps=True, beam_size=1)
    pred = model.decode(mel, options, video).text
    return pred

### En Transcription

In [None]:
result = decode_audio_video(clean_input, mouth_roi_path, whisper_flamingo_en_x_small)
print("Transcribing original input : {}".format(result))

Transcribing original input : it's using past experience based on similar situations to try to make meaning


In [None]:
result = decode_audio_video(noisy_input, mouth_roi_path, whisper_flamingo_en_x_small)
print("Transcribing noisy input : {}".format(result))

Transcribing noisy input : it's using past experience based on similar situations to try and make meaning


### En-Russian Translation

In [None]:
result = decode_audio_video(clean_input, mouth_roi_path, whisper_flamingo_en_x_small, lang='ru')
print("Translating original input : {}".format(result))

Translating original input : Он использует прошлые опыты на основе подобных ситуаций, чтобы попытаться сделать смысл.


In [None]:
result = decode_audio_video(noisy_input, mouth_roi_path, whisper_flamingo_en_x_small, lang='ru')
print("Translating noisy input : {}".format(result))

Translating noisy input : Она использует паспорту, основанную на внутренних ситуациях, чтобы попытаться изменить музыку.
