Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for WavlmForXVector #603

Merged
merged 8 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
'per_channel': False,
'reduce_range': False,
},
'wavlm': {
'per_channel': False,
'reduce_range': False,
},
}

MODELS_WITHOUT_TOKENIZERS = [
Expand Down
6 changes: 6 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,12 @@
'microsoft/wavlm-base-plus',
'microsoft/wavlm-large',
],

# Audio XVector (e.g., for speaker verification)
'audio-xvector': [
'microsoft/wavlm-base-plus-sv',
'microsoft/wavlm-base-sv',
],
},
'whisper': {
# Automatic speech recognition
Expand Down
68 changes: 68 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -4735,6 +4735,49 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel {
}
}

/**
* WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
*
* **Example:** Extract speaker embeddings with `WavLMForXVector`.
* ```javascript
* import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sv');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
* const audio = await read_audio(url, 16000);
* const inputs = await processor(audio);
*
* // Run model with inputs
* const model = await AutoModel.from_pretrained('Xenova/wavlm-base-plus-sv');
* const outputs = await model(inputs);
* // {
* // logits: Tensor {
* // dims: [ 1, 512 ],
* // type: 'float32',
* // data: Float32Array(512) [0.5847219228744507, ...],
* // size: 512
* // },
* // embeddings: Tensor {
* // dims: [ 1, 512 ],
* // type: 'float32',
* // data: Float32Array(512) [-0.09079201519489288, ...],
* // size: 512
* // }
* // }
* ```
*/
export class WavLMForXVector extends WavLMPreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<XVectorOutput>} An object containing the model's output logits and speaker embeddings.
*/
async _call(model_inputs) {
return new XVectorOutput(await super._call(model_inputs));
}
}

//////////////////////////////////////////////////
// SpeechT5 models
/**
Expand Down Expand Up @@ -5483,6 +5526,10 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([
['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]],
]);

const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([
['wavlm', ['WavLMForXVector', WavLMForXVector]],
]);

const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]],
]);
Expand Down Expand Up @@ -5523,6 +5570,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
];

for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
Expand Down Expand Up @@ -5741,6 +5789,10 @@ export class AutoModelForAudioClassification extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES];
}

export class AutoModelForXVector extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES];
}

export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES];
}
Expand Down Expand Up @@ -5793,6 +5845,22 @@ export class SequenceClassifierOutput extends ModelOutput {
}
}

/**
* Base class for outputs of XVector models.
*/
export class XVectorOutput extends ModelOutput {
/**
* @param {Object} output The output of the model.
* @param {Tensor} output.logits Classification hidden states before AMSoftmax, of shape `(batch_size, config.xvector_output_dim)`.
* @param {Tensor} output.embeddings Utterance embeddings used for vector similarity-based retrieval, of shape `(batch_size, config.xvector_output_dim)`.
*/
constructor({ logits, embeddings }) {
super();
this.logits = logits;
this.embeddings = embeddings;
}
}

/**
* Base class for outputs of token classification models.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes =
function validate_audio_inputs(audio, feature_extractor) {
if (!(audio instanceof Float32Array || audio instanceof Float64Array)) {
throw new Error(
`${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` +
`${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` +
`If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.`
)
}
Expand Down
Loading