Skip to content

Commit

Permalink
Example for using Wav2Vec2 SpeechToText from Huggingface (#1939)
Browse files Browse the repository at this point in the history
* Example for using HuggingFace Wav2Vec2

* Simpler readme.

* Add a few comments

* Add output from model in readme

Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
  • Loading branch information
altre and msaroufim committed Nov 3, 2022
1 parent 33e1e97 commit e18d8b8
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 0 deletions.
27 changes: 27 additions & 0 deletions examples/speech2text_wav2vec2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## Speech2Text Wav2Vec2 example:
In this example we will use a pretrained Wav2Vec2 model for Speech2Text using the `transformers` library: https://huggingface.co/docs/transformers/model_doc/wav2vec2 and serve it using torchserve.

### Prerequisites
Apart from the usual dependencies as shown here: `https://github.com/pytorch/serve/blob/master/docs/getting_started.md`, we need to install `torchaudio` and `transformers`.

You can install these into your current environment or follow these steps which should give you all necessary prerequisites from scratch:
* Install miniconda: https://docs.conda.io/en/latest/miniconda.html
* run `python ../../ts_scripts/install_dependencies.py` to install binary dependencies
* Install all needed packages with `conda env create -f environment.yml`
* Activate conda environment: `conda activate wav2vec2env`

### Prepare model and run server
Next, we need to download our wav2vec2 model and archive it for use by torchserve:
```bash
./download_wav2vec2.py # Downloads model and creates folder `./model` with all necessary files
./archive_model.sh # Creates .mar archive using torch-model-archiver and moves it to folder `./model_store`
```

Now let's start the server and try it out with our example file!
```bash
torchserve --start --model-store model_store --models Wav2Vec2=Wav2Vec2.mar --ncs
# Once the server is running, let's try it with:
curl -X POST http://127.0.0.1:8080/predictions/Wav2Vec2 --data-binary '@./sample.wav' -H "Content-Type: audio/basic"
# Which will happily return:
I HAD THAT CURIOSITY BESIDE ME AT THIS MOMENT%
```
7 changes: 7 additions & 0 deletions examples/speech2text_wav2vec2/archive_model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
set -euo pipefail

mkdir -p model_store
# Extra files add all files necessary for processor
torch-model-archiver --model-name Wav2Vec2 --version 1.0 --serialized-file model/pytorch_model.bin --handler ./handler.py --extra-files "model/config.json,model/special_tokens_map.json,model/tokenizer_config.json,model/vocab.json,model/preprocessor_config.json" -f
mv Wav2Vec2.mar model_store
13 changes: 13 additions & 0 deletions examples/speech2text_wav2vec2/download_wav2vec2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env python3

from transformers import AutoModelForCTC, AutoProcessor
import os

modelname = "facebook/wav2vec2-base-960h"
model = AutoModelForCTC.from_pretrained(modelname)
processor = AutoProcessor.from_pretrained(modelname)

modelpath = "model"
os.makedirs(modelpath, exist_ok=True)
model.save_pretrained(modelpath)
processor.save_pretrained(modelpath)
15 changes: 15 additions & 0 deletions examples/speech2text_wav2vec2/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: wav2vec2env
channels:
- defaults
- pytorch
- conda-forge
dependencies:
- python>=3.10
- transformers
- pytorch
- torchaudio
- pip
- pip:
- torch-workflow-archiver
- torch-model-archiver
- torchserve
48 changes: 48 additions & 0 deletions examples/speech2text_wav2vec2/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForCTC
import io


class Wav2VecHandler(object):
def __init__(self):
self._context = None
self.initialized = False
self.model = None
self.processor = None
self.device = None
# Sampling rate for Wav2Vec model must be 16k
self.expected_sampling_rate = 16_000

def initialize(self, context):
"""Initialize properties and load model"""
self._context = context
self.initialized = True
properties = context.system_properties

# See https://pytorch.org/serve/custom_service.html#handling-model-execution-on-multiple-gpus
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

model_dir = properties.get("model_dir")
self.processor = AutoProcessor.from_pretrained(model_dir)
self.model = AutoModelForCTC.from_pretrained(model_dir)

def handle(self, data, context):
"""Transform input to tensor, resample, run model and return transcribed text."""
input = data[0].get("data")
if input is None:
input = data[0].get("body")

# torchaudio.load accepts file like object, here `input` is bytes
model_input, sample_rate = torchaudio.load(io.BytesIO(input), format="WAV")

# Ensure sampling rate is the same as the trained model
if sample_rate != self.expected_sampling_rate:
model_input = torchaudio.functional.resample(model_input, sample_rate, self.expected_sampling_rate)

model_input = self.processor(model_input, sampling_rate = self.expected_sampling_rate, return_tensors="pt").input_values[0]
logits = self.model(model_input)[0]
pred_ids = torch.argmax(logits, axis=-1)[0]
output = self.processor.decode(pred_ids)

return [output]
Binary file added examples/speech2text_wav2vec2/sample.wav
Binary file not shown.

0 comments on commit e18d8b8

Please sign in to comment.