Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for WavlmForXVector #603

Merged
merged 8 commits into from Feb 28, 2024
Merged

Add support for WavlmForXVector #603

merged 8 commits into from Feb 28, 2024

Conversation

D4ve-R
Copy link
Contributor

@D4ve-R D4ve-R commented Feb 23, 2024

Adding support for wavlm with xvector head on top.
The onnx version of microsoft/wavlm-base-plus-sv can be found at D4ve-R/wavlm-base-plus-sv.
Aims to be as close to the python implementation as possible.

@D4ve-R
Copy link
Contributor Author

D4ve-R commented Feb 23, 2024

import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';

const processor = await AutoProcessor.from_pretrained('D4ve-R/wavlm-base-plus-sv');
const audio = await read_audio('FILE_URL', 16000);
const inputs = await processor(audio);
const model = await AutoModel.from_pretrained('D4ve-R/wavlm-base-plus-sv');
const output = await model(inputs);

 // {
 //   embeddings: Tensor {
 //     dims: [ 1, 512 ],
 //     type: 'float32',
 //     data: Float32Array(512) [-0.349443256855011, ...],
 //     size: 512
 //   },
 //   logits: Tensor {
 //     dims: [ 1, 512 ],
 //     type: 'float32',
 //     data: Float32Array(512) [0.022836603224277496, ...],
 //     size: 512
 //   }
 // }

@xenova
Copy link
Owner

xenova commented Feb 23, 2024

Wow! This PR looks perfect! 😍 I look forward to reviewing and merging over the weekend!

@D4ve-R
Copy link
Contributor Author

D4ve-R commented Feb 23, 2024

Thank you! Awesome 👍
I also have code for WavLMForAudioFrameClassification & Wav2Vec2ForAudioFrameClassification, will open a pr once this is merged!
Lmk if i should/can add docs somewhere.

Copy link
Owner

@xenova xenova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Just nits: variable names + comments

src/models.js Outdated Show resolved Hide resolved
src/models.js Outdated Show resolved Hide resolved
src/models.js Outdated Show resolved Hide resolved
src/models.js Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@xenova
Copy link
Owner

xenova commented Feb 28, 2024

I did some additional testing w/ quantization settings, and it's clear that the best combination is per_channel=False and reduce_range=False. You can test it out by running the following code snippet, replacing the model ID with D4ve-R/wavlm-base-plus-sv, with and without { quantized: false }. I will also add it to the model cards.

import { AutoProcessor, AutoModel, read_audio, cos_sim } from '@xenova/transformers';

// Load processor and model
const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sv');
const model = await AutoModel.from_pretrained('Xenova/wavlm-base-plus-sv');

// Helper function to compute speaker embedding from audio URL
async function compute_embedding(url) {
    const audio = await read_audio(url);
    const inputs = await processor(audio);
    const { embeddings } = await model(inputs);
    return embeddings.data;
}

// Generate speaker embeddings
const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/sv_speaker';
const speaker_1_1 = await compute_embedding(`${BASE_URL}-1_1.wav`);
const speaker_1_2 = await compute_embedding(`${BASE_URL}-1_2.wav`);
const speaker_2_1 = await compute_embedding(`${BASE_URL}-2_1.wav`);
const speaker_2_2 = await compute_embedding(`${BASE_URL}-2_2.wav`);

// Compute similarity scores
console.log(cos_sim(speaker_1_1, speaker_1_2)); // 0.959439158881247 (Both are speaker 1)
console.log(cos_sim(speaker_1_2, speaker_2_1)); // 0.618130172602329 (Different speakers)
console.log(cos_sim(speaker_2_1, speaker_2_2)); // 0.962999814169370 (Both are speaker 2)

@xenova xenova merged commit b5a548f into xenova:main Feb 28, 2024
4 checks passed
@xenova
Copy link
Owner

xenova commented Feb 28, 2024

Clean addition! Thanks so much @D4ve-R!

I also have code for WavLMForAudioFrameClassification & Wav2Vec2ForAudioFrameClassification, will open a pr once this is merged!
Lmk if i should/can add docs somewhere.

I look forward to reviewing your future PRs 🔥

@D4ve-R
Copy link
Contributor Author

D4ve-R commented Feb 28, 2024

Thank you!! This was really fun!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants