In [3]:
# Cell 2: Imports
import os
import json
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Any

# Core ML/Data Libraries
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import librosa
import numpy as np
import datasets
from datasets import load_dataset, VerificationMode
from tqdm.auto import tqdm

# Hugging Face Trainer
from transformers import (
    HfArgumentParser,
    set_seed,
    TrainerCallback,
    Trainer,
    PretrainedConfig,
    TrainingArguments as HfTrainingArguments,
    # get_last_checkpoint
)

# Language and Text Processing
from langdetect import detect
import pykakasi
from num2words import num2words # For our Spanish example

import sys
import os

src_path = os.path.abspath(os.path.join(os.getcwd(), "../../src"))
if src_path not in sys.path:
    sys.path.insert(0, src_path)
print("src path added:", src_path)

# Chatterbox Specific Imports
from chatterbox.tts import ChatterboxTTS, punc_norm, REPO_ID
from chatterbox.models.t3.t3 import T3, T3Cond
from chatterbox.models.t3.modules.t3_config import T3Config
from chatterbox.models.s3tokenizer import S3_SR
from chatterbox.models.s3gen.s3gen import S3Token2Mel
# from chatterbox.utils.training_args import CustomTrainingArguments

src path added: d:\Code\voice_clone\local_test\src


In [4]:
from typing import Optional
from dataclasses import dataclass, field
from transformers.training_args import TrainingArguments as HfTrainingArguments

# --- Custom Training Arguments ---
@dataclass
class CustomTrainingArguments(HfTrainingArguments):
    early_stopping_patience: Optional[int] = field(
        default=None, metadata={"help": "Enable early stopping with specified patience. Default: None (disabled)."}
    )
    use_torch_profiler: bool = field(
        default=False, metadata={"help": "Enable PyTorch profiler and dump traces to TensorBoard."}
    )
    dataloader_persistent_workers: bool = field(
        default=True, metadata={"help": "Use persistent workers for the dataloader."}
    )

In [None]:
import csv
import json
from pathlib import Path

# Paths
audio_dir = Path("dataset/bengali_audio_files")
manifest_path = Path("dataset/bengali_audio_manifest.csv")

# Read manifest and create JSON files
with open(manifest_path, encoding="utf-8") as f:
    reader = csv.DictReader(f)
    for row in reader:
        audio_filename = row.get("audio_filepath") or row.get("audio_path") or row.get("wav") or row.get("filename")
        text = row.get("text") or row.get("transcript") or row.get("sentence")
        if not audio_filename or not text:
            continue
        audio_path = audio_dir / Path(audio_filename).name
        json_path = audio_path.with_suffix(".json")
        with open(json_path, "w", encoding="utf-8") as jf:
            json.dump({"text": text}, jf, ensure_ascii=False, indent=2)
        print(f"Created {json_path}")

print("All JSON files created.")

In [29]:
# Cell 3: Configuration
# -------------------
# MODEL ARGUMENTS
# -------------------
# Here we define where to get the base model from.
# We'll use the default Chatterbox model from the Hugging Face Hub.
@dataclass
class ModelArguments:
    model_name_or_path: str = "ResembleAI/chatterbox"
    cache_dir: Optional[str] = None
    freeze_voice_encoder: bool = True
    freeze_s3gen: bool = True

model_args = ModelArguments()

# -------------------
# DATA ARGUMENTS
# -------------------
# Here we define where your data is and how to process it.
# ➡️ **CHANGE `dataset_dir` to the path of your audio/text folder.**
@dataclass
class DataArguments:
    dataset_dir: str = "D:/Code/voice_clone/local_test/custom_train/dataset/bengali_audio_files" # ⬅️ CHANGE THIS
    eval_split_size: float = 0.05 # Use 5% of the data for validation
    max_text_len: int = 256
    max_speech_len: int = 800
    audio_prompt_duration_s: float = 3.0

data_args = DataArguments()

# -------------------
# TRAINING ARGUMENTS
# -------------------
# Here we define the training parameters (batch size, learning rate, etc.)
# ➡️ **CHANGE `output_dir` to where you want to save checkpoints.**
training_args = CustomTrainingArguments(
    output_dir="custom_train/notebooks/checkpoints_bengali", # ⬅️ CHANGE THIS
    num_train_epochs=10,
    per_device_train_batch_size=4,         # Lower this if you run out of GPU memory (e.g., to 2)
    gradient_accumulation_steps=4,         # Effective batch size = 4 * 4 = 16
    learning_rate=5e-5,
    warmup_steps=500,
    logging_steps=20,
    save_strategy="epoch",
    save_total_limit=2,
    fp16=True,                             # Set to False if your GPU doesn't support it
    report_to="tensorboard",
    dataloader_num_workers=4,
    do_train=True,
    do_eval=True,
    eval_strategy="epoch",
)

# Setup logger
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

In [35]:
# Cell 4: SpeechFineTuningDataset Class
class SpeechFineTuningDataset(Dataset):
    def __init__(self,
                 data_args: DataArguments,
                 t3_config: T3Config,
                 dataset_source: List[Dict[str, str]],
                 chatterbox_model):
        self.data_args = data_args
        self.chatterbox_t3_config = t3_config
        self.dataset_source = dataset_source
        self.chatterbox_model = chatterbox_model
        self.text_tokenizer = self.chatterbox_model.tokenizer
        self.speech_tokenizer = self.chatterbox_model.s3gen.tokenizer
        self.voice_encoder = self.chatterbox_model.ve
        self.s3_sr = S3_SR
        self.enc_cond_audio_len_samples = int(data_args.audio_prompt_duration_s * self.s3_sr)

    def __len__(self):
        return len(self.dataset_source)

    def __getitem__(self, idx) -> Optional[Dict[str, Union[torch.Tensor, float]]]:
        item = self.dataset_source[idx]
        audio_path = item["audio"]
        text = item["text"]
        
        try:
            wav_16k, _ = librosa.load(audio_path, sr=self.s3_sr, mono=True)
            if wav_16k is None or len(wav_16k) == 0: return None
        except Exception as e:
            logger.error(f"Error loading audio {audio_path}: {e}")
            return None

        speaker_emb_np = self.voice_encoder.embeds_from_wavs([wav_16k], sample_rate=self.s3_sr)
        speaker_emb = torch.from_numpy(speaker_emb_np[0])

        normalized_text = punc_norm(text)
        
        # ==========================================================
        # ⬇️ THIS IS WHERE YOU ADD/MODIFY LANGUAGE-SPECIFIC LOGIC ⬇️
        # ==========================================================
        try:
            lang = detect(normalized_text)
        except:
            lang = "en" # Default to English if detection fails

        if lang == "ja":
            pka_converter = pykakasi.kakasi()
            pka_converter.setMode("J","H"); pka_converter.setMode("K","H"); pka_converter.setMode("H","H")
            conv = pka_converter.getConverter()
            normalized_text = conv.do(normalized_text)
        elif lang == "fr":
            normalized_text = "[fr] " + normalized_text
        elif lang == "de":
            normalized_text = "[de] " + normalized_text
        
        # --- OUR NEW SPANISH CODE ---
        elif lang == "es":
            # 1. Add the language ID token. The model will learn this means "speak in Spanish".
            normalized_text = "[es] " + normalized_text
            
            # 2. (Optional but recommended) Add text normalization rules.
            # Example: convert "123" to "ciento veintitrés".
            import re
            def expand_numbers_es(text):
                return re.sub(r'(\d+)', lambda m: num2words(int(m.group(1)), lang='es'), text)
            normalized_text = expand_numbers_es(normalized_text)
        # --- END OF NEW LANGUAGE CODE ---

        raw_text_tokens = self.text_tokenizer.text_to_tokens(normalized_text).squeeze(0)
        text_tokens = F.pad(raw_text_tokens, (1, 0), value=self.chatterbox_t3_config.start_text_token)
        text_tokens = F.pad(text_tokens, (0, 1), value=self.chatterbox_t3_config.stop_text_token)
        if len(text_tokens) > self.data_args.max_text_len:
            text_tokens = text_tokens[:self.data_args.max_text_len-1]
            text_tokens = torch.cat([text_tokens, torch.tensor([self.chatterbox_t3_config.stop_text_token])])
        text_token_len = torch.tensor(len(text_tokens), dtype=torch.long)

        raw_speech_tokens, speech_lens = self.speech_tokenizer.forward([wav_16k])
        if raw_speech_tokens is None: return None
        raw_speech_tokens = raw_speech_tokens.squeeze(0)[:speech_lens.squeeze(0).item()]
        
        speech_tokens = F.pad(raw_speech_tokens, (1, 0), value=self.chatterbox_t3_config.start_speech_token)
        speech_tokens = F.pad(speech_tokens, (0, 1), value=self.chatterbox_t3_config.stop_speech_token)
        if len(speech_tokens) > self.data_args.max_speech_len:
            speech_tokens = speech_tokens[:self.data_args.max_speech_len-1]
            speech_tokens = torch.cat([speech_tokens, torch.tensor([self.chatterbox_t3_config.stop_speech_token])])
        speech_token_len = torch.tensor(len(speech_tokens), dtype=torch.long)

        cond_audio = wav_16k[:self.enc_cond_audio_len_samples]
        cond_prompt, _ = self.speech_tokenizer.forward([cond_audio], max_len=self.chatterbox_t3_config.speech_cond_prompt_len)
        cond_prompt = cond_prompt.squeeze(0) if cond_prompt is not None else torch.zeros(self.chatterbox_t3_config.speech_cond_prompt_len, dtype=torch.long)

        if cond_prompt.size(0) != self.chatterbox_t3_config.speech_cond_prompt_len:
             cond_prompt = F.pad(cond_prompt, (0, self.chatterbox_t3_config.speech_cond_prompt_len - cond_prompt.size(0)), value=0)

        return {
            "text_tokens": text_tokens.long(), "text_token_lens": text_token_len.long(),
            "speech_tokens": speech_tokens.long(), "speech_token_lens": speech_token_len.long(),
            "t3_cond_speaker_emb": speaker_emb.float(),
            "t3_cond_prompt_speech_tokens": cond_prompt.long(),
            "t3_cond_emotion_adv": torch.tensor(0.5, dtype=torch.float),
        }


In [36]:
# Cell 5: Data Collator
@dataclass
class SpeechDataCollator:
    t3_config: T3Config
    text_pad_token_id: int
    speech_pad_token_id: int

    def __call__(self, features: List[Optional[Dict[str, Any]]]) -> Dict[str, Any]:
        valid_features = [f for f in features if f is not None]
        if not valid_features: return {}
        
        text_tokens = [f["text_tokens"] for f in valid_features]
        speech_tokens = [f["speech_tokens"] for f in valid_features]
        max_text_len = max(len(t) for t in text_tokens)
        max_speech_len = max(len(s) for s in speech_tokens)

        padded_text = torch.stack([F.pad(t, (0, max_text_len - len(t)), value=self.text_pad_token_id) for t in text_tokens])
        padded_speech = torch.stack([F.pad(s, (0, max_speech_len - len(s)), value=self.speech_pad_token_id) for s in speech_tokens])
        
        IGNORE_ID = -100
        labels_text = padded_text[:, 1:].clone()
        labels_text[labels_text == self.text_pad_token_id] = IGNORE_ID
        
        labels_speech = padded_speech[:, 1:].clone()
        labels_speech[labels_speech == self.speech_pad_token_id] = IGNORE_ID

        return {
            "text_tokens": padded_text,
            "text_token_lens": torch.stack([f["text_token_lens"] for f in valid_features]),
            "speech_tokens": padded_speech,
            "speech_token_lens": torch.stack([f["speech_token_lens"] for f in valid_features]),
            "t3_cond_speaker_emb": torch.stack([f["t3_cond_speaker_emb"] for f in valid_features]),
            "t3_cond_prompt_speech_tokens": torch.stack([f["t3_cond_prompt_speech_tokens"] for f in valid_features]),
            "t3_cond_emotion_adv": torch.stack([f["t3_cond_emotion_adv"] for f in valid_features]).view(len(valid_features), 1, 1),
            "labels_text": labels_text,
            "labels_speech": labels_speech,
        }

