Skip to content

Commit

Permalink
Merge pull request #17 from sensein/dev
Browse files Browse the repository at this point in the history
adding voice cloning
  • Loading branch information
fabiocat93 authored May 14, 2024
2 parents a987d46 + 8447e12 commit ebbfa6d
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 1 deletion.
15 changes: 15 additions & 0 deletions scripts/experiment4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""This script is used to test the voice cloning task."""
from senselab.audio.tasks.preprocessing import resample_hf_dataset
from senselab.audio.tasks.voice_cloning import clone_voice_in_dataset_with_KNNVC
from senselab.utils.tasks.input_output import read_files_from_disk

dataset = read_files_from_disk(["/Users/fabiocat/Documents/git/sensein/senselab/src/tests/data_for_testing/audio_48khz_mono_16bits.wav"])

print("Resampling dataset...")
dataset = resample_hf_dataset(dataset, 16000)
print("Resampled dataset.")

cloned_dataset = clone_voice_in_dataset_with_KNNVC(dataset, dataset)

print("cloned_dataset")
#print(cloned_dataset)
68 changes: 68 additions & 0 deletions src/senselab/audio/tasks/voice_cloning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""This module implements some utilities for the voice cloning task."""
from typing import Any, Dict, Optional

import torch
from datasets import Dataset

from senselab.utils.functions import DeviceType, _select_device_and_dtype
from senselab.utils.tasks.input_output import _from_dict_to_hf_dataset, _from_hf_dataset_to_dict


def clone_voice_in_dataset_with_KNNVC(
source_dataset: Dict[str, Any],
target_dataset: Dict[str, Any],
source_audio_column: str = 'audio',
target_audio_column: str = 'audio',
model_id: str = "bshall/knn-vc",
model_revision: str = 'master',
prematched_vocoder: bool = True,
topk: int = 4,
device: Optional[DeviceType] = None,
) -> Dict[str, Any]:
"""Clones the voice in the dataset using KNNVC."""
def _setup_knn_vc_model(
model_id: str,
model_revision: str,
prematched_vocoder: bool,
device: Optional[DeviceType] = None
) -> Any: # noqa: ANN401
"""Prepare a KNNVC pipeline."""
repo_id = f"{model_id}:{model_revision}"
device, torch_dtype = _select_device_and_dtype(device_options=[device] if device else [DeviceType.CUDA, DeviceType.CPU])
knn_vc = torch.hub.load(repo_id, 'knn_vc', prematched=prematched_vocoder, trust_repo=True, pretrained=True, device=device.value)
return knn_vc, device, torch_dtype

def _clone_voice_in_row_with_KNNVC(
source_row: Dataset,
target_dataset: Dataset,
knn_vc_model: Any, # noqa: ANN401
torch_dtype: torch.dtype,
source_audio_column: str = 'audio',
target_audio_column: str = 'audio'
) -> Dict[str, torch.Tensor]:
def _get_waveform(dataset: Dataset, column: str) -> torch.Tensor:

audio = dataset[column]
waveform = torch.tensor(audio['array'], dtype=torch_dtype)
sampling_rate = audio['sampling_rate']
if sampling_rate != 16000:
raise ValueError(f"{column} sampling rate {sampling_rate} is not supported. Only 16kHz sampling rates are supported.")
return waveform

source_waveform = _get_waveform(source_row, source_audio_column)
target_waveform = _get_waveform(target_dataset[0], target_audio_column)

query_seq = knn_vc_model.get_features(source_waveform)
matching_set = knn_vc_model.get_matching_set([target_waveform])
out_wav = knn_vc_model.match(query_seq, matching_set, topk=topk)
return {"cloned_waveform": out_wav}

hf_source_dataset = _from_dict_to_hf_dataset(source_dataset, audio_columns=[source_audio_column])
hf_target_dataset = _from_dict_to_hf_dataset(target_dataset, audio_columns=[target_audio_column])

knn_vc, device, torch_dtype = _setup_knn_vc_model(model_id=model_id, model_revision=model_revision, prematched_vocoder=prematched_vocoder, device=device)

cloned_dataset = hf_source_dataset.map(lambda x: _clone_voice_in_row_with_KNNVC(x, hf_target_dataset, knn_vc, torch_dtype, source_audio_column, target_audio_column))
cloned_dataset = cloned_dataset.remove_columns([source_audio_column])

return _from_hf_dataset_to_dict(cloned_dataset)
6 changes: 6 additions & 0 deletions src/senselab/audio/tasks/voice_cloning_pydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""This module defines a pydra API for the voice cloning task."""
import pydra

from senselab.audio.tasks.voice_cloning import clone_voice_in_dataset_with_KNNVC

clone_voice_in_dataset_with_KNNVC_pt = pydra.mark.task(clone_voice_in_dataset_with_KNNVC)
Empty file added src/senselab/text/.gitkeep
Empty file.
2 changes: 1 addition & 1 deletion src/senselab/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _select_device_and_dtype(device_options: list[DeviceType] = [DeviceType.CPU,
torch_dtype = torch.float16 # Using half precision for CUDA
elif torch.backends.mps.is_available() and DeviceType.MPS in device_options:
device = DeviceType.MPS
torch_dtype = torch.float16 # Using half precision for MPS if suitable
torch_dtype = torch.float32 # Default to float32 on MPS for better precision
else:
device = DeviceType.CPU
torch_dtype = torch.float32 # Default to float32 on CPU for better precision
Expand Down

0 comments on commit ebbfa6d

Please sign in to comment.