In [1]:
import fairseq
import torch
from transformers import SpeechToSpeechModel, Wav2Vec2FeatureExtractor

2022-08-04 12:08:14 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


In [2]:
hf_path = "./pytorch_dump_folder"
fairseq_wav2vec2_path = "./w2v2_mbart_LND_w_ASR.pt"

In [3]:
processor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path)
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
    [fairseq_wav2vec2_path], arg_overrides={"data": "./", "task": "speech_to_text"}
)
hf_model = SpeechToSpeechModel.from_pretrained(hf_path)

2022-08-04 12:08:19 | INFO | fairseq.tasks.speech_to_text | dictionary size (dict_1003_unitmbart.txt): 1,007


In [4]:
model = model[0].eval()

In [40]:
# CausalLM decoder inputs (previous output tokens)
prev_tokens = torch.arange(10).reshape(1, 10)

In [30]:
with torch.no_grad():
    fsq_output = model.decoder.forward(prev_output_tokens=prev_tokens)[0]

In [31]:
with torch.no_grad():
    hf_output = hf_model.decoder(prev_tokens).logits

In [41]:
assert hf_output.shape == fsq_output.shape, f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq"

In [51]:
assert torch.allclose(hf_output[:, 0, :], fsq_output[:, 0, :], atol=1e-4), f"Values don't match. Max diff={torch.max(torch.abs(hf_output[:, 0, :] - hf_output[:, 0, :]))}"

In [54]:
for i in range(hf_output.shape[1]):
    max_diff = torch.max(torch.abs(hf_output[:, i, :] - fsq_output[:, i, :]))
    print(f"Token {i}, Max diff = {max_diff}")

Token 0, Max diff = 0.0
Token 1, Max diff = 1.0779876708984375
Token 2, Max diff = 0.1926802098751068
Token 3, Max diff = 0.16107210516929626
Token 4, Max diff = 0.13886451721191406
Token 5, Max diff = 0.13402068614959717
Token 6, Max diff = 0.12233522534370422
Token 7, Max diff = 0.14203158020973206
Token 8, Max diff = 0.12118098139762878
Token 9, Max diff = 0.1102476716041565


In [47]:
assert torch.allclose(hf_output, fsq_output, atol=1e-2), f"Values don't match. Max diff={torch.max(torch.abs(hf_output - fsq_output))}"

AssertionError: Values don't match. Max diff=1.0779876708984375

In [None]:
# fariseq positional embeddings for tokens 0 to 10
0 tensor([0.9093, 0.9236, 0.9365,  ..., 1.0000, 1.0000, 1.0000])
1 tensor([0., 0., 0.,  ..., 0., 0., 0.])
2 tensor([0.1411, 0.1939, 0.2453,  ..., 1.0000, 1.0000, 1.0000])
3 tensor([-0.7568, -0.7082, -0.6570,  ...,  1.0000,  1.0000,  1.0000])
4 tensor([-0.9589, -0.9804, -0.9939,  ...,  1.0000,  1.0000,  1.0000])
5 tensor([-0.2794, -0.3805, -0.4756,  ...,  1.0000,  1.0000,  1.0000])
6 tensor([0.6570, 0.5578, 0.4520,  ..., 1.0000, 1.0000, 1.0000])
7 tensor([0.9894, 1.0000, 0.9906,  ..., 1.0000, 1.0000, 1.0000])
8 tensor([0.4121, 0.5527, 0.6768,  ..., 1.0000, 1.0000, 1.0000])
9 tensor([-0.5440, -0.3863, -0.2194,  ...,  1.0000,  1.0000,  1.0000])

In [None]:
# HF positional embeddings for tokens 0 to 10
0 tensor([0., 0., 0.,  ..., 1., 1., 1.])
1 tensor([0.8415, 0.8317, 0.8219,  ..., 1.0000, 1.0000, 1.0000])
2 tensor([0.9093, 0.9236, 0.9364,  ..., 1.0000, 1.0000, 1.0000])
3 tensor([0.1411, 0.1938, 0.2451,  ..., 1.0000, 1.0000, 1.0000])
4 tensor([-0.7568, -0.7083, -0.6572,  ...,  1.0000,  1.0000,  1.0000])
5 tensor([-0.9589, -0.9804, -0.9939,  ...,  1.0000,  1.0000,  1.0000])
6 tensor([-0.2794, -0.3803, -0.4752,  ...,  1.0000,  1.0000,  1.0000])
7 tensor([0.6570, 0.5580, 0.4524,  ..., 1.0000, 1.0000, 1.0000])
8 tensor([0.9894, 1.0000, 0.9907,  ..., 1.0000, 1.0000, 1.0000])
9 tensor([0.4121, 0.5524, 0.6764,  ..., 1.0000, 1.0000, 1.0000])