In [1]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, T5Tokenizer, T5ForConditionalGeneration
from datasets import Dataset, load_dataset, Audio, DatasetDict, load_from_disk
import evaluate
import jiwer
import numpy as np

import torch

from IPython.display import Audio as AudioDisp

from asr_w_spellchecker import ST6

import logging

In [2]:
DATASET = {
    "path": "facebook/voxpopuli",
    # "path": "mozilla-foundation/common_voice_11_0",
    "name": "cs",
    "split": "test"
}

SAMPLING_RATE = 16_000

WAV2VEC_MODEL_NAME = "/home/sulcm/models/wav2vec2/wav2vec2-cs-v23"

T5_MODEL_NAME = "/home/sulcm/models/t5/t5-spellchecker-cs-v20"

In [None]:
dataset = load_dataset(DATASET['path'], DATASET['name'], split=DATASET['split'])
dataset = dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))

In [3]:
dataset = load_from_disk("/home/sulcm/datasets/t5/asr-correction-cs-v23")

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['asr_transcription', 'target_output'],
        num_rows: 18902
    })
    validation: Dataset({
        features: ['asr_transcription', 'target_output'],
        num_rows: 1103
    })
    test: Dataset({
        features: ['asr_transcription', 'target_output'],
        num_rows: 1123
    })
})

In [None]:
metrics = eval_metrics = {metric: evaluate.load(metric) for metric in ["sacrebleu", "wer", "cer"]}

In [None]:
idx = 172

input_audio = dataset[idx]['audio']
sentence = dataset[idx]['normalized_text'].lower()

print(sentence)
AudioDisp(input_audio['path'])

# ST6

In [None]:
st6_model = ST6(wav2vec2_path=WAV2VEC_MODEL_NAME, t5_path=T5_MODEL_NAME, logging_level=logging.DEBUG)

In [None]:
st6_model(input_audio['array'])

# Wav2Vec2.0

In [None]:
wav2vec_processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_MODEL_NAME)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(WAV2VEC_MODEL_NAME)

In [None]:
inputs = wav2vec_processor(input_audio['array'], sampling_rate=SAMPLING_RATE, return_tensors='pt')

with torch.no_grad():
    logits = wav2vec_model(**inputs).logits

pred_ids = torch.argmax(logits, dim=-1)
transcription = wav2vec_processor.batch_decode(pred_ids)
transcription

In [None]:
logits.shape

In [None]:
pred_ids.shape

In [None]:
inputs.input_values.shape

In [None]:
logits.shape

# T5

In [9]:
t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_NAME)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
t5_tokenizer.vocab_size

32000

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['asr_transcription', 'target_output'],
        num_rows: 18902
    })
    validation: Dataset({
        features: ['asr_transcription', 'target_output'],
        num_rows: 1103
    })
    test: Dataset({
        features: ['asr_transcription', 'target_output'],
        num_rows: 1123
    })
})

In [26]:
train_ds = [len(t5_tokenizer(example).input_ids) for example in dataset["test"]["target_output"]]

In [25]:
np.mean(train_ds)

27.863757791629563

In [27]:
np.mean(train_ds)

26.908281389136242

In [None]:
inputs = t5_tokenizer(["spell check: " + sentence for sentence in dataset["normalized_text"]], return_tensors="pt", padding=True)

In [None]:
np.percentile(np.count_nonzero(inputs.input_ids, axis=1), q=100)

In [None]:
inputs = t5_tokenizer(["spell check: " + sentence for sentence in transcription], return_tensors="pt")

output_sequences = t5_model.generate(**inputs, max_new_tokens=20)

t5_tokenizer.batch_decode(output_sequences, skip_special_tokens=True)

# Stuff ...

In [4]:
t5_nn_modules = t5_model.to(torch.device("cuda"))

In [5]:
t5_nn_modules

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [None]:
w2v2_nn_modules = wav2vec_model.to(torch.device("cuda"))

In [None]:
w2v2_nn_modules

In [None]:
x = input_audio["array"]
x.shape

In [None]:
92736/320

In [None]:
raw_waveform = torch.tensor([input_audio["array"]], dtype=torch.float32).to(torch.device("cuda"))

In [None]:
conv0 = w2v2_nn_modules.wav2vec2.feature_extractor.conv_layers[0].conv

In [None]:
conv0.weight.shape

In [None]:
tcn_waveform = w2v2_nn_modules.wav2vec2.feature_extractor.conv_layers[0].conv.forward(raw_waveform)

In [None]:
gelu_act = w2v2_nn_modules.wav2vec2.feature_extractor.conv_layers[0].activation.forward(tcn_waveform)

In [None]:
group_norm = w2v2_nn_modules.wav2vec2.feature_extractor.conv_layers[0].layer_norm.forward(gelu_act.T)

In [None]:
group_norm[0].shape

In [None]:
tcn_waveform_2 = w2v2_nn_modules.wav2vec2.feature_extractor.conv_layers[1].conv.forward(group_norm.T)

In [None]:
92736/5

In [None]:
raw_waveform.shape

In [None]:
40000/5/2/2/2/2/2/2

In [None]:
92736/320

In [None]:
(16000*289)/92736

In [None]:
tcn_waveform.shape

In [None]:
tcn_waveform_2[0].shape

In [None]:
t5_dataset = load_from_disk("/home/sulcm/datasets/t5/asr-correction-cs-v23")

In [None]:
t5_dataset["test"][1]

In [None]:
model = torch.load("/home/sulcm/models/wav2vec2/wav2vec2-cs-v1/pytorch_model.bin")

In [None]:
model.keys()

In [None]:
model['wav2vec2.encoder.layers.11.final_layer_norm.weight'].shape

In [None]:
model['lm_head.weight'].shape