# ü¶ú VieNeu-TTS Fine-tuning Notebook

Notebook n√†y t·ªïng h·ª£p to√†n b·ªô code training cho **VieNeu-TTS-0.3B**.  
B·∫°n c√≥ th·ªÉ thay sang **VieNeu-TTS** ·ªü ph·∫ßn `training_config` (m·ª•c 6) n·∫øu mu·ªën.

Trong qu√° tr√¨nh training, n·∫øu g·∫∑p l·ªói ho·∫∑c c√≥ g√≥p √Ω, vui l√≤ng t·∫°o **issue** tr√™n GitHub:  
https://github.com/pnnbao97/VieNeu-TTS  

Ho·∫∑c li√™n h·ªá tr·ª±c ti·∫øp v·ªõi t√°c gi·∫£ **Ph·∫°m Nguy·ªÖn Ng·ªçc B·∫£o** qua:
- Email: pnnbao@gmail.com  
- Facebook: https://www.facebook.com/bao.phamnguyenngoc.5

## üì¶ 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q transformers peft torch datasets librosa soundfile tqdm phonemizer
!pip install -q git+https://github.com/Neuphonic/NeuCodec.git

In [None]:
!apt install espeak-ng -y

## üîß 2. Setup Utils 

In [None]:
!git clone https://github.com/pnnbao97/VieNeu-TTS

In [None]:
import sys
import os
from pathlib import Path

def setup_vieneu_tts():
    """Universal setup for VieNeu-TTS - works on any platform"""
    
    # Find VieNeu-TTS
    search_paths = [
        "/root/VieNeu-TTS",
        "/content/VieNeu-TTS",
        "./VieNeu-TTS",
        "../VieNeu-TTS",
    ]
    
    vieneu_path = None
    for path in search_paths:
        if os.path.exists(path) and os.path.exists(os.path.join(path, "utils")):
            vieneu_path = os.path.abspath(path)
            break
    
    if not vieneu_path:
        raise FileNotFoundError(
            "VieNeu-TTS not found! Clone it:\n"
            "  git clone https://github.com/pnnbao97/VieNeu-TTS"
        )
    
    # Clean and add to path
    sys.path = [p for p in sys.path if "VieNeu-TTS" not in p]
    sys.path.insert(0, vieneu_path)
    
    print(f"‚úÖ VieNeu-TTS: {vieneu_path}")
    
    from vieneu_utils.normalize_text import VietnameseTTSNormalizer
    from vieneu_utils.phonemize_text import phonemize_with_dict
    
    # Initialize normalizer
    normalizer = VietnameseTTSNormalizer()
    
    # Create wrapper function
    def normalize_text(text):
        return normalizer.normalize(text)
    
    print("‚úÖ Utils loaded!")
    
    return vieneu_path, normalize_text, phonemize_with_dict

# Run setup
VIENEU_PATH, normalize_text, phonemize_with_dict = setup_vieneu_tts()

# ========== TEST ==========
def preprocess_text(text):
    """Complete preprocessing pipeline"""
    normalized = normalize_text(text)
    phonemes = phonemize_with_dict(normalized)
    return {
        "original": text,
        "normalized": normalized,
        "phonemes": phonemes
    }

# Quick test
result = preprocess_text("T√¥i c√≥ 2.000 m·∫´u audio, gi√° 5.000.000ƒë")
print(f"\nüìù Test result:")
for key, val in result.items():
    print(f"  {key:12s}: {val}")

print("\nüéâ Ready to use!")

## üì• 3. Download Sample Data

T·∫£i d·ªØ li·ªáu m·∫´u t·ª´ Hugging Face (ho·∫∑c thay b·∫±ng dataset c·ªßa b·∫°n).  
Trong notebook n√†y, ch√∫ng t√¥i s·ª≠ d·ª•ng b·ªô d·ªØ li·ªáu m·∫´u:  
https://huggingface.co/datasets/pnnbao-ump/ngochuyen_voice  

Dataset n√†y ƒë∆∞·ª£c d√πng ƒë·ªÉ training gi·ªçng ƒë·ªçc **Ng·ªçc Huy·ªÅn (Vbee)** v√† **kh√¥ng n·∫±m trong b·ªô VieNeu-TTS-1000h**,  
v√¨ v·∫≠y r·∫•t ph√π h·ª£p ƒë·ªÉ l√†m v√≠ d·ª• minh h·ªça cho qu√° tr√¨nh fine-tuning.

In [None]:
import io
from datasets import load_dataset, Audio
import soundfile as sf
from tqdm import tqdm

def download_sample_data(output_dir="dataset", num_samples=10):
    raw_audio_dir = os.path.join(output_dir, "raw_audio")
    metadata_path = os.path.join(output_dir, "metadata.csv")
    
    os.makedirs(raw_audio_dir, exist_ok=True)
    
    print(f"üîÑ ƒêang t·∫£i dataset t·ª´ Hugging Face...")
    dataset = load_dataset("pnnbao-ump/ngochuyen_voice", split="train")
    dataset = dataset.cast_column("audio", Audio(decode=False))
    
    print(f"‚úÖ B·∫Øt ƒë·∫ßu l∆∞u {num_samples} m·∫´u...")
    
    with open(metadata_path, 'w', encoding='utf-8') as f:
        count = 0
        for sample in tqdm(dataset, total=num_samples):
            if count >= num_samples:
                break
            
            try:
                audio_data = sample["audio"]
                audio_bytes = audio_data["bytes"]
                audio_array, sampling_rate = sf.read(io.BytesIO(audio_bytes))
                
                text = sample["transcription"]
                original_filename = sample.get("file_name", f"sample_{count:03d}.wav")
                filename = os.path.basename(original_filename)
                
                file_path = os.path.join(raw_audio_dir, filename)
                sf.write(file_path, audio_array, sampling_rate)
                
                f.write(f"{filename}|{text}\n")
                count += 1
            except Exception as e:
                print(f"\n‚ö†Ô∏è L·ªói m·∫´u {count}: {e}")
                continue
    
    print(f"\nü¶ú Ho√†n t·∫•t! ƒê√£ t·∫°o {count} m·∫´u t·∫°i {output_dir}")
    return metadata_path

# Download data (thay ƒë·ªïi num_samples theo nhu c·∫ßu)
metadata_path = download_sample_data(output_dir="dataset", num_samples=7000)

## üßπ 4. Filter Data

L·ªçc d·ªØ li·ªáu k√©m ch·∫•t l∆∞·ª£ng (audio h·ªèng, text r√°c, qu√° ng·∫Øn/d√†i)

In [None]:
import re
ACRONYM = re.compile(r"(?:[a-zA-Z]\.){2,}")
ACRONYM_NO_PERIOD = re.compile(r"(?:[A-Z]){2,}")

def text_filter(text: str) -> bool:
    if not text: return False
    if re.search(r"\d", text): return False
    if ACRONYM.search(text) or ACRONYM_NO_PERIOD.search(text): return False
    if text[-1] not in ".,?!": return False
    return True

def filter_dataset(dataset_dir="dataset"):
    metadata_path = os.path.join(dataset_dir, "metadata.csv")
    cleaned_path = os.path.join(dataset_dir, "metadata_cleaned.csv")
    raw_audio_dir = os.path.join(dataset_dir, "raw_audio")
    
    if not os.path.exists(metadata_path):
        print(f"‚ùå Kh√¥ng t√¨m th·∫•y {metadata_path}")
        return
    
    print("üßπ B·∫Øt ƒë·∫ßu l·ªçc d·ªØ li·ªáu...")
    
    valid_samples = []
    skipped = {"audio_not_found": 0, "audio_error": 0, "duration_out_of_range": 0, "text_invalid": 0}
    
    with open(metadata_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    for line in tqdm(lines, desc="Filtering"):
        parts = line.strip().split('|')
        if len(parts) < 2:
            continue
        
        filename = parts[0]
        text = parts[1]
        file_path = os.path.join(raw_audio_dir, filename)
        
        if not os.path.exists(file_path):
            skipped["audio_not_found"] += 1
            continue
        
        try:
            info = sf.info(file_path)
            duration = info.duration
            
            if not (3.0 <= duration <= 15.0):
                skipped["duration_out_of_range"] += 1
                continue
        except Exception:
            skipped["audio_error"] += 1
            continue
        
        if not text_filter(text):
            skipped["text_invalid"] += 1
            continue
        
        valid_samples.append(f"{filename}|{text}\n")
    
    with open(cleaned_path, 'w', encoding='utf-8') as f:
        f.writelines(valid_samples)
    
    print(f"\nü¶ú K·∫æT QU·∫¢ L·ªåC:")
    print(f"   - T·ªïng: {len(lines)} | H·ª£p l·ªá: {len(valid_samples)} ({len(valid_samples)/len(lines)*100:.1f}%)")
    print(f"   - Lo·∫°i b·ªè: {sum(skipped.values())} ({skipped})")
    print(f"‚úÖ ƒê√£ l∆∞u: {cleaned_path}")
    return cleaned_path

cleaned_metadata_path = filter_dataset(dataset_dir="dataset")

## üîä 5. Encode Audio to VQ Codes

S·ª≠ d·ª•ng NeuCodec ƒë·ªÉ encode audio th√†nh vector quantized codes

In [None]:
import torch
import librosa
from neucodec import NeuCodec
import json
import random

def encode_dataset(dataset_dir="dataset", max_samples=2000):
    metadata_path = os.path.join(dataset_dir, "metadata_cleaned.csv")
    if not os.path.exists(metadata_path):
        print(f"ü¶ú Kh√¥ng t√¨m th·∫•y metadata_cleaned.csv, d√πng metadata.csv...")
        metadata_path = os.path.join(dataset_dir, "metadata.csv")
    
    output_path = os.path.join(dataset_dir, "metadata_encoded.csv")
    raw_audio_dir = os.path.join(dataset_dir, "raw_audio")
    
    if not os.path.exists(metadata_path):
        print("ü¶ú Kh√¥ng t√¨m th·∫•y metadata!")
        return
    
    print("ü¶ú ƒêang t·∫£i NeuCodec model...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    codec = NeuCodec.from_pretrained("neuphonic/neucodec").to(device)
    codec.eval()
    
    print(f"ü¶ú Encode t·ªëi ƒëa {max_samples} m·∫´u (device: {device})")
    
    lines_to_write = []
    skipped_count = 0
    
    with open(metadata_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    # Shuffle v√† l·∫•y max_samples
    random.shuffle(lines)
    if len(lines) > max_samples:
        lines = lines[:max_samples]
    
    for line in tqdm(lines, desc="Encoding"):
        parts = line.strip().split('|')
        if len(parts) < 2:
            continue
        
        filename = parts[0]
        text = parts[1]
        audio_path = os.path.join(raw_audio_dir, filename)
        
        if not os.path.exists(audio_path):
            skipped_count += 1
            continue
        
        try:
            wav, sr = librosa.load(audio_path, sr=16000, mono=True)
            wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
            
            with torch.no_grad():
                codes = codec.encode_code(wav_tensor)
                codes = codes.squeeze(0).squeeze(0).cpu().numpy().flatten().tolist()
                codes = [int(x) for x in codes]
            
            # Validate
            if not codes or not all(0 <= c < 65536 for c in codes):
                print(f"ü¶ú Invalid codes: {filename}")
                skipped_count += 1
                continue
            
            codes_json = json.dumps(codes)
            lines_to_write.append(f"{filename}|{text}|{codes_json}\n")
            
        except Exception as e:
            print(f"ü¶ú L·ªói {filename}: {e}")
            skipped_count += 1
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.writelines(lines_to_write)
    
    print(f"\nü¶ú Ho√†n t·∫•t! ƒê√£ encode {len(lines_to_write)} m·∫´u")
    print(f"   - L∆∞u t·∫°i: {output_path}")
    print(f"   - B·ªè qua: {skipped_count}")
    return output_path

encoded_metadata_path = encode_dataset(dataset_dir="dataset", max_samples=2000)

## üéØ 6. Setup Training

C·∫•u h√¨nh LoRA v√† Training Arguments

In [None]:
from peft import LoraConfig, TaskType, get_peft_model
from transformers import TrainingArguments

# LoRA Config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# Training Config
training_config = {
    'model': "pnnbao-ump/VieNeu-TTS-0.3B",
    'run_name': "VieNeu-TTS-LoRA",
    'output_dir': "output",
    'per_device_train_batch_size': 2,
    'gradient_accumulation_steps': 1,
    'learning_rate': 2e-4,
    'max_steps': 5000,  # Gi·∫£m ƒë·ªÉ test nhanh
    'logging_steps': 50,
    'save_steps': 500,
    'eval_steps': 500,
    'warmup_ratio': 0.05,
    'bf16': True,
}

def get_training_args(config):
    return TrainingArguments(
        output_dir=os.path.join(config['output_dir'], config['run_name']),
        do_train=True,
        do_eval=True,
        max_steps=config['max_steps'],
        per_device_train_batch_size=config['per_device_train_batch_size'],
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        learning_rate=config['learning_rate'],
        warmup_ratio=config['warmup_ratio'],
        bf16=config['bf16'],
        logging_steps=config['logging_steps'],
        save_steps=config['save_steps'],
        eval_strategy="steps",
        eval_steps=config['eval_steps'],
        save_strategy="steps",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        report_to="none",
        dataloader_num_workers=2,  # Gi·∫£m ƒë·ªÉ tr√°nh l·ªói
        ddp_find_unused_parameters=False,
    )

print("‚úÖ Training config ready!")

## üìä 7. Dataset Class & Preprocessing

In [None]:
from torch.utils.data import Dataset

def preprocess_sample(sample, tokenizer, max_len=2048):
    speech_gen_start = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_START|>')
    ignore_index = -100
    
    phones = sample["phones"]
    vq_codes = sample["codes"]
    
    codes_str = "".join([f"<|speech_{i}|>" for i in vq_codes])
    chat = f"""user: Convert the text to speech:<|TEXT_PROMPT_START|>{phones}<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}<|SPEECH_GENERATION_END|>"""
    
    ids = tokenizer.encode(chat)
    
    # Pad/truncate
    if len(ids) < max_len:
        ids = ids + [tokenizer.pad_token_id] * (max_len - len(ids))
    elif len(ids) > max_len:
        ids = ids[:max_len]
    
    input_ids = torch.tensor(ids, dtype=torch.long)
    labels = torch.full_like(input_ids, ignore_index)
    
    # Mask labels before speech generation
    speech_gen_start_idx = (input_ids == speech_gen_start).nonzero(as_tuple=True)[0]
    if len(speech_gen_start_idx) > 0:
        speech_gen_start_idx = speech_gen_start_idx[0]
        labels[speech_gen_start_idx:] = input_ids[speech_gen_start_idx:]
    
    attention_mask = (input_ids != tokenizer.pad_token_id).long()
    
    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attention_mask
    }

class VieNeuDataset(Dataset):
    def __init__(self, metadata_path, tokenizer, max_len=2048):
        self.samples = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        if not os.path.exists(metadata_path):
            raise FileNotFoundError(f"Missing: {metadata_path}")
        
        with open(metadata_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('|')
                if len(parts) >= 3:
                    self.samples.append({
                        "filename": parts[0],
                        "text": parts[1],
                        "codes": json.loads(parts[2])
                    })
        
        print(f"ü¶ú Loaded {len(self.samples)} samples from {metadata_path}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = sample["text"]
        
        try:
            phones = phonemize_with_dict(text)
        except Exception as e:
            print(f"‚ö†Ô∏è Phonemization error: {e}")
            phones = text
        
        data_item = {"phones": phones, "codes": sample["codes"]}
        return preprocess_sample(data_item, self.tokenizer, self.max_len)

print("‚úÖ Dataset class ready!")

## üöÄ 8. Train Model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator

model_name = training_config['model']
print(f"ü¶ú Loading model: {model_name}")

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype=torch.bfloat16,
    device_map="auto"
)

# Load Dataset
dataset_path = encoded_metadata_path  # From earlier step
full_dataset = VieNeuDataset(dataset_path, tokenizer)

# Train/Eval split (5%)
val_size = max(1, int(0.05 * len(full_dataset)))
train_size = len(full_dataset) - val_size
train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

print(f"ü¶ú Train: {len(train_dataset)} | Eval: {len(eval_dataset)}")

# Apply LoRA
print("ü¶ú Applying LoRA...")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Trainer
args = get_training_args(training_config)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

print("ü¶ú Starting training! (Good luck)")
trainer.train()

# Save
save_path = os.path.join(training_config['output_dir'], training_config['run_name'])
print(f"ü¶ú Saving model to: {save_path}")
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print("‚úÖ Training complete!")

## ü¶ú Done!

Model ƒë√£ ƒë∆∞·ª£c fine-tune v√† l∆∞u t·∫°i `output/VieNeu-TTS-LoRA/`.

B·∫°n c√≥ th·ªÉ s·ª≠ d·ª•ng checkpoint n√†y ƒë·ªÉ:
- Inference / generate speech
- Merge LoRA v√†o model g·ªëc
- Ti·∫øp t·ª•c fine-tuning v·ªõi dataset kh√°c

In [39]:
import os
from huggingface_hub import (
    HfApi,
    create_repo,
    upload_folder
)

# ===================== CONFIG =====================
HF_USERNAME = "pnnbao-ump"  # ‚ö†Ô∏è ƒë·ªïi
REPO_NAME = "VieNeu-TTS-0.3B-lora-ngoc-huyen"
LOCAL_LORA_DIR = "output/VieNeu-TTS-LoRA"
BASE_MODEL = "pnnbao-ump/VieNeu-TTS-0.3B"
DATASET_URL = "https://huggingface.co/datasets/pnnbao-ump/ngochuyen_voice"

# ===================== README =====================
README_CONTENT = f"""
---
language: vi
license: cc-by-nc-4.0
base_model: {BASE_MODEL}
library_name: peft
tags:
  - lora
  - text-to-speech
  - tts
  - vietnamese
  - vieneu-tts
---

# ü¶ú VieNeu-TTS-LoRA (Ng·ªçc Huy·ªÅn)

LoRA adapter ƒë∆∞·ª£c fine-tune t·ª´ base model **VieNeu-TTS-0.3B**
ƒë·ªÉ hu·∫•n luy·ªán gi·ªçng ƒë·ªçc **Ng·ªçc Huy·ªÅn (Vbee)**.  

Code finetune VieNeu-TTS t·∫°i repo: https://github.com/pnnbao97/VieNeu-TTS

---

## üîó Base Model
- Base model: `{BASE_MODEL}`
- Repo n√†y **ch·ªâ ch·ª©a LoRA adapter**, kh√¥ng bao g·ªìm model g·ªëc.

---

## üì¶ Dataset
- {DATASET_URL}

---

## üöÄ Usage

```python
from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(
    "{BASE_MODEL}",
    device_map="auto"
)

model = PeftModel.from_pretrained(
    base_model,
    "{HF_USERNAME}/{REPO_NAME}"
)

## Credits

Base model: Ph·∫°m Nguy·ªÖn Ng·ªçc B·∫£o
LORA finetuning: Ph·∫°m Nguy·ªÖn Ng·ªçc B·∫£o
"""

In [None]:
from huggingface_hub import login

login(token="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxx") # Huggingface Token c·ªßa b·∫°n - ƒë·∫£m b·∫£o c√≥ quy·ªÅn write

In [40]:
repo_id = f"{HF_USERNAME}/{REPO_NAME}"
print(f"ü¶ú Creating repo: {repo_id}")
create_repo(
    repo_id=repo_id,
    repo_type="model",
    exist_ok=True
)

# Write README.md
readme_path = os.path.join(LOCAL_LORA_DIR, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
    f.write(README_CONTENT.strip())

print("ü¶ú Uploading LoRA adapter to Hugging Face...")
upload_folder(
    folder_path=LOCAL_LORA_DIR,
    repo_id=repo_id,
    repo_type="model",
    commit_message="Upload VieNeu-TTS LoRA adapter"
)

print("‚úÖ Upload completed successfully!")
print(f"üîó https://huggingface.co/{repo_id}")

ü¶ú Creating repo: pnnbao-ump/VieNeu-TTS-0.3B-lora-ngoc-huyen
ü¶ú Uploading LoRA adapter to Hugging Face...


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...-LoRA/checkpoint-4500/rng_state.pth: 100%|##########| 14.6kB / 14.6kB            

  ...S-LoRA/checkpoint-4500/scheduler.pt: 100%|##########| 1.47kB / 1.47kB            

  ...A/checkpoint-4500/training_args.bin: 100%|##########| 5.78kB / 5.78kB            

  ...-LoRA/checkpoint-5000/rng_state.pth: 100%|##########| 14.6kB / 14.6kB            

  ...S-LoRA/checkpoint-5000/scheduler.pt: 100%|##########| 1.47kB / 1.47kB            

  ...A/checkpoint-5000/training_args.bin: 100%|##########| 5.78kB / 5.78kB            

  ...-TTS-LoRA/adapter_model.safetensors: 100%|##########| 12.8MB / 12.8MB            

  ...oint-5000/adapter_model.safetensors: 100%|##########| 12.8MB / 12.8MB            

  ...S-LoRA/checkpoint-4500/optimizer.pt: 100%|##########| 25.8MB / 25.8MB            

  ...S-LoRA/checkpoint-5000/optimizer.pt: 100%|##########| 25.8MB / 25.8MB            

‚úÖ Upload completed successfully!
üîó https://huggingface.co/pnnbao-ump/VieNeu-TTS-0.3B-lora-ngoc-huyen