# Cell 6: Model Wrapper for Hugging Face Trainer
class T3ForFineTuning(torch.nn.Module):
    def __init__(self, t3_model: T3, chatterbox_t3_config: T3Config):
        super().__init__()
        self.t3 = t3_model
        self.chatterbox_t3_config = chatterbox_t3_config
        class HFCompatibleConfig(PretrainedConfig): model_type = "chatterbox_t3_finetune"
        self.config = HFCompatibleConfig()

    def forward(self, labels_text=None, labels_speech=None, **kwargs):
        current_t3_cond = T3Cond(
            speaker_emb=kwargs["t3_cond_speaker_emb"],
            cond_prompt_speech_tokens=kwargs["t3_cond_prompt_speech_tokens"],
            emotion_adv=kwargs["t3_cond_emotion_adv"]
        ).to(device=self.t3.device)
        
        loss_text, loss_speech, _ = self.t3.loss(
            t3_cond=current_t3_cond, text_tokens=kwargs["text_tokens"], text_token_lens=kwargs["text_token_lens"],
            speech_tokens=kwargs["speech_tokens"], speech_token_lens=kwargs["speech_token_lens"],
            labels_text=labels_text, labels_speech=labels_speech
        )
        return {"loss": loss_text + loss_speech}
        
# Cell 7: Logging Callback
class DetailedLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and 'loss' in logs:
            logger.info(f"Step {state.global_step}: Loss = {logs['loss']:.4f}, LR = {logs.get('learning_rate', 'N/A')}")

In [38]:
# Cell 8: Setup Logging and Seed
logger.info("Training parameters %s", training_args)
set_seed(training_args.seed)

# Cell 9: Load Pre-trained Chatterbox Model
logger.info(f"Loading base model from {model_args.model_name_or_path}...")
chatterbox_model = ChatterboxTTS.from_pretrained(
    # model_args.model_name_or_path,
    # save_dir=download_dir,
    device="cpu"
)
t3_model = chatterbox_model.t3
chatterbox_t3_config_instance = t3_model.hp

# Cell 10: Freeze Model Layers
logger.info("Freezing specified model layers...")
if model_args.freeze_voice_encoder:
    for param in chatterbox_model.ve.parameters(): param.requires_grad = False
    logger.info("Voice Encoder frozen.")
if model_args.freeze_s3gen:
    for param in chatterbox_model.s3gen.parameters(): param.requires_grad = False
    logger.info("S3Gen model frozen.")
for param in t3_model.parameters(): param.requires_grad = True
logger.info("T3 model set to trainable.")

# Cell 11: Load Your Local Dataset Files
def load_local_dataset(dataset_dir: str) -> List[Dict[str, str]]:
    dataset_path = Path(dataset_dir)
    json_files = list(dataset_path.glob("**/*.json"))
    files = []
    logger.info(f"Found {len(json_files)} JSON files in {dataset_path}.")
    for json_file in tqdm(json_files, desc="Loading dataset files"):
        with open(json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        text = data.get("text", "").strip()
        # Find audio file with the same name but different extension (.wav, .mp3, etc.)
        audio_path = next(dataset_path.glob(f"{json_file.stem}.*"), None)
        if audio_path and audio_path.exists() and text:
            files.append({"audio": str(audio_path), "text": text})
    return files

all_files = load_local_dataset(data_args.dataset_dir)
if not all_files:
    raise ValueError("No data files found! Check your `dataset_dir` path and file structure.")
logger.info(f"Successfully loaded {len(all_files)} audio-text pairs.")

np.random.shuffle(all_files)
split_idx = int(len(all_files) * (1 - data_args.eval_split_size))
train_files, eval_files = all_files[:split_idx], all_files[split_idx:]
logger.info(f"Split dataset: {len(train_files)} for training, {len(eval_files)} for evaluation.")

# Cell 12: Create PyTorch Datasets and Collator
logger.info("Initializing datasets and data collator...")
train_dataset = SpeechFineTuningDataset(
    data_args,
    chatterbox_t3_config_instance,
    train_files,
    chatterbox_model=chatterbox_model,
)

eval_dataset = SpeechFineTuningDataset(
    data_args,
    chatterbox_t3_config_instance,
    eval_files,
    chatterbox_model=chatterbox_model,
)

data_collator = SpeechDataCollator(
    chatterbox_t3_config_instance,
    chatterbox_t3_config_instance.stop_text_token,
    chatterbox_t3_config_instance.stop_speech_token
)

# Cell 13: Initialize the Trainer
logger.info("Initializing the Hugging Face Trainer...")
hf_trainable_model = T3ForFineTuning(t3_model, chatterbox_t3_config_instance)

trainer = Trainer(
    model=hf_trainable_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    callbacks=[DetailedLoggingCallback()],
)


09/18/2025 19:02:29 - INFO - __main__ - Training parameters CustomTrainingArguments(
_n_gpu=0,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
average_tokens_across_devices=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=4,
dataloader_persistent_workers=True,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=True,
early_stopping_patience=None,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_b

loaded PerthNet (Implicit) at step 250,000


Loading dataset files: 100%|██████████| 91/91 [00:00<00:00, 2261.51it/s]
09/18/2025 19:02:39 - INFO - __main__ - Successfully loaded 91 audio-text pairs.
09/18/2025 19:02:39 - INFO - __main__ - Split dataset: 86 for training, 5 for evaluation.
09/18/2025 19:02:39 - INFO - __main__ - Initializing datasets and data collator...
09/18/2025 19:02:39 - INFO - __main__ - Initializing the Hugging Face Trainer...
Loading dataset files: 100%|██████████| 91/91 [00:00<00:00, 2261.51it/s]
09/18/2025 19:02:39 - INFO - __main__ - Successfully loaded 91 audio-text pairs.
09/18/2025 19:02:39 - INFO - __main__ - Split dataset: 86 for training, 5 for evaluation.
09/18/2025 19:02:39 - INFO - __main__ - Initializing datasets and data collator...
09/18/2025 19:02:39 - INFO - __main__ - Initializing the Hugging Face Trainer...


In [None]:
# Cell 14: Run Training
logger.info("*** Starting T3 model fine-tuning ***")
train_result = trainer.train()
logger.info("*** Training finished ***")

09/18/2025 19:03:28 - INFO - __main__ - *** Starting T3 model fine-tuning ***
  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
# Cell 15: Save Final Model Artifacts
logger.info("Saving the final fine-tuned model...")

# 1. Save the fine-tuned T3 weights
final_output_dir = Path(training_args.output_dir) / "final_model"
final_output_dir.mkdir(parents=True, exist_ok=True)
output_t3_path = final_output_dir / "t3_cfg.safetensors"

t3_to_save = trainer.model.t3
from safetensors.torch import save_file
save_file(t3_to_save.state_dict(), output_t3_path)
logger.info(f"T3 weights saved to {output_t3_path}")

# 2. Copy the other essential model files
import shutil
files_to_copy = ["ve.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]
for f_name in files_to_copy:
    src_path = download_dir / f_name
    if src_path.exists():
        shutil.copy2(src_path, final_output_dir / f_name)
        logger.info(f"Copied {f_name} to final model directory.")

logger.info(f"✅ Complete fine-tuned model saved in: {final_output_dir}")

In [None]:
from chatterbox.tts import ChatterboxTTS
import soundfile as sf
from pathlib import Path

# 1. Path to your fine-tuned model directory
MODEL_PATH = Path("/path/to/your/checkpoints/final_model")

# 2. Path to a short audio clip of the target speaker's voice
# This is used to create the speaker embedding. It can be any clip from your training data.
SPEAKER_WAV_PATH = "/path/to/your/spanish_dataset/some_audio_file.wav"

# 3. Load your fine-tuned model
print("Loading the fine-tuned model...")
model = ChatterboxTTS.from_local(ckpt_dir=str(MODEL_PATH), device="cuda") # or "cpu"

# 4. Define the text you want to generate
# IMPORTANT: Use the language token you trained on!
text_to_generate_bengali = "[bn] আমি বাংলায় কথা বলি।"
text_to_generate_english = "Hello, this is a test in English." # This should also work!

# 5. Generate the audio
print(f"Generating audio for: {text_to_generate_bengali}")
wav_bengali = model.generate(
    text=text_to_generate_bengali,
    speaker_wav=SPEAKER_WAV_PATH
)

print(f"Generating audio for: {text_to_generate_english}")
wav_english = model.generate(
    text=text_to_generate_english,
    speaker_wav=SPEAKER_WAV_PATH
)

# 6. Save the audio to a file
sf.write("output_bengali.wav", wav_bengali, samplerate=24000)
sf.write("output_english.wav", wav_english, samplerate=24000)

print("✅ Audio files saved successfully!")