Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wav2vec 2.0 inference pipeline #2651

Closed
loretoparisi opened this issue Sep 24, 2020 · 112 comments
Closed

wav2vec 2.0 inference pipeline #2651

loretoparisi opened this issue Sep 24, 2020 · 112 comments

Comments

@loretoparisi
Copy link

🚀 Feature Request

Provide a simple inference pipeline for the wav2vec 2.0 model.

Motivation

Current inference script examples/speech_recognition/infer.py handles a lot of cases, resulting being extremely complex.

Pitch

A single python script that loads and runs inference with wav2vec 2.0 pre-trained model on a single wav file or on a programmatically loaded waveform signal.

Alternatives

Additional context

This kind of inference pipeline would enable indi researchers to test the model on their audio dataset and and against other models.

@sooftware
Copy link

sooftware commented Sep 25, 2020

If anyone succeeded in making a brief inference, I would appreciate it if you could leave it here.
If I succeed, I will leave the code here.

@sooftware
Copy link

I Success !!
I'll be wrapping up the code and put it up here!

@sooftware
Copy link

sooftware commented Sep 28, 2020

I did it in Fairseq version 0.9.0.
In fairseq-0.9.0, Wav2vec-2.0 is not supported, So I took it from the fairseq code and applied it.
I hope this will help.

I will improve the code further and send a pull request.
Here is my code.

import os
import math
import sys
import torch
import torch.nn.functional as F
import numpy as np
import itertools as it
import torch.nn as nn
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.tasks.audio_pretraining import AudioPretrainingTask
from fairseq.data import Dictionary
from fairseq.models import BaseFairseqModel
import soundfile as sf
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
import contextlib
import torch
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq.models import FairseqEncoder
from examples.wav2vec2.tasks.audio_pretraining import Wav2vec2PretrainingTask


def post_process(sentence: str, symbol: str):
    if symbol == "sentencepiece":
        sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
    elif symbol == 'wordpiece':
        sentence = sentence.replace(" ", "").replace("_", " ").strip()
    elif symbol == 'letter':
        sentence = sentence.replace(" ", "").replace("|", " ").strip()
    elif symbol == "_EOW":
        sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
    elif symbol is not None and symbol != 'none':
        sentence = (sentence + " ").replace(symbol, "").rstrip()
    return sentence


class Wav2VecEncoder(FairseqEncoder):
    def __init__(self, args, tgt_dict=None):
        self.apply_mask = args.apply_mask

        arg_overrides = {
            "dropout": args.dropout,
            "activation_dropout": args.activation_dropout,
            "dropout_input": args.dropout_input,
            "attention_dropout": args.attention_dropout,
            "mask_length": args.mask_length,
            "mask_prob": args.mask_prob,
            "mask_selection": args.mask_selection,
            "mask_other": args.mask_other,
            "no_mask_overlap": args.no_mask_overlap,
            "mask_channel_length": args.mask_channel_length,
            "mask_channel_prob": args.mask_channel_prob,
            "mask_channel_selection": args.mask_channel_selection,
            "mask_channel_other": args.mask_channel_other,
            "no_mask_channel_overlap": args.no_mask_channel_overlap,
            "encoder_layerdrop": args.layerdrop,
            "feature_grad_mult": args.feature_grad_mult,
        }

        if getattr(args, "w2v_args", None) is None:
            state = checkpoint_utils.load_checkpoint_to_cpu(
                args.w2v_path, arg_overrides
            )
            w2v_args = state["args"]
        else:
            state = None
            w2v_args = args.w2v_args

        assert args.normalize == w2v_args.normalize, 'Fine-tuning works best when data normalization is the same'

        w2v_args.data = args.data
        task = Wav2vec2PretrainingTask.setup_task(w2v_args)
        model = task.build_model(w2v_args)

        if state is not None and not args.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()
        super().__init__(task.source_dictionary)

        d = w2v_args.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(args.final_dropout)
        self.freeze_finetune_updates = args.freeze_finetune_updates
        self.num_updates = 0

        if tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
        elif getattr(args, 'decoder_embed_dim', d) != d:
            self.proj = Linear(d, args.decoder_embed_dim)
        else:
            self.proj = None

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def forward(self, source, padding_mask, tbc=True, **kwargs):

        w2v_args = {
            "source": source,
            "padding_mask": padding_mask,
            "mask": self.apply_mask and self.training,
        }

        ft = self.freeze_finetune_updates <= self.num_updates

        with torch.no_grad() if not ft else contextlib.ExitStack():
            x, padding_mask = self.w2v_model.extract_features(**w2v_args)

            if tbc:
                # B x T x C -> T x B x C
                x = x.transpose(0, 1)

        x = self.final_dropout(x)

        if self.proj:
            x = self.proj(x)

        return {
            "encoder_out": x,  # T x B x C
            "encoder_padding_mask": padding_mask,  # B x T
            "padding_mask": padding_mask,
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        if encoder_out["encoder_out"] is not None:
            encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
                1, new_order
            )
        if encoder_out["encoder_padding_mask"] is not None:
            encoder_out["encoder_padding_mask"] = encoder_out[
                "encoder_padding_mask"
            ].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return None

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict


def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    nn.init.xavier_uniform_(m.weight)
    if bias:
        nn.init.constant_(m.bias, 0.0)
    return m


def base_architecture(args):
    args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False)
    args.dropout_input = getattr(args, "dropout_input", 0)
    args.final_dropout = getattr(args, "final_dropout", 0)
    args.apply_mask = getattr(args, "apply_mask", False)
    args.dropout = getattr(args, "dropout", 0)
    args.attention_dropout = getattr(args, "attention_dropout", 0)
    args.activation_dropout = getattr(args, "activation_dropout", 0)

    args.mask_length = getattr(args, "mask_length", 10)
    args.mask_prob = getattr(args, "mask_prob", 0.5)
    args.mask_selection = getattr(args, "mask_selection", "static")
    args.mask_other = getattr(args, "mask_other", 0)
    args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
    args.mask_channel_length = getattr(args, "mask_channel_length", 10)
    args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
    args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
    args.mask_channel_other = getattr(args, "mask_channel_other", 0)
    args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)

    args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0)
    args.feature_grad_mult = getattr(args, "feature_grad_mult", 0)
    args.layerdrop = getattr(args, "layerdrop", 0.0)


class W2lDecoder(object):
    def __init__(self, tgt_dict):
        self.tgt_dict = tgt_dict
        self.vocab_size = len(tgt_dict)
        self.nbest = 1

        self.criterion_type = CriterionType.CTC
        self.blank = (
            tgt_dict.index("<ctc_blank>")
            if "<ctc_blank>" in tgt_dict.indices
            else tgt_dict.bos()
        )
        self.asg_transitions = None

    def generate(self, models, sample, **unused):
        """Generate a batch of inferences."""
        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
        }
        emissions = self.get_emissions(models, encoder_input)
        return self.decode(emissions)

    def get_emissions(self, models, encoder_input):
        """Run encoder and normalize emissions"""
        # encoder_out = models[0].encoder(**encoder_input)
        encoder_out = models[0](**encoder_input)
        if self.criterion_type == CriterionType.CTC:
            emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)

        return emissions.transpose(0, 1).float().cpu().contiguous()

    def get_tokens(self, idxs):
        """Normalize tokens by handling CTC blank, ASG replabels, etc."""
        idxs = (g[0] for g in it.groupby(idxs))
        idxs = filter(lambda x: x != self.blank, idxs)

        return torch.LongTensor(list(idxs))


class W2lViterbiDecoder(W2lDecoder):
    def __init__(self, tgt_dict):
        super().__init__(tgt_dict)

    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = list()

        if self.asg_transitions is None:
            transitions = torch.FloatTensor(N, N).zero_()
        else:
            transitions = torch.FloatTensor(self.asg_transitions).view(N, N)

        viterbi_path = torch.IntTensor(B, T)
        workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
        CpuViterbiPath.compute(
            B,
            T,
            N,
            get_data_ptr_as_bytes(emissions),
            get_data_ptr_as_bytes(transitions),
            get_data_ptr_as_bytes(viterbi_path),
            get_data_ptr_as_bytes(workspace),
        )
        return [
            [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B)
        ]


class Wav2VecCtc(BaseFairseqModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        add_common_args(parser)

    def __init__(self, w2v_encoder, args):
        super().__init__()
        self.w2v_encoder = w2v_encoder
        self.args = args

    def upgrade_state_dict_named(self, state_dict, name):
        super().upgrade_state_dict_named(state_dict, name)
        return state_dict

    @classmethod
    def build_model(cls, args, target_dict):
        """Build a new model instance."""
        base_architecture(args)
        w2v_encoder = Wav2VecEncoder(args, target_dict)
        return cls(w2v_encoder, args)

    def get_normalized_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["encoder_out"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)

    def forward(self, **kwargs):
        x = self.w2v_encoder(**kwargs)
        return x


def get_feature(filepath):
    def postprocess(feats, sample_rate):
        if feats.dim == 2:
            feats = feats.mean(-1)

        assert feats.dim() == 1, feats.dim()

        with torch.no_grad():
            feats = F.layer_norm(feats, feats.shape)
        return feats

    wav, sample_rate = sf.read(filepath)
    feats = torch.from_numpy(wav).float()
    feats = postprocess(feats, sample_rate)
    return feats


def load_target_dict(manifest_path='./manifest'):
    dict_path = os.path.join(manifest_path, "dict.ltr.txt")
    target_dict = Dictionary.load(dict_path)
    return target_dict


def load_model(model_path, target_dict):
    # state = checkpoint_utils.load_checkpoint_to_cpu(model_path)
    # args = state["args"]
    w2v = torch.load(model_path)

    # from examples.wav2vec2.models.wav2vec2_asr import Wav2Vec2Model
    model = Wav2VecCtc.build_model(w2v["args"], target_dict)
    model.load_state_dict(w2v["model"], strict=True)

    return [model]


def main():
    sample, input = dict(), dict()
    WAV_PATH = 'xxx.wav'
    W2V_PATH = 'wav2vec2_vox_960h.pt'

    manifest_path = "MANIFEST_PATH"
    feature = get_feature(WAV_PATH )

    use_cuda = torch.cuda.is_available()

    target_dict = load_target_dict(manifest_path)
    model = load_model(W2V_PATH, target_dict)
    model[0].eval()

    generator = W2lViterbiDecoder(target_dict)
    input["source"] = feature.unsqueeze(0)

    padding_mask = torch.BoolTensor(input["source"].size(1)).fill_(False).unsqueeze(0)

    input["padding_mask"] = padding_mask
    sample["net_input"] = input

    with torch.no_grad():
        hypo = generator.generate(model, sample, prefix_tokens=None)

    hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
    print(post_process(hyp_pieces, 'letter'))


if __name__ == '__main__':
    main()
  • Output
I CAME TO THE CONCLUSION THAT WHAT WE NEED IN EDUCATION IS MUCH BETTER UNDERSTANDING EXCLUSIVE AND LEARNING FROM A MOTIVATION OF PERSPECTIVE FROM A PSYCHOLOGICAL REPROSPECTIVE

@loretoparisi
Copy link
Author

@sooftware amazing!!! Did you use the latest version of wav2letter?

@sooftware
Copy link

I don`t sure but I have a command that I used.

# Install python libraries
pip install soundfile
pip install torchaudio
pip install sentencepiece

# Update apt-get & Install soundfile
apt-get update \
&& apt-get upgrade -y \
&& apt-get install -y \
&& apt-get -y install apt-utils gcc libpq-dev libsndfile-dev

# Install kenlm
mkdir external_lib
cd external_lib

sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
git clone https://github.com/kpu/kenlm.git
cd kenlm
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DKENLM_MAX_ORDER=20 -DCMAKE_POSITION_INDEPENDENT_CODE=ON
make -j 16
export KENLM_ROOT_DIR=$ABSOLUTE_PATH'/external_lib/kenlm/'
cd ../..

# Install Additional Dependencies (ATLAS, OpenBLAS, Accelerate, Intel MKL)
apt-get install libsndfile1-dev libopenblas-dev libfftw3-dev libgflags-dev libgoogle-glog-dev

# Install wav2letter
git clone -b v0.2 https://github.com/facebookresearch/wav2letter.git
cd wav2letter/bindings/python
pip install -e .
cd ../../..

@sooftware
Copy link

sooftware commented Sep 28, 2020

I installed wav2letter a few days ago.

@kpister
Copy link

kpister commented Sep 28, 2020

@sooftware Thanks! I'm getting an import error for ModuleNotFoundError: No module named 'examples.wav2vec2'.
This module doesn't exist in fairseq though. Did you add it from somewhere else?

@mironnn
Copy link

mironnn commented Sep 28, 2020

@sooftware Could you please specify what does you have inside the file from manifest_path = "MANIFEST_PATH"

Is this path to link

@kpister
Copy link

kpister commented Sep 28, 2020

@mironnn The manifest path only contains the dictionary from what I can tell. Look at the load_target_dict function

def load_target_dict(manifest_path='./manifest'):
    dict_path = os.path.join(manifest_path, "dict.ltr.txt")
    target_dict = Dictionary.load(dict_path)
    return target_dict

@mironnn
Copy link

mironnn commented Sep 28, 2020

@sooftware Thanks! I'm getting an import error for ModuleNotFoundError: No module named 'examples.wav2vec2'.
This module doesn't exist in fairseq though. Did you add it from somewhere else?

Have the same issue =(

@sooftware
Copy link

sooftware commented Sep 28, 2020

@kpister I made and used wav2vec2 in the examples folder because I was using it in fairseq-0.9.0.
I'll make code to deduce the latest fairseq! Please wait for a little.

@mironnn

@sooftware
Copy link

sooftware commented Sep 28, 2020

I create pull request (#2668)
I create recognize.py in examples/wav2vec/ directory.
Usage is simple.

  • Command
$ python3 examples/wav2vec/recognize.py --wav_path $WAV_PATH --w2v_path $W2V_PATH --target_dict_path $TARGET_DICT_PATH
  • Output
I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I LOVE THEE PURELY AS THEY TURN FROM PRAISE

@sooftware
Copy link

Here is the code recognize.py

import torch
import argparse
import soundfile as sf
import torch.nn.functional as F
import itertools as it
from fairseq import utils
from fairseq.models import BaseFairseqModel
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
from fairseq.data import Dictionary
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes

parser = argparse.ArgumentParser(description='Wav2vec-2.0 Recognize')
parser.add_argument('--wav_path', type=str,
                    default='~/xxx.wav',
                    help='path of wave file')
parser.add_argument('--w2v_path', type=str,
                    default='~/wav2vec2_vox_960h.pt',
                    help='path of pre-trained wav2vec-2.0 model')
parser.add_argument('--target_dict_path', type=str,
                    default='dict.ltr.txt',
                    help='path of target dict (dict.ltr.txt)')


class Wav2VecCtc(BaseFairseqModel):
    def __init__(self, w2v_encoder, args):
        super().__init__()
        self.w2v_encoder = w2v_encoder
        self.args = args

    def upgrade_state_dict_named(self, state_dict, name):
        super().upgrade_state_dict_named(state_dict, name)
        return state_dict

    @classmethod
    def build_model(cls, args, target_dict):
        """Build a new model instance."""
        base_architecture(args)
        w2v_encoder = Wav2VecEncoder(args, target_dict)
        return cls(w2v_encoder, args)

    def get_normalized_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits = net_output["encoder_out"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)

    def forward(self, **kwargs):
        x = self.w2v_encoder(**kwargs)
        return x


class W2lDecoder(object):
    def __init__(self, tgt_dict):
        self.tgt_dict = tgt_dict
        self.vocab_size = len(tgt_dict)
        self.nbest = 1

        self.criterion_type = CriterionType.CTC
        self.blank = (
            tgt_dict.index("<ctc_blank>")
            if "<ctc_blank>" in tgt_dict.indices
            else tgt_dict.bos()
        )
        self.asg_transitions = None

    def generate(self, models, sample, **unused):
        """Generate a batch of inferences."""
        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
        }
        emissions = self.get_emissions(models, encoder_input)
        return self.decode(emissions)

    def get_emissions(self, models, encoder_input):
        """Run encoder and normalize emissions"""
        # encoder_out = models[0].encoder(**encoder_input)
        encoder_out = models[0](**encoder_input)
        if self.criterion_type == CriterionType.CTC:
            emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)

        return emissions.transpose(0, 1).float().cpu().contiguous()

    def get_tokens(self, idxs):
        """Normalize tokens by handling CTC blank, ASG replabels, etc."""
        idxs = (g[0] for g in it.groupby(idxs))
        idxs = filter(lambda x: x != self.blank, idxs)

        return torch.LongTensor(list(idxs))


class W2lViterbiDecoder(W2lDecoder):
    def __init__(self, tgt_dict):
        super().__init__(tgt_dict)

    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = list()

        if self.asg_transitions is None:
            transitions = torch.FloatTensor(N, N).zero_()
        else:
            transitions = torch.FloatTensor(self.asg_transitions).view(N, N)

        viterbi_path = torch.IntTensor(B, T)
        workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
        CpuViterbiPath.compute(
            B,
            T,
            N,
            get_data_ptr_as_bytes(emissions),
            get_data_ptr_as_bytes(transitions),
            get_data_ptr_as_bytes(viterbi_path),
            get_data_ptr_as_bytes(workspace),
        )
        return [
            [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B)
        ]


def post_process(sentence: str, symbol: str):
    if symbol == "sentencepiece":
        sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
    elif symbol == 'wordpiece':
        sentence = sentence.replace(" ", "").replace("_", " ").strip()
    elif symbol == 'letter':
        sentence = sentence.replace(" ", "").replace("|", " ").strip()
    elif symbol == "_EOW":
        sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
    elif symbol is not None and symbol != 'none':
        sentence = (sentence + " ").replace(symbol, "").rstrip()
    return sentence


def get_feature(filepath):
    def postprocess(feats, sample_rate):
        if feats.dim == 2:
            feats = feats.mean(-1)

        assert feats.dim() == 1, feats.dim()

        with torch.no_grad():
            feats = F.layer_norm(feats, feats.shape)
        return feats

    wav, sample_rate = sf.read(filepath)
    feats = torch.from_numpy(wav).float()
    feats = postprocess(feats, sample_rate)
    return feats


def load_model(model_path, target_dict):
    w2v = torch.load(model_path)
    model = Wav2VecCtc.build_model(w2v["args"], target_dict)
    model.load_state_dict(w2v["model"], strict=True)

    return [model]


def main():
    args = parser.parse_args()
    sample = dict()
    net_input = dict()

    feature = get_feature(args.wav_path)
    target_dict = Dictionary.load(args.target_dict_path)

    model = load_model(args.w2v_path, target_dict)
    model[0].eval()

    generator = W2lViterbiDecoder(target_dict)
    net_input["source"] = feature.unsqueeze(0)

    padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)

    net_input["padding_mask"] = padding_mask
    sample["net_input"] = net_input

    with torch.no_grad():
        hypo = generator.generate(model, sample, prefix_tokens=None)

    hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
    print(post_process(hyp_pieces, 'letter'))


if __name__ == '__main__':
    main()

@loretoparisi
Copy link
Author

I don`t sure but I have a command that I used.

# Install python libraries
pip install soundfile
pip install torchaudio
pip install sentencepiece

# Update apt-get & Install soundfile
apt-get update \
&& apt-get upgrade -y \
&& apt-get install -y \
&& apt-get -y install apt-utils gcc libpq-dev libsndfile-dev

# Install kenlm
mkdir external_lib
cd external_lib

sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
git clone https://github.com/kpu/kenlm.git
cd kenlm
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DKENLM_MAX_ORDER=20 -DCMAKE_POSITION_INDEPENDENT_CODE=ON
make -j 16
export KENLM_ROOT_DIR=$ABSOLUTE_PATH'/external_lib/kenlm/'
cd ../..

# Install Additional Dependencies (ATLAS, OpenBLAS, Accelerate, Intel MKL)
apt-get install libsndfile1-dev libopenblas-dev libfftw3-dev libgflags-dev libgoogle-glog-dev

# Install wav2letter
git clone -b v0.2 https://github.com/facebookresearch/wav2letter.git
cd wav2letter/bindings/python
pip install -e .
cd ../../..

@sooftware thanks, I'm trying a CPU build in this case I get a

CMake Error at cmake/CUDAUtils.cmake:12 (message):
      CUDA required to build CUDA criterion backend
    Call Stack (most recent call first):
      src/libraries/criterion/CMakeLists.txt:28 (include)

I can see from your script you build the python bindings, but how to include the -DCRITERION_BACKEND=CPU to disable CUDA?

@sooftware
Copy link

sooftware commented Sep 28, 2020

Oh, I'm sorry. I don't know that issue. T.T

@loretoparisi
Copy link
Author

Asked it here flashlight/wav2letter#842

@mychiux413
Copy link

@loretoparisi
I tested CPU case in docker env, and the recognize.py did work.

Here are my processes below:

  1. prepare wav2vec2 required data at fairseq/data, model, dict, wav files:
# For example
fairseq/data/wav2vec_small_960h.pt  # model
fairseq/data/dict.ltr.txt  # dict file
fairseq/data/temp.wav  # the wav you want to test, and don't forget to resample it as 16kHz
  1. prepare recognize.py mentioned above, I put it at fairseq/examples/wav2vec/recognize.py
  2. prepare a dockerfile at fairseq/wav2vec2.CPU.Dockerfile, the build script is:
FROM wav2letter/wav2letter:cpu-latest

ENV USE_CUDA=0
ENV KENLM_ROOT_DIR=/root/kenlm

# will use Intel MKL for featurization but this may cause dynamic loading conflicts.
# ENV USE_MKL=1

ENV LD_LIBRARY_PATH=/opt/intel/compilers_and_libraries_2018.5.274/linux/mkl/lib/intel64:$LD_IBRARY_PATH
WORKDIR /root/wav2letter/bindings/python

RUN pip install --upgrade pip && pip install soundfile packaging && pip install -e .

WORKDIR /root
RUN git clone https://github.com/pytorch/fairseq.git
RUN mkdir data
COPY examples/wav2vec/recognize.py /root/fairseq/examples/wav2vec/recognize.py

WORKDIR /root/fairseq
RUN pip install --editable ./ && python examples/speech_recognition/infer.py --help && python examples/wav2vec/recognize.py --help
  1. go to fairseq/ dir, then build docker:
# build
docker build -t wav2vec2 -f wav2vec2.CPU.Dockerfile .

# run docker
docker run --rm -itd --ipc=host -v $PWD/data:/root/data --name w2v wav2vec2

# go into container
docker exec -it w2v bash

# run recognize
python examples/wav2vec/recognize.py --wav_path ~/data/temp.wav --w2v_path ~/data/wav2vec_small_960h.pt --target_dict_path ~/data/dict.ltr.txt

@loretoparisi
Copy link
Author

loretoparisi commented Sep 30, 2020

@mychiux413 thank you so much. I'm getting this UserWarning

/root/fairseq/examples/speech_recognition/w2l_decoder.py:39: UserWarning: wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings
  "wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"
usage: recognize.py [-h] [--wav_path WAV_PATH] [--w2v_path W2V_PATH]
                    [--target_dict_path TARGET_DICT_PATH]
recognize.py: error: unrecognized arguments: --wv2_path /app/data/wav2vec_small_10m.pt

Within the container the command used was

python examples/wav2vec/recognize.py --wav_path /root/data/temp.wav --wv2_path /root/data/wav2vec_small_10m.pt --target_dict_path /root/data/dict.ltr.txt

It should not be there, so I have opened an issue.

@sooftware
Copy link

@loretoparisi there is an typo. not wv2_path, w2v_path. :)

@loretoparisi
Copy link
Author

@sooftware gosh!!! I've have checked it ten times!

@sooftware
Copy link

LoL!! I'm glad I found it now!
@loretoparisi Have you tried evaluating Wav2vec-2.0 model with KenLM or Transformer LM?

@loretoparisi
Copy link
Author

@sooftware not yet but this is definitively something I'm are going to do!

@sooftware
Copy link

Let me know if you succeed! I have an issue (#2654) (with KenLM)
If I succeed, I'll write on the issue.

@loretoparisi
Copy link
Author

loretoparisi commented Sep 30, 2020

@sooftware definitively I will. In the meanwhile I have pushed everything here with Docker. I did two Dockerfile. The one suggested by @mychiux413 (👍 thanks) and one edited by me with your commands (👍 thank you too) slightly adapted starting from a stripped down python:3.7.4-slim-buster. They both works, but the docker images have very different sizes:

wav2vec-python3                           latest              cfdcb450b427        51 minutes ago      9.97GB
wav2vec-wav2letter                            latest              e028493c66b0        2 hours ago         3.37GB

Thank you guys for your help and collaboration! I will keep you posted.

@sooftware
Copy link

Grrrrrrreat !!!
I am studying wav2vec with great interest. It would be nice if we could help each other. :)

@pkadambi
Copy link

pkadambi commented Nov 25, 2020

@alexeib I still have the same error on the most recent commit. I built kenlm using the tarball (not via git), compiled with DKENLM_MAX_ORDER 20.

The following command `

python fairseq/examples/speech_recognition/infer.py ./data/manifest --path ./models/w2v2/wav2vec_small_960h.pt --results-path ./results/ --lexicon ./models/w2v2/librispeech_lexicon.lst --w2l-decoder kenlm --lm-model ./models/kenlm/4-gram.bin --task audio_pretraining --nbest 1 --gen-subset dev_other --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 400000 --post-process letter --cpu --num-workers 1 --batch-size 8 --beam 1024

causes the stack trace

INFO:main:| decoding with criterion ctc
INFO:main:| loading model(s) from ./models/w2v2/wav2vec_small_960h.pt
INFO:fairseq.data.audio.raw_audio_dataset:loaded 1, skipped 0 samples
INFO:main:| ./data/manifest dev_other 1 examples
/home/prad/github/wrapASR/fairseq/examples/speech_recognition/w2l_decoder.py:42: UserWarning: wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings
"wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"
Traceback (most recent call last):
File "fairseq/examples/speech_recognition/infer.py", line 428, in
cli_main()
File "fairseq/examples/speech_recognition/infer.py", line 424, in cli_main
main(args)
File "fairseq/examples/speech_recognition/infer.py", line 284, in main
generator = build_generator(args)
File "fairseq/examples/speech_recognition/infer.py", line 273, in build_generator
return W2lKenLMDecoder(args, task.target_dictionary)
File "/home/prad/github/wrapASR/fairseq/examples/speech_recognition/w2l_decoder.py", line 133, in init
super().init(args, tgt_dict)
File "/home/prad/github/wrapASR/fairseq/examples/speech_recognition/w2l_decoder.py", line 56, in init
self.criterion_type = CriterionType.CTC

@alexeib
Copy link
Contributor

alexeib commented Nov 26, 2020

looks like it cant import wav2letter. have you tried installing python bindings like the error message is suggesting?

@pkadambi
Copy link

Ah thanks! I don't know how I missed that, must be an issue with my wav2letter install

@youssefavx
Copy link

youssefavx commented Dec 4, 2020

Returning to this after a while, I just ran my first test and I'm getting surprisingly poor results. I believe my audio file was 16kHz, a lecture that has some noise in it. My WER is: 22.9% and my CER: 14.75%

I used 4-gram (probing), and here is the command I ran:

python examples/speech_recognition/infer.py ~/data/libri --task audio_pretraining --nbest 1 --path ~/data/wav2vec2_vox_960h.pt --gen-subset test --results-path ~/data/result-kenlm --w2l-decoder kenlm --lm-model ~/data/4-gram_probing.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 --post-process letter --cpu --num-workers 1 --batch-size 8 --lexicon ~/data/librispeech_lexicon.lst --beam 1024

Anybody know how to get this down to at the very least 4-5%?

Do I have to use the transformer language model instead?

I am getting this warning:

/root/fairseq/examples/speech_recognition/w2l_decoder.py:41: UserWarning: wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings
  "wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"

However, it seems to run either way. Could the word error rate be high due to this?

@alexeib
Copy link
Contributor

alexeib commented Dec 5, 2020

you need to use the fairseq model (.pt) not the wav2letter model (.bin)

@youssefavx
Copy link

@alexeib ah! 🤦‍♂️ Sorry about that, you’re right. I’ll try again soon with the actual language model haha.

@youssefavx
Copy link

youssefavx commented Dec 5, 2020

Okay, so I downloaded and tried to run it, here's what's going on so far:

I tried running this command:

python examples/speech_recognition/infer.py ~/data/libri --task audio_pretraining --nbest 1 --path ~/data/wav2vec2_vox_960h.pt --gen-subset test --results-path ~/data/result-kenlm --w2l-decoder fairseqlm --lm-model ~/data/lm_librispeech_word_transformer.pt --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 --post-process letter --cpu --num-workers 1 --batch-size 8 --lexicon ~/data/librispeech_lexicon.lst --beam 500

But it gave me this error:

Traceback (most recent call last):
  File "examples/speech_recognition/infer.py", line 471, in <module>
    cli_main()
  File "examples/speech_recognition/infer.py", line 467, in cli_main
    main(args)
  File "examples/speech_recognition/infer.py", line 327, in main
    generator = build_generator(args)
  File "examples/speech_recognition/infer.py", line 320, in build_generator
    return W2lFairseqLMDecoder(args, task.target_dictionary)
  File "/root/fairseq/examples/speech_recognition/w2l_decoder.py", line 354, in __init__
    task = tasks.setup_task(lm_args)
  File "/root/fairseq/fairseq/tasks/__init__.py", line 26, in setup_task
    return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs)
  File "/root/fairseq/fairseq/tasks/language_modeling.py", line 158, in setup_task
    dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs)
  File "/root/fairseq/fairseq/tasks/language_modeling.py", line 142, in setup_dictionary
    dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
  File "/root/fairseq/fairseq/data/dictionary.py", line 214, in load
    d.add_from_file(f)
  File "/root/fairseq/fairseq/data/dictionary.py", line 227, in add_from_file
    raise fnfe
  File "/root/fairseq/fairseq/data/dictionary.py", line 224, in add_from_file
    with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
FileNotFoundError: [Errno 2] No such file or directory: '/root/data/dict.txt'

Then I tried taking the dict.ltr.txt file from the libri folder and putting it in the data folder, and renaming it to "dict.txt", and I ran this:

python examples/speech_recognition/infer.py ~/data/libri --task audio_pretraining --nbest 1 --path ~/data/wav2vec2_vox_960h.pt --gen-subset test --results-path ~/data/result-kenlm --w2l-decoder fairseqlm --lm-model ~/data/lm_librispeech_word_transformer.pt --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 --post-process letter --cpu --num-workers 1 --batch-size 8 --lexicon ~/data/librispeech_lexicon.lst --beam 500

I got this error:

INFO:fairseq.tasks.language_modeling:dictionary: 32 types
Traceback (most recent call last):
  File "examples/speech_recognition/infer.py", line 471, in <module>
    cli_main()
  File "examples/speech_recognition/infer.py", line 467, in cli_main
    main(args)
  File "examples/speech_recognition/infer.py", line 327, in main
    generator = build_generator(args)
  File "examples/speech_recognition/infer.py", line 320, in build_generator
    return W2lFairseqLMDecoder(args, task.target_dictionary)
  File "/root/fairseq/examples/speech_recognition/w2l_decoder.py", line 355, in __init__
    model = task.build_model(lm_args)
  File "/root/fairseq/fairseq/tasks/language_modeling.py", line 178, in build_model
    model = super().build_model(args)
  File "/root/fairseq/fairseq/tasks/fairseq_task.py", line 548, in build_model
    model = models.build_model(args, self)
  File "/root/fairseq/fairseq/models/__init__.py", line 56, in build_model
    return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task)
  File "/root/fairseq/fairseq/models/transformer_lm.py", line 221, in build_model
    args.quant_noise_pq_block_size,
  File "/root/fairseq/fairseq/modules/adaptive_input.py", line 33, in __init__
    ), "cannot specify cutoff larger than vocab size"

Then I tried downloading the fairseq dict file listed next to the transformer model here: https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019

I then renamed the file to 'dict.txt'

I ran this command:

python examples/speech_recognition/infer.py ~/data/libri --task audio_pretraining --nbest 1 --path ~/data/wav2vec2_vox_960h.pt --gen-subset test --results-path ~/data/result-kenlm --w2l-decoder fairseqlm --lm-model ~/data/lm_librispeech_word_transformer.pt --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 --post-process letter --cpu --num-workers 1 --batch-size 8 --lexicon ~/data/librispeech_lexicon.lst --beam 500

And I got this error:

INFO:fairseq.tasks.language_modeling:dictionary: 221456 types
Traceback (most recent call last):
  File "examples/speech_recognition/infer.py", line 471, in <module>
    cli_main()
  File "examples/speech_recognition/infer.py", line 467, in cli_main
    main(args)
  File "examples/speech_recognition/infer.py", line 327, in main
    generator = build_generator(args)
  File "examples/speech_recognition/infer.py", line 320, in build_generator
    return W2lFairseqLMDecoder(args, task.target_dictionary)
  File "/root/fairseq/examples/speech_recognition/w2l_decoder.py", line 362, in __init__
    self.lm = FairseqLM(self.word_dict, model)
  File "/root/fairseq/examples/speech_recognition/w2l_decoder.py", line 223, in __init__
    model.cuda()
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 463, in cuda
    return self._apply(lambda t: t.cuda(device))
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 359, in _apply
    module._apply(fn)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 359, in _apply
    module._apply(fn)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 359, in _apply
    module._apply(fn)
  [Previous line repeated 2 more times]
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 381, in _apply
    param_applied = fn(param)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 463, in <lambda>
    return self._apply(lambda t: t.cuda(device))
  File "/usr/local/lib/python3.6/dist-packages/torch/cuda/__init__.py", line 172, in _lazy_init
    torch._C._cuda_init()
RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

If this error in fact follows from a proper command/sequence of events and not some mistake I made in one of the inputs,

Is there some way to run this on CPU? I already have the --cpu setting in that command.

I believe I have a CUDA driver installed but not sure if I have an NVIDIA one, it seems from my settings that I do but for some reason I guess it's not seeing it?

Screen Shot 2020-12-05 at 3 27 21 PM

There's a whole bunch of problems in this area with Apple/Mac/Nvidia that are hairy to get into. I'd rather just run on CPU.

@shreyashub
Copy link

I want to use Wave2Vec2.0 as a featurizer i.e. get just the embeddings. Can anyone help with this or point to some starting point

@spygaurad
Copy link

@sooftware have you added Language Model decoding in this inference pipeline?

@sooftware
Copy link

@spygaurad I'll try. I'll leave it here if I succeed. The code up there is a high probability that fairseq will not work as it is upgraded to version 0.10.1

@spygaurad
Copy link

Okay @sooftware, i can try to integrate LM for version 0.10.1 if i have some references. The code seemed quite complex for me. Thanks for your response.

@samuelazran
Copy link

Can anyone please provide an example of how to use this pipeline for ASR of real time / media stream rather than a static wav file? I was looking everywhere for it. I see in the Pitch here that "programmatically loaded waveform signal" should be supported, if I understand correctly it refers to a sort of online/live ASR.

Help will be very much appreciated.

@sooftware
Copy link

You can check wav2vec 2.0 Inference pipeline at https://github.com/kakaobrain/pororo
Pororo has an English, Chinese, and Korean wav2vec 2.0 models.

@ytorosjan
Copy link

Hi @sooftware, thanks for the recognition.py, it's a great script you made there. However, I've some issues. Your script is working fine with finetuned model given by the wav2vec2.0 models but when I'm trying to use my own finetuned model it's throwing error.

Code I ran -->

python3 examples/wav2vec/recognize.py --wav_path /path/audio_file/10min/file/audio_finetune1/file2375.wav --w2v_path /path/audio_file/wav2vec_small_10m.pt --target_dict_path /path/audio_file/manifest/dict.ltr.txt

Error I'm getting-->

Traceback (most recent call last):
File "examples/wav2vec/recognize.py", line 192, in
main()
File "examples/wav2vec/recognize.py", line 173, in main
model = load_model(args.w2v_path, target_dict)
File "examples/wav2vec/recognize.py", line 159, in load_model
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
File "examples/wav2vec/recognize.py", line 39, in build_model
base_architecture(args)
File "/path/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 633, in base_architecture
args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False)
AttributeError: 'NoneType' object has no attribute 'no_pretrained_weights'

Can you please look into it, if any possible mistake I'm doing from my side.

I am running into the same issue when trying to run inference with my own finetuned model trained using fairseq-hydra-train.

@lyogavin
Copy link

I've created pull request: #3244 for this.

@stale
Copy link

stale bot commented Jun 20, 2021

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

@stale stale bot added the stale label Jun 20, 2021
@loretoparisi
Copy link
Author

loretoparisi commented Jun 20, 2021

Time to close this Wave2vec 2.0 specific issue. A lot of time is passed (9 months and a kid is born!) and now thanks to HuggingFace new audio libraries inference is simple and it works like a charm!
Here an example of ASR with Wave2Vec 2.0 plus VAD.

https://github.com/loretoparisi/hf-experiments/blob/master/src/asr/README.md

@tensorfoo
Copy link

tensorfoo commented Jun 30, 2021

Time to close this Wave2vec 2.0 specific issue. A lot of time is passed (9 months and a kid is born!) and now thanks to HuggingFace new audio libraries inference is simple and it works like a charm!
Here an example of ASR with Wave2Vec 2.0 plus VAD.

https://github.com/loretoparisi/hf-experiments/blob/master/src/asr/README.md

Hi, can this work with my own pretrained and finetuned model (mymodel.pt) or only one from hugginface?

@loretoparisi
Copy link
Author

@tensorfoo assumed you replace the Wave2Vec2 model here
https://github.com/loretoparisi/hf-experiments/blob/master/src/asr/run.py#L20

it should definitively work.

@tensorfoo
Copy link

tensorfoo commented Jul 17, 2021

@tensorfoo assumed you replace the Wave2Vec2 model here
https://github.com/loretoparisi/hf-experiments/blob/master/src/asr/run.py#L20

it should definitively work.

I've tried it with my finetuned model and it gave a utf8-error. Now i'm trying with the base english model, like: model = Wav2Vec2ForCTC.from_pretrained('/models/others/wav2vec_small.pt') - same deal. That one isn't finetuned but it gives the same error as my finetuned one. Any ideas?

/usr/lib/python3.8/codecs.py in decode(self, input, final)
320 # decode input (taking the buffer into account)
321 data = self.buffer + input
--> 322 (result, consumed) = self._buffer_decode(data, self.errors, final)
323 # keep undecoded input until the next call
324 self.buffer = data[consumed:]

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

@8904591612
Copy link

Hi @sooftware, thanks for the recognition.py, it's a great script you made there. However, I've some issues. Your script is working fine with finetuned model given by the wav2vec2.0 models but when I'm trying to use my own finetuned model it's throwing error.
Code I ran -->

python3 examples/wav2vec/recognize.py --wav_path /path/audio_file/10min/file/audio_finetune1/file2375.wav --w2v_path /path/audio_file/wav2vec_small_10m.pt --target_dict_path /path/audio_file/manifest/dict.ltr.txt

Error I'm getting-->

Traceback (most recent call last):
File "examples/wav2vec/recognize.py", line 192, in
main()
File "examples/wav2vec/recognize.py", line 173, in main
model = load_model(args.w2v_path, target_dict)
File "examples/wav2vec/recognize.py", line 159, in load_model
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
File "examples/wav2vec/recognize.py", line 39, in build_model
base_architecture(args)
File "/path/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 633, in base_architecture
args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False)
AttributeError: 'NoneType' object has no attribute 'no_pretrained_weights'

Can you please look into it, if any possible mistake I'm doing from my side.

I am running into the same issue when trying to run inference with my own finetuned model trained using fairseq-hydra-train.

Hi @sooftware, thanks for the recognition.py, it's a great script you made there. However, I've some issues. Your script is working fine with finetuned model given by the wav2vec2.0 models but when I'm trying to use my own finetuned model it's throwing error.
Code I ran -->

python3 examples/wav2vec/recognize.py --wav_path /path/audio_file/10min/file/audio_finetune1/file2375.wav --w2v_path /path/audio_file/wav2vec_small_10m.pt --target_dict_path /path/audio_file/manifest/dict.ltr.txt

Error I'm getting-->

Traceback (most recent call last):
File "examples/wav2vec/recognize.py", line 192, in
main()
File "examples/wav2vec/recognize.py", line 173, in main
model = load_model(args.w2v_path, target_dict)
File "examples/wav2vec/recognize.py", line 159, in load_model
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
File "examples/wav2vec/recognize.py", line 39, in build_model
base_architecture(args)
File "/path/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 633, in base_architecture
args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False)
AttributeError: 'NoneType' object has no attribute 'no_pretrained_weights'

Can you please look into it, if any possible mistake I'm doing from my side.

I am running into the same issue when trying to run inference with my own finetuned model trained using fairseq-hydra-train.

were you able to solve that issue ? pls help, am running into the same issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.