-
Notifications
You must be signed in to change notification settings - Fork 829
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example for using Wav2Vec2 SpeechToText from Huggingface (#1939)
* Example for using HuggingFace Wav2Vec2 * Simpler readme. * Add a few comments * Add output from model in readme Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
- Loading branch information
Showing
6 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
## Speech2Text Wav2Vec2 example: | ||
In this example we will use a pretrained Wav2Vec2 model for Speech2Text using the `transformers` library: https://huggingface.co/docs/transformers/model_doc/wav2vec2 and serve it using torchserve. | ||
|
||
### Prerequisites | ||
Apart from the usual dependencies as shown here: `https://github.com/pytorch/serve/blob/master/docs/getting_started.md`, we need to install `torchaudio` and `transformers`. | ||
|
||
You can install these into your current environment or follow these steps which should give you all necessary prerequisites from scratch: | ||
* Install miniconda: https://docs.conda.io/en/latest/miniconda.html | ||
* run `python ../../ts_scripts/install_dependencies.py` to install binary dependencies | ||
* Install all needed packages with `conda env create -f environment.yml` | ||
* Activate conda environment: `conda activate wav2vec2env` | ||
|
||
### Prepare model and run server | ||
Next, we need to download our wav2vec2 model and archive it for use by torchserve: | ||
```bash | ||
./download_wav2vec2.py # Downloads model and creates folder `./model` with all necessary files | ||
./archive_model.sh # Creates .mar archive using torch-model-archiver and moves it to folder `./model_store` | ||
``` | ||
|
||
Now let's start the server and try it out with our example file! | ||
```bash | ||
torchserve --start --model-store model_store --models Wav2Vec2=Wav2Vec2.mar --ncs | ||
# Once the server is running, let's try it with: | ||
curl -X POST http://127.0.0.1:8080/predictions/Wav2Vec2 --data-binary '@./sample.wav' -H "Content-Type: audio/basic" | ||
# Which will happily return: | ||
I HAD THAT CURIOSITY BESIDE ME AT THIS MOMENT% | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/bash | ||
set -euo pipefail | ||
|
||
mkdir -p model_store | ||
# Extra files add all files necessary for processor | ||
torch-model-archiver --model-name Wav2Vec2 --version 1.0 --serialized-file model/pytorch_model.bin --handler ./handler.py --extra-files "model/config.json,model/special_tokens_map.json,model/tokenizer_config.json,model/vocab.json,model/preprocessor_config.json" -f | ||
mv Wav2Vec2.mar model_store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from transformers import AutoModelForCTC, AutoProcessor | ||
import os | ||
|
||
modelname = "facebook/wav2vec2-base-960h" | ||
model = AutoModelForCTC.from_pretrained(modelname) | ||
processor = AutoProcessor.from_pretrained(modelname) | ||
|
||
modelpath = "model" | ||
os.makedirs(modelpath, exist_ok=True) | ||
model.save_pretrained(modelpath) | ||
processor.save_pretrained(modelpath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
name: wav2vec2env | ||
channels: | ||
- defaults | ||
- pytorch | ||
- conda-forge | ||
dependencies: | ||
- python>=3.10 | ||
- transformers | ||
- pytorch | ||
- torchaudio | ||
- pip | ||
- pip: | ||
- torch-workflow-archiver | ||
- torch-model-archiver | ||
- torchserve |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
import torchaudio | ||
from transformers import AutoProcessor, AutoModelForCTC | ||
import io | ||
|
||
|
||
class Wav2VecHandler(object): | ||
def __init__(self): | ||
self._context = None | ||
self.initialized = False | ||
self.model = None | ||
self.processor = None | ||
self.device = None | ||
# Sampling rate for Wav2Vec model must be 16k | ||
self.expected_sampling_rate = 16_000 | ||
|
||
def initialize(self, context): | ||
"""Initialize properties and load model""" | ||
self._context = context | ||
self.initialized = True | ||
properties = context.system_properties | ||
|
||
# See https://pytorch.org/serve/custom_service.html#handling-model-execution-on-multiple-gpus | ||
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") | ||
|
||
model_dir = properties.get("model_dir") | ||
self.processor = AutoProcessor.from_pretrained(model_dir) | ||
self.model = AutoModelForCTC.from_pretrained(model_dir) | ||
|
||
def handle(self, data, context): | ||
"""Transform input to tensor, resample, run model and return transcribed text.""" | ||
input = data[0].get("data") | ||
if input is None: | ||
input = data[0].get("body") | ||
|
||
# torchaudio.load accepts file like object, here `input` is bytes | ||
model_input, sample_rate = torchaudio.load(io.BytesIO(input), format="WAV") | ||
|
||
# Ensure sampling rate is the same as the trained model | ||
if sample_rate != self.expected_sampling_rate: | ||
model_input = torchaudio.functional.resample(model_input, sample_rate, self.expected_sampling_rate) | ||
|
||
model_input = self.processor(model_input, sampling_rate = self.expected_sampling_rate, return_tensors="pt").input_values[0] | ||
logits = self.model(model_input)[0] | ||
pred_ids = torch.argmax(logits, axis=-1)[0] | ||
output = self.processor.decode(pred_ids) | ||
|
||
return [output] |
Binary file not shown.