-
Notifications
You must be signed in to change notification settings - Fork 6.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
wav2vec 2.0 inference pipeline #2651
Comments
If anyone succeeded in making a brief inference, I would appreciate it if you could leave it here. |
I Success !! |
I did it in Fairseq version 0.9.0. I will improve the code further and send a pull request. import os
import math
import sys
import torch
import torch.nn.functional as F
import numpy as np
import itertools as it
import torch.nn as nn
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.tasks.audio_pretraining import AudioPretrainingTask
from fairseq.data import Dictionary
from fairseq.models import BaseFairseqModel
import soundfile as sf
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
import contextlib
import torch
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq.models import FairseqEncoder
from examples.wav2vec2.tasks.audio_pretraining import Wav2vec2PretrainingTask
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == 'wordpiece':
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == 'letter':
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol is not None and symbol != 'none':
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
class Wav2VecEncoder(FairseqEncoder):
def __init__(self, args, tgt_dict=None):
self.apply_mask = args.apply_mask
arg_overrides = {
"dropout": args.dropout,
"activation_dropout": args.activation_dropout,
"dropout_input": args.dropout_input,
"attention_dropout": args.attention_dropout,
"mask_length": args.mask_length,
"mask_prob": args.mask_prob,
"mask_selection": args.mask_selection,
"mask_other": args.mask_other,
"no_mask_overlap": args.no_mask_overlap,
"mask_channel_length": args.mask_channel_length,
"mask_channel_prob": args.mask_channel_prob,
"mask_channel_selection": args.mask_channel_selection,
"mask_channel_other": args.mask_channel_other,
"no_mask_channel_overlap": args.no_mask_channel_overlap,
"encoder_layerdrop": args.layerdrop,
"feature_grad_mult": args.feature_grad_mult,
}
if getattr(args, "w2v_args", None) is None:
state = checkpoint_utils.load_checkpoint_to_cpu(
args.w2v_path, arg_overrides
)
w2v_args = state["args"]
else:
state = None
w2v_args = args.w2v_args
assert args.normalize == w2v_args.normalize, 'Fine-tuning works best when data normalization is the same'
w2v_args.data = args.data
task = Wav2vec2PretrainingTask.setup_task(w2v_args)
model = task.build_model(w2v_args)
if state is not None and not args.no_pretrained_weights:
model.load_state_dict(state["model"], strict=True)
model.remove_pretraining_modules()
super().__init__(task.source_dictionary)
d = w2v_args.encoder_embed_dim
self.w2v_model = model
self.final_dropout = nn.Dropout(args.final_dropout)
self.freeze_finetune_updates = args.freeze_finetune_updates
self.num_updates = 0
if tgt_dict is not None:
self.proj = Linear(d, len(tgt_dict))
elif getattr(args, 'decoder_embed_dim', d) != d:
self.proj = Linear(d, args.decoder_embed_dim)
else:
self.proj = None
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(self, source, padding_mask, tbc=True, **kwargs):
w2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
}
ft = self.freeze_finetune_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
x, padding_mask = self.w2v_model.extract_features(**w2v_args)
if tbc:
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x = self.final_dropout(x)
if self.proj:
x = self.proj(x)
return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": padding_mask, # B x T
"padding_mask": padding_mask,
}
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return None
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
def base_architecture(args):
args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False)
args.dropout_input = getattr(args, "dropout_input", 0)
args.final_dropout = getattr(args, "final_dropout", 0)
args.apply_mask = getattr(args, "apply_mask", False)
args.dropout = getattr(args, "dropout", 0)
args.attention_dropout = getattr(args, "attention_dropout", 0)
args.activation_dropout = getattr(args, "activation_dropout", 0)
args.mask_length = getattr(args, "mask_length", 10)
args.mask_prob = getattr(args, "mask_prob", 0.5)
args.mask_selection = getattr(args, "mask_selection", "static")
args.mask_other = getattr(args, "mask_other", 0)
args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0)
args.layerdrop = getattr(args, "layerdrop", 0.0)
class W2lDecoder(object):
def __init__(self, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = 1
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.asg_transitions = None
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
# encoder_out = models[0].encoder(**encoder_input)
encoder_out = models[0](**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B)
]
class Wav2VecCtc(BaseFairseqModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
add_common_args(parser)
def __init__(self, w2v_encoder, args):
super().__init__()
self.w2v_encoder = w2v_encoder
self.args = args
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, args, target_dict):
"""Build a new model instance."""
base_architecture(args)
w2v_encoder = Wav2VecEncoder(args, target_dict)
return cls(w2v_encoder, args)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out"]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
def get_feature(filepath):
def postprocess(feats, sample_rate):
if feats.dim == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
wav, sample_rate = sf.read(filepath)
feats = torch.from_numpy(wav).float()
feats = postprocess(feats, sample_rate)
return feats
def load_target_dict(manifest_path='./manifest'):
dict_path = os.path.join(manifest_path, "dict.ltr.txt")
target_dict = Dictionary.load(dict_path)
return target_dict
def load_model(model_path, target_dict):
# state = checkpoint_utils.load_checkpoint_to_cpu(model_path)
# args = state["args"]
w2v = torch.load(model_path)
# from examples.wav2vec2.models.wav2vec2_asr import Wav2Vec2Model
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
return [model]
def main():
sample, input = dict(), dict()
WAV_PATH = 'xxx.wav'
W2V_PATH = 'wav2vec2_vox_960h.pt'
manifest_path = "MANIFEST_PATH"
feature = get_feature(WAV_PATH )
use_cuda = torch.cuda.is_available()
target_dict = load_target_dict(manifest_path)
model = load_model(W2V_PATH, target_dict)
model[0].eval()
generator = W2lViterbiDecoder(target_dict)
input["source"] = feature.unsqueeze(0)
padding_mask = torch.BoolTensor(input["source"].size(1)).fill_(False).unsqueeze(0)
input["padding_mask"] = padding_mask
sample["net_input"] = input
with torch.no_grad():
hypo = generator.generate(model, sample, prefix_tokens=None)
hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
print(post_process(hyp_pieces, 'letter'))
if __name__ == '__main__':
main()
|
@sooftware amazing!!! Did you use the latest version of |
I don`t sure but I have a command that I used.
|
I installed wav2letter a few days ago. |
@sooftware Thanks! I'm getting an import error for |
@sooftware Could you please specify what does you have inside the file from Is this path to link |
@mironnn The manifest path only contains the dictionary from what I can tell. Look at the
|
Have the same issue =( |
I create pull request (#2668)
|
Here is the code import torch
import argparse
import soundfile as sf
import torch.nn.functional as F
import itertools as it
from fairseq import utils
from fairseq.models import BaseFairseqModel
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
from fairseq.data import Dictionary
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
parser = argparse.ArgumentParser(description='Wav2vec-2.0 Recognize')
parser.add_argument('--wav_path', type=str,
default='~/xxx.wav',
help='path of wave file')
parser.add_argument('--w2v_path', type=str,
default='~/wav2vec2_vox_960h.pt',
help='path of pre-trained wav2vec-2.0 model')
parser.add_argument('--target_dict_path', type=str,
default='dict.ltr.txt',
help='path of target dict (dict.ltr.txt)')
class Wav2VecCtc(BaseFairseqModel):
def __init__(self, w2v_encoder, args):
super().__init__()
self.w2v_encoder = w2v_encoder
self.args = args
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, args, target_dict):
"""Build a new model instance."""
base_architecture(args)
w2v_encoder = Wav2VecEncoder(args, target_dict)
return cls(w2v_encoder, args)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out"]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
class W2lDecoder(object):
def __init__(self, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = 1
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.asg_transitions = None
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
# encoder_out = models[0].encoder(**encoder_input)
encoder_out = models[0](**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B)
]
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == 'wordpiece':
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == 'letter':
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol is not None and symbol != 'none':
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
def get_feature(filepath):
def postprocess(feats, sample_rate):
if feats.dim == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
wav, sample_rate = sf.read(filepath)
feats = torch.from_numpy(wav).float()
feats = postprocess(feats, sample_rate)
return feats
def load_model(model_path, target_dict):
w2v = torch.load(model_path)
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
return [model]
def main():
args = parser.parse_args()
sample = dict()
net_input = dict()
feature = get_feature(args.wav_path)
target_dict = Dictionary.load(args.target_dict_path)
model = load_model(args.w2v_path, target_dict)
model[0].eval()
generator = W2lViterbiDecoder(target_dict)
net_input["source"] = feature.unsqueeze(0)
padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
net_input["padding_mask"] = padding_mask
sample["net_input"] = net_input
with torch.no_grad():
hypo = generator.generate(model, sample, prefix_tokens=None)
hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
print(post_process(hyp_pieces, 'letter'))
if __name__ == '__main__':
main() |
@sooftware thanks, I'm trying a CPU build in this case I get a
I can see from your script you build the python bindings, but how to include the |
Oh, I'm sorry. I don't know that issue. T.T |
Asked it here flashlight/wav2letter#842 |
@loretoparisi Here are my processes below:
# For example
fairseq/data/wav2vec_small_960h.pt # model
fairseq/data/dict.ltr.txt # dict file
fairseq/data/temp.wav # the wav you want to test, and don't forget to resample it as 16kHz
FROM wav2letter/wav2letter:cpu-latest
ENV USE_CUDA=0
ENV KENLM_ROOT_DIR=/root/kenlm
# will use Intel MKL for featurization but this may cause dynamic loading conflicts.
# ENV USE_MKL=1
ENV LD_LIBRARY_PATH=/opt/intel/compilers_and_libraries_2018.5.274/linux/mkl/lib/intel64:$LD_IBRARY_PATH
WORKDIR /root/wav2letter/bindings/python
RUN pip install --upgrade pip && pip install soundfile packaging && pip install -e .
WORKDIR /root
RUN git clone https://github.com/pytorch/fairseq.git
RUN mkdir data
COPY examples/wav2vec/recognize.py /root/fairseq/examples/wav2vec/recognize.py
WORKDIR /root/fairseq
RUN pip install --editable ./ && python examples/speech_recognition/infer.py --help && python examples/wav2vec/recognize.py --help
# build
docker build -t wav2vec2 -f wav2vec2.CPU.Dockerfile .
# run docker
docker run --rm -itd --ipc=host -v $PWD/data:/root/data --name w2v wav2vec2
# go into container
docker exec -it w2v bash
# run recognize
python examples/wav2vec/recognize.py --wav_path ~/data/temp.wav --w2v_path ~/data/wav2vec_small_960h.pt --target_dict_path ~/data/dict.ltr.txt |
@mychiux413 thank you so much. I'm getting this
Within the container the command used was
It should not be there, so I have opened an issue. |
@loretoparisi there is an typo. not wv2_path, w2v_path. :) |
@sooftware gosh!!! I've have checked it ten times! |
LoL!! I'm glad I found it now! |
@sooftware not yet but this is definitively something I'm are going to do! |
Let me know if you succeed! I have an issue (#2654) (with KenLM) |
@sooftware definitively I will. In the meanwhile I have pushed everything here with Docker. I did two wav2vec-python3 latest cfdcb450b427 51 minutes ago 9.97GB
wav2vec-wav2letter latest e028493c66b0 2 hours ago 3.37GB Thank you guys for your help and collaboration! I will keep you posted. |
Grrrrrrreat !!! |
@alexeib I still have the same error on the most recent commit. I built kenlm using the tarball (not via git), compiled with DKENLM_MAX_ORDER 20. The following command `
causes the stack trace
|
looks like it cant import wav2letter. have you tried installing python bindings like the error message is suggesting? |
Ah thanks! I don't know how I missed that, must be an issue with my wav2letter install |
Returning to this after a while, I just ran my first test and I'm getting surprisingly poor results. I believe my audio file was 16kHz, a lecture that has some noise in it. My WER is: 22.9% and my CER: 14.75% I used 4-gram (probing), and here is the command I ran:
Anybody know how to get this down to at the very least 4-5%? Do I have to use the transformer language model instead? I am getting this warning:
However, it seems to run either way. Could the word error rate be high due to this? |
you need to use the fairseq model (.pt) not the wav2letter model (.bin) |
@alexeib ah! 🤦♂️ Sorry about that, you’re right. I’ll try again soon with the actual language model haha. |
Okay, so I downloaded and tried to run it, here's what's going on so far: I tried running this command:
But it gave me this error:
Then I tried taking the dict.ltr.txt file from the libri folder and putting it in the data folder, and renaming it to "dict.txt", and I ran this:
I got this error:
Then I tried downloading the fairseq dict file listed next to the transformer model here: https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019 I then renamed the file to 'dict.txt' I ran this command:
And I got this error:
If this error in fact follows from a proper command/sequence of events and not some mistake I made in one of the inputs, Is there some way to run this on CPU? I already have the I believe I have a CUDA driver installed but not sure if I have an NVIDIA one, it seems from my settings that I do but for some reason I guess it's not seeing it? There's a whole bunch of problems in this area with Apple/Mac/Nvidia that are hairy to get into. I'd rather just run on CPU. |
I want to use Wave2Vec2.0 as a featurizer i.e. get just the embeddings. Can anyone help with this or point to some starting point |
@sooftware have you added Language Model decoding in this inference pipeline? |
@spygaurad I'll try. I'll leave it here if I succeed. The code up there is a high probability that fairseq will not work as it is upgraded to version 0.10.1 |
Okay @sooftware, i can try to integrate LM for version 0.10.1 if i have some references. The code seemed quite complex for me. Thanks for your response. |
Can anyone please provide an example of how to use this pipeline for ASR of real time / media stream rather than a static wav file? I was looking everywhere for it. I see in the Pitch here that "programmatically loaded waveform signal" should be supported, if I understand correctly it refers to a sort of online/live ASR. Help will be very much appreciated. |
You can check wav2vec 2.0 Inference pipeline at https://github.com/kakaobrain/pororo |
I am running into the same issue when trying to run inference with my own finetuned model trained using fairseq-hydra-train. |
I've created pull request: #3244 for this. |
This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment! |
Time to close this Wave2vec 2.0 specific issue. A lot of time is passed (9 months and a kid is born!) and now thanks to HuggingFace new audio libraries inference is simple and it works like a charm! https://github.com/loretoparisi/hf-experiments/blob/master/src/asr/README.md |
Hi, can this work with my own pretrained and finetuned model (mymodel.pt) or only one from hugginface? |
@tensorfoo assumed you replace the Wave2Vec2 model here it should definitively work. |
I've tried it with my finetuned model and it gave a utf8-error. Now i'm trying with the base english model, like:
|
were you able to solve that issue ? pls help, am running into the same issue |
🚀 Feature Request
Provide a simple inference pipeline for the
wav2vec 2.0
model.Motivation
Current inference script
examples/speech_recognition/infer.py
handles a lot of cases, resulting being extremely complex.Pitch
A single python script that loads and runs inference with
wav2vec 2.0
pre-trained model on a single wav file or on a programmatically loaded waveform signal.Alternatives
Additional context
This kind of inference pipeline would enable indi researchers to test the model on their audio dataset and and against other models.
The text was updated successfully, but these errors were encountered: