# Fine-tune KaniTTS cho Ti·∫øng Vi·ªát

**Datasets:** 4 subset t·ª´ VieNeu-TTS-500h (NanoCodec)
- Nam mi·ªÅn Nam, N·ªØ mi·ªÅn Nam, Nam mi·ªÅn B·∫Øc, N·ªØ mi·ªÅn B·∫Øc
- M·ªói subset ƒë√£ ƒë∆∞·ª£c encode NanoCodec (22kHz, 0.6kbps, 12.5fps)

---


## üì¶ Step 1: Install Dependencies

C√†i ƒë·∫∑t c√°c th∆∞ vi·ªán c·∫ßn thi·∫øt cho fine-tuning.


In [1]:
# Install dependencies
%uv pip install -q transformers==4.56.0 accelerate==1.10.1 datasets==3.6.0 omegaconf peft wandb
%uv pip install -q flash-attn --no-build-isolation

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from omegaconf import OmegaConf
import wandb
import os
import gc
from datetime import datetime

# Check GPU
print("=" * 70)
print("üîç GPU INFORMATION")
print("=" * 70)
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")
print("=" * 70)

üîç GPU INFORMATION
CUDA Available: True
GPU: NVIDIA L40S
VRAM: 50.9 GB
CUDA Version: 12.9
PyTorch Version: 2.8.0+cu129
Transformers Version: 4.56.0


## üîê Step 2: Login to HuggingFace & Weights & Biases


In [None]:
import os
from huggingface_hub import login
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("‚úÖ Login th√†nh c√¥ng t·ª´ environment variable!")
else:
    print("‚ö†Ô∏è HF_TOKEN not found in environment variables")


üîê HuggingFace Login
L·∫•y token t·ª´: https://huggingface.co/settings/tokens

‚úÖ Login th√†nh c√¥ng!


## ‚öôÔ∏è Step 3: Configuration

C·∫•u h√¨nh model v√† dataset cho training.


In [None]:
# ===== CONFIGURATION =====

# Base model
BASE_MODEL = "pnnbao-ump/kani-tts-370m-vie"

# Dataset configuration (4 subset: Nam/N·ªØ x Mi·ªÅn Nam/B·∫Øc)
DATASET_CONFIG = {
    "max_duration_sec": 12,  # Ho·∫∑c None n·∫øu mu·ªën gi·ªØ nguy√™n ƒë·ªô d√†i
    "hf_datasets": [
        {
            "reponame": "pnnbao-ump/VieNeu-TTS-500h-nanocodec-male-south",
            "name": None,
            "split": "train",
            "text_col_name": "text",
            "nano_layer_1": "nano_layer_1",
            "nano_layer_2": "nano_layer_2",
            "nano_layer_3": "nano_layer_3",
            "nano_layer_4": "nano_layer_4",
            "encoded_len": "encoded_len",
            "speaker_id": "nam_mien_nam",
            "max_len": None,
            "categorical_filter": None,
        },
        {
            "reponame": "pnnbao-ump/VieNeu-TTS-500h-nanocodec-female-south",
            "name": None,
            "split": "train",
            "text_col_name": "text",
            "nano_layer_1": "nano_layer_1",
            "nano_layer_2": "nano_layer_2",
            "nano_layer_3": "nano_layer_3",
            "nano_layer_4": "nano_layer_4",
            "encoded_len": "encoded_len",
            "speaker_id": "nu_mien_nam",
            "max_len": None,
            "categorical_filter": None,
        },
        {
            "reponame": "pnnbao-ump/VieNeu-TTS-500h-nanocodec-male-north",
            "name": None,
            "split": "train",
            "text_col_name": "text",
            "nano_layer_1": "nano_layer_1",
            "nano_layer_2": "nano_layer_2",
            "nano_layer_3": "nano_layer_3",
            "nano_layer_4": "nano_layer_4",
            "encoded_len": "encoded_len",
            "speaker_id": "nam_mien_bac",
            "max_len": None,
            "categorical_filter": None,
        },
        {
            "reponame": "pnnbao-ump/VieNeu-TTS-500h-nanocodec-female-north",
            "name": None,
            "split": "train",
            "text_col_name": "text",
            "nano_layer_1": "nano_layer_1",
            "nano_layer_2": "nano_layer_2",
            "nano_layer_3": "nano_layer_3",
            "nano_layer_4": "nano_layer_4",
            "encoded_len": "encoded_len",
            "speaker_id": "nu_mien_bac",
            "max_len": None,
            "categorical_filter": None,
        },
    ],
}

# LoRA Configuration
LORA_CONFIG = {
    "r": 48,
    "lora_alpha": 48,
    "lora_dropout": 0.1,
    "target_modules": ['q_proj', 'k_proj', 'v_proj', 'out_proj',
                   'w1', 'w2', 'w3'],
    "bias": "none",
    "task_type": "CAUSAL_LM",
    "use_rslora": True
}

# Training Configuration
TRAINING_CONFIG = {
    "num_train_epochs": 3,
    "per_device_train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "learning_rate": 2e-5,
    "lr_scheduler_type": "cosine",
    "warmup_ratio": 0.1,
    "weight_decay": 0.02,
    "max_grad_norm": 1.0,
    "bf16": True,
    "optim": "adamw_torch",
    "logging_steps": 20,
    "save_steps": 500,
    "save_total_limit": 3,
    "output_dir": "./checkpoints/kani_tts_vi",
    "report_to": "wandb",
    "run_name": f"kani-tts-vi-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
    "gradient_checkpointing": True, 
    "gradient_checkpointing_kwargs": {"use_reentrant": False},
}

# W&B Project
WANDB_PROJECT = "kani-tts-vi-finetune"

print("‚úÖ Configuration loaded!")
print(f"   Base Model: {BASE_MODEL}")
print(f"   Max Duration: {DATASET_CONFIG['max_duration_sec']}s")
print(f"   S·ªë dataset: {len(DATASET_CONFIG['hf_datasets'])}")
for ds in DATASET_CONFIG["hf_datasets"]:
    print(f"     - {ds['speaker_id']}: {ds['reponame']}")
print(f"   LoRA Rank: {LORA_CONFIG['r']}")
print(f"   Batch Size: {TRAINING_CONFIG['per_device_train_batch_size']}")
print(f"   Epochs: {TRAINING_CONFIG['num_train_epochs']}")

‚úÖ Configuration loaded!
   Base Model: nineninesix/kani-tts-370m
   Dataset: pnnbao-ump/VieNeu-TTS-140h-nanocodec
   Max Duration: Nones
   LoRA Rank: 48
   Batch Size: 8
   Epochs: 3


## üìä Step 4: Load and Process Dataset

Load dataset t·ª´ HuggingFace v√† x·ª≠ l√Ω cho training.


In [6]:
import urllib.request
try:
    url = "https://raw.githubusercontent.com/nineninesix-ai/KaniTTS-Finetune-pipeline/main/dataset_processor.py"
    urllib.request.urlretrieve(url, "dataset_processor.py")
    print("‚úÖ ƒê√£ download dataset_processor.py")
except Exception as e:
    print(f"‚ùå Kh√¥ng th·ªÉ download: {e}")
    print("üí° S·∫Ω s·ª≠ d·ª•ng code t√≠ch h·ª£p trong notebook")


‚úÖ ƒê√£ download dataset_processor.py


In [None]:
import sys
import os
from torch.utils.data import Dataset
import torch
from datasets import load_dataset, concatenate_datasets
from omegaconf import OmegaConf
from transformers import AutoTokenizer
import locale
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np

locale.getpreferredencoding = lambda: "UTF-8"

# TrainDataPreProcessor class
class TrainDataPreProcessor:
    def __init__(self, tokenizer_name: str, max_dur: int, speaker_id: str = None):
        self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_dur = max_dur
        self.speaker_id = speaker_id
        
        # Token configuration
        self.tokeniser_length = 64400
        self.start_of_text = 1
        self.end_of_text = 2
        self.start_of_speech = self.tokeniser_length + 1
        self.end_of_speech = self.tokeniser_length + 2
        self.start_of_human = self.tokeniser_length + 3
        self.end_of_human = self.tokeniser_length + 4
        self.start_of_ai = self.tokeniser_length + 5
        self.end_of_ai = self.tokeniser_length + 6
        self.pad_token = self.tokeniser_length + 7
        self.audio_tokens_start = self.tokeniser_length + 10
        self.codebook_size = 4032
    
    def add_codes(self, example):
        snac_layers = ['nano_layer_1', 'nano_layer_2', 'nano_layer_3', 'nano_layer_4']
        codes = [example[i] for i in snac_layers]
        codes = np.array(codes).T
        all_codes = codes + np.array([self.codebook_size * i for i in range(4)])
        
        # remove duplicates
        mask = np.any(all_codes[1:] != all_codes[:-1], axis=1)
        keep = np.insert(mask, 0, True)
        all_codes = all_codes[keep]
        
        # flatten to sequence
        all_codes = all_codes + self.audio_tokens_start
        example["codes_list"] = all_codes.flatten().tolist()
        return example
    
    def create_input_ids(self, example):
        if self.speaker_id is not None:
            text_prompt = f"{self.speaker_id.lower()}: {example['text']}"
        else:
            text_prompt = example["text"]
        
        text_ids = self.text_tokenizer.encode(text_prompt, add_special_tokens=True)
        text_ids.append(self.end_of_text)
        
        example["text_tokens"] = text_ids
        input_ids = (
            [self.start_of_human]
            + example["text_tokens"]
            + [self.end_of_human]
            + [self.start_of_ai]
            + [self.start_of_speech]
            + example["codes_list"]
            + [self.end_of_speech]
            + [self.end_of_ai]
        )
        example["input_ids"] = input_ids
        example["labels"] = input_ids
        example["attention_mask"] = [1] * len(input_ids)
        return example
    
    def __call__(self, dataset: Dataset) -> Dataset:
        print(f'üîÑ Processing dataset with {len(dataset)} samples...')
        
        if self.max_dur:
            fps = 12.5
            dataset_len = len(dataset)
            dataset = dataset.filter(lambda i: i['encoded_len']/fps <= self.max_dur)
            filtred_len = len(dataset)
            print(f'üìä Filtered by duration: {filtred_len} rows from {dataset_len}')
        
        dataset = dataset.map(self.add_codes, remove_columns=['nano_layer_1', 'nano_layer_2', 'nano_layer_3', 'nano_layer_4'], desc='Add Audio Codes')
        dataset = dataset.filter(lambda x: x["codes_list"] is not None, desc='Check codes list')
        dataset = dataset.filter(lambda x: len(x["codes_list"]) > 0, desc='Check Codes list length')
        dataset = dataset.map(self.create_input_ids, remove_columns=["text", "codes_list"], desc='Create input ids')
        
        columns_to_keep = ["input_ids", "labels", "attention_mask"]
        columns_to_remove = [col for col in dataset.column_names if col not in columns_to_keep]
        dataset = dataset.remove_columns(columns_to_remove)
        
        print(f'‚úÖ Processing completed: {len(dataset)} samples')
        return dataset

def process_shard(shard_idx, shard_data, tokenizer_name, max_dur, speaker_id):
    print(f'üöÄ WORKER {shard_idx}: Starting processing...')
    processor = TrainDataPreProcessor(tokenizer_name, max_dur, speaker_id)
    processed_shard = processor(shard_data)
    print(f'‚úÖ WORKER {shard_idx}: Completed processing')
    return processed_shard

class ItemDataset:
    def __init__(self, item_cfg: OmegaConf, tokenizer_name: str, max_dur: int, n_shards: int = None):
        print(f'üì¶ Loading dataset "{item_cfg.reponame}"...')
        self.item_cfg = item_cfg
        self.tokenizer_name = tokenizer_name
        self.max_dur = max_dur
        self.speaker_id = item_cfg.get('speaker_id')
        self.max_len = item_cfg.get('max_len')
        
        if n_shards is None:
            self.n_shards = min(mp.cpu_count(), 8)
        else:
            self.n_shards = n_shards
        
        self.dataset = load_dataset(
            item_cfg.reponame,
            item_cfg.name,
            split=item_cfg.split,
            num_proc=10
        )
        
        print(f'üìä Loaded {len(self.dataset)} samples')
        
        if item_cfg.get('categorical_filter'):
            print(f'üîß Filtering by {item_cfg.categorical_filter.column_name} = {item_cfg.categorical_filter.value}')
            self.dataset = self.dataset.filter(
                lambda x: x[item_cfg.categorical_filter.column_name] == item_cfg.categorical_filter.value
            )
            print(f'‚úÖ Filtered to {len(self.dataset)} samples')
        
        print(f'üîÑ Renaming columns...')
        rename_dict = {
            item_cfg.text_col_name: 'text',
            item_cfg.nano_layer_1: 'nano_layer_1',
            item_cfg.nano_layer_2: 'nano_layer_2',
            item_cfg.nano_layer_3: 'nano_layer_3',
            item_cfg.nano_layer_4: 'nano_layer_4',
            item_cfg.encoded_len: 'encoded_len',
        }
        self.dataset = self.dataset.rename_columns(rename_dict)
        print(f'‚úÖ Column renaming completed')
    
    def __call__(self):
        print(f'üîÑ Starting parallel processing with {self.n_shards} shards...')
        
        shards = []
        for i in range(self.n_shards):
            shard = self.dataset.shard(num_shards=self.n_shards, index=i)
            shards.append((shard, i))
            print(f'üì¶ SHARD {i}: Created with {len(shard)} samples')
        
        processed_shards = []
        with ProcessPoolExecutor(max_workers=self.n_shards) as executor:
            future_to_shard = {
                executor.submit(process_shard, shard_idx, shard, self.tokenizer_name, self.max_dur, self.speaker_id): shard_idx
                for shard, shard_idx in shards
            }
            
            for future in as_completed(future_to_shard):
                shard_idx = future_to_shard[future]
                try:
                    processed_shard = future.result()
                    processed_shards.append((shard_idx, processed_shard))
                    print(f'‚úÖ COMPLETED: Shard {shard_idx} processing finished')
                except Exception as exc:
                    print(f'‚ùå ERROR: Shard {shard_idx} generated an exception: {exc}')
                    raise exc
        
        processed_shards.sort(key=lambda x: x[0])
        final_shards = [shard for _, shard in processed_shards]
        
        print(f'üîó Concatenating {len(final_shards)} processed shards...')
        final_dataset = concatenate_datasets(final_shards)
        
        if self.max_len is not None:
            final_dataset = final_dataset.shuffle(seed=42).select(range(self.max_len))
        
        print(f'‚úÖ Dataset processing completed! Final size: {len(final_dataset)} samples')
        return final_dataset

# Create dataset config
print("=" * 70)
print("üì• LOADING AND PROCESSING DATASETS")
print("=" * 70)
print(f"Max duration (seconds): {DATASET_CONFIG['max_duration_sec']}")
print(f"T·ªïng s·ªë dataset s·∫Ω load: {len(DATASET_CONFIG['hf_datasets'])}")
print("=" * 70)

dataset_summaries = []
processed_datasets = []

for ds_idx, ds_cfg in enumerate(DATASET_CONFIG["hf_datasets"], start=1):
    item_cfg_dict = {
        "reponame": ds_cfg["reponame"],
        "name": ds_cfg.get("name"),
        "split": ds_cfg.get("split", "train"),
        "text_col_name": ds_cfg["text_col_name"],
        "nano_layer_1": ds_cfg["nano_layer_1"],
        "nano_layer_2": ds_cfg["nano_layer_2"],
        "nano_layer_3": ds_cfg["nano_layer_3"],
        "nano_layer_4": ds_cfg["nano_layer_4"],
        "encoded_len": ds_cfg["encoded_len"],
        "speaker_id": ds_cfg.get("speaker_id"),
        "max_len": ds_cfg.get("max_len"),
        "categorical_filter": ds_cfg.get("categorical_filter"),
    }

    item_cfg = OmegaConf.create(item_cfg_dict)

    print("\n" + "-" * 70)
    print(f"üéØ Dataset {ds_idx}: {ds_cfg['reponame']}")
    print(f"    Speaker ID prompt : {ds_cfg.get('speaker_id')}")
    if ds_cfg.get("categorical_filter"):
        filt = ds_cfg["categorical_filter"]
        print(f"    Filter: {filt['column_name']} = {filt['value']}")
    if ds_cfg.get("max_len"):
        print(f"    Max samples: {ds_cfg['max_len']}")

    item_dataset = ItemDataset(
        item_cfg=item_cfg,
        tokenizer_name=BASE_MODEL,
        max_dur=DATASET_CONFIG["max_duration_sec"],
        n_shards=8,
    )

    subset_dataset = item_dataset()
    subset_dataset = subset_dataset.shuffle(seed=42)

    processed_datasets.append(subset_dataset)
    dataset_summaries.append(
        {
            "speaker_id": ds_cfg.get("speaker_id"),
            "reponame": ds_cfg["reponame"],
            "num_samples": len(subset_dataset),
        }
    )

if not processed_datasets:
    raise RuntimeError("Kh√¥ng load ƒë∆∞·ª£c dataset n√†o. Ki·ªÉm tra c·∫•u h√¨nh DATASET_CONFIG.")

full_dataset = concatenate_datasets(processed_datasets)
full_dataset = full_dataset.shuffle(seed=42)

# ‚úÖ TH√äM: Split train/validation (10% cho validation)
VALIDATION_SPLIT = 0.1  # 10% cho validation
total_samples = len(full_dataset)
val_samples = int(total_samples * VALIDATION_SPLIT)

train_dataset = full_dataset.select(range(total_samples - val_samples))
val_dataset = full_dataset.select(range(total_samples - val_samples, total_samples))

print("\n" + "=" * 70)
print("‚úÖ DATASET READY!")
print("=" * 70)
print(f"Total samples: {total_samples:,}")
print(f"Train samples: {len(train_dataset):,} ({100*(1-VALIDATION_SPLIT):.1f}%)")
print(f"Validation samples: {len(val_dataset):,} ({100*VALIDATION_SPLIT:.1f}%)")
print(f"Features: {train_dataset.column_names}")
print("-" * 70)
for summary in dataset_summaries:
    print(
        f"   ‚Ä¢ {summary['speaker_id']}: {summary['num_samples']:,} samples ({summary['reponame']})"
    )
print("=" * 70)
print("\nüí° Validation set s·∫Ω ƒë∆∞·ª£c d√πng ƒë·ªÉ:")
print("   - Ph√°t hi·ªán overfitting (quan tr·ªçng v·ªõi LoRA rank cao)")
print("   - Ch·ªçn best model t·ª± ƒë·ªông")
print("   - Monitor training progress")
print("=" * 70)


üì• LOADING AND PROCESSING DATASET
Dataset: pnnbao-ump/VieNeu-TTS-140h-nanocodec
Max samples: None
Speaker ID: None
Max duration: Nones
üì¶ Loading dataset "pnnbao-ump/VieNeu-TTS-140h-nanocodec"...


README.md:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/58.3M [00:00<?, ?B/s]

Setting num_proc from 10 back to 1 for the train split to disable multiprocessing as it only contains one shard.


Generating train split:   0%|          | 0/74858 [00:00<?, ? examples/s]

üìä Loaded 74858 samples
üîÑ Renaming columns...
‚úÖ Column renaming completed
üîÑ Starting parallel processing with 8 shards...
üì¶ SHARD 0: Created with 9358 samples
üì¶ SHARD 1: Created with 9358 samples
üì¶ SHARD 2: Created with 9357 samples
üì¶ SHARD 3: Created with 9357 samples
üì¶ SHARD 4: Created with 9357 samples
üì¶ SHARD 5: Created with 9357 samples
üì¶ SHARD 6: Created with 9357 samples
üì¶ SHARD 7: Created with 9357 samples
üöÄ WORKER 5: Starting processing...üöÄ WORKER 4: Starting processing...üöÄ WORKER 3: Starting processing...üöÄ WORKER 6: Starting processing...üöÄ WORKER 2: Starting processing...üöÄ WORKER 1: Starting processing...üöÄ WORKER 7: Starting processing...üöÄ WORKER 0: Starting processing...









tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/434 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

üîÑ Processing dataset with 9357 samples...
üîÑ Processing dataset with 9357 samples...üîÑ Processing dataset with 9357 samples...

üîÑ Processing dataset with 9357 samples...
üîÑ Processing dataset with 9358 samples...
üîÑ Processing dataset with 9358 samples...
üîÑ Processing dataset with 9357 samples...üîÑ Processing dataset with 9357 samples...



Add Audio Codes:   0%|          | 0/9357 [00:00<?, ? examples/s]

Add Audio Codes:   0%|          | 0/9357 [00:00<?, ? examples/s]

Add Audio Codes:   0%|          | 0/9357 [00:00<?, ? examples/s]

Add Audio Codes:   0%|          | 0/9357 [00:00<?, ? examples/s]

Add Audio Codes:   0%|          | 0/9358 [00:00<?, ? examples/s]

Add Audio Codes:   0%|          | 0/9358 [00:00<?, ? examples/s]

Add Audio Codes:   0%|          | 0/9357 [00:00<?, ? examples/s]

Add Audio Codes:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check codes list:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check codes list:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check codes list:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check codes list:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check codes list:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check codes list:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check codes list:   0%|          | 0/9358 [00:00<?, ? examples/s]

Check codes list:   0%|          | 0/9358 [00:00<?, ? examples/s]

Check Codes list length:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check Codes list length:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check Codes list length:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check Codes list length:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check Codes list length:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check Codes list length:   0%|          | 0/9357 [00:00<?, ? examples/s]

Check Codes list length:   0%|          | 0/9358 [00:00<?, ? examples/s]

Check Codes list length:   0%|          | 0/9358 [00:00<?, ? examples/s]

Create input ids:   0%|          | 0/9357 [00:00<?, ? examples/s]

Create input ids:   0%|          | 0/9357 [00:00<?, ? examples/s]

Create input ids:   0%|          | 0/9357 [00:00<?, ? examples/s]

Create input ids:   0%|          | 0/9357 [00:00<?, ? examples/s]

Create input ids:   0%|          | 0/9357 [00:00<?, ? examples/s]

Create input ids:   0%|          | 0/9357 [00:00<?, ? examples/s]

Create input ids:   0%|          | 0/9358 [00:00<?, ? examples/s]

Create input ids:   0%|          | 0/9358 [00:00<?, ? examples/s]

‚úÖ Processing completed: 9357 samples
‚úÖ WORKER 4: Completed processing
‚úÖ Processing completed: 9357 samples
‚úÖ WORKER 6: Completed processing
‚úÖ Processing completed: 9357 samples
‚úÖ WORKER 2: Completed processing
‚úÖ Processing completed: 9357 samples‚úÖ Processing completed: 9357 samples

‚úÖ WORKER 7: Completed processing‚úÖ WORKER 3: Completed processing

‚úÖ Processing completed: 9357 samples
‚úÖ WORKER 5: Completed processing‚úÖ Processing completed: 9358 samples‚úÖ Processing completed: 9358 samples


‚úÖ WORKER 1: Completed processing‚úÖ WORKER 0: Completed processing

‚úÖ COMPLETED: Shard 4 processing finished
‚úÖ COMPLETED: Shard 6 processing finished
‚úÖ COMPLETED: Shard 2 processing finished
‚úÖ COMPLETED: Shard 7 processing finished
‚úÖ COMPLETED: Shard 3 processing finished
‚úÖ COMPLETED: Shard 5 processing finished
‚úÖ COMPLETED: Shard 1 processing finished
‚úÖ COMPLETED: Shard 0 processing finished
üîó Concatenating 8 processed shards...
‚úÖ Dataset processing 

## üß† Step 5: Load Model and Tokenizer

Load base model v√† tokenizer, sau ƒë√≥ apply LoRA.


In [8]:
print("üìö Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
print(f"   Vocab size: {tokenizer.vocab_size}")

print("\nüß† Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    torch_dtype=torch.bfloat16,  # H100 h·ªó tr·ª£ t·ªët bf16
    attn_implementation="flash_attention_2"  # T·∫≠n d·ª•ng Flash Attention tr√™n H100
)

print(f"\n‚úÖ Model loaded!")
print(f"   Device: {next(model.parameters()).device}")
print(f"   Dtype: {next(model.parameters()).dtype}")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Check GPU memory
if torch.cuda.is_available():
    print(f"\nüíæ GPU Memory:")
    print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"   Reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
    print(f"   Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)) / 1e9:.2f} GB")


üìö Loading tokenizer...
   Vocab size: 64400

üß† Loading model...


config.json: 0.00B [00:00, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.48G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]


‚úÖ Model loaded!
   Device: cuda:0
   Dtype: torch.bfloat16
   Parameters: 369,847,040

üíæ GPU Memory:
   Allocated: 0.75 GB
   Reserved: 0.79 GB
   Free: 50.07 GB


In [9]:
print("üîß Applying LoRA...")
lora_config = LoraConfig(**LORA_CONFIG)
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\n‚úÖ LoRA applied!")
print(f"   Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
print(f"   Total parameters: {total_params:,}")

# Check GPU memory after LoRA
if torch.cuda.is_available():
    print(f"\nüíæ GPU Memory (after LoRA):")
    print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"   Reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")


üîß Applying LoRA...

‚úÖ LoRA applied!
   Trainable parameters: 16,023,552 (4.15%)
   Total parameters: 385,870,592

üíæ GPU Memory (after LoRA):
   Allocated: 0.81 GB
   Reserved: 0.86 GB


## üöÄ Step 6: Initialize Training

Setup trainer v√† b·∫Øt ƒë·∫ßu training.


In [None]:
# Initialize W&B
wandb.init(
    project=WANDB_PROJECT,
    name=TRAINING_CONFIG["run_name"],
    config={
        **TRAINING_CONFIG,
        **LORA_CONFIG,
        "base_model": BASE_MODEL,
        "max_duration_sec": DATASET_CONFIG["max_duration_sec"],
        "dataset_repos": [ds["reponame"] for ds in DATASET_CONFIG["hf_datasets"]],
        "speaker_ids": [ds.get("speaker_id") for ds in DATASET_CONFIG["hf_datasets"]],
        "max_samples_per_repo": [ds.get("max_len") for ds in DATASET_CONFIG["hf_datasets"]],
    }
)

# Ensure tokenizer has pad_token
pad_token_id = 64407
if tokenizer.pad_token is None:
    tokenizer.pad_token_id = pad_token_id
    print(f"‚úÖ Set pad_token_id to {pad_token_id}")

# Custom Data Collator (GI·ªêNG C≈®, KH√îNG ƒê·ªîI)
from transformers import DataCollatorForLanguageModeling
import torch

class CustomDataCollator(DataCollatorForLanguageModeling):
    def __call__(self, features):
        labels = [f.pop("labels") for f in features] 
        
        batch = self.tokenizer.pad(
            features,
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        
        max_length = batch["input_ids"].shape[1]
        padded_labels = []
        for label in labels:
            padded_label = label + [-100] * (max_length - len(label))
            padded_labels.append(padded_label)
        
        batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)
        
        return batch

data_collator = CustomDataCollator(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8,
)

# ‚úÖ T√çNH TO√ÅN EVAL_STEPS V√Ä SAVE_STEPS
steps_per_epoch = len(train_dataset) // (
    TRAINING_CONFIG['per_device_train_batch_size'] * 
    TRAINING_CONFIG['gradient_accumulation_steps']
)
eval_steps = max(250, steps_per_epoch // 4)  # ‚úÖ Eval 4 l·∫ßn m·ªói epoch

print("=" * 70)
print("üìä TRAINING CONFIGURATION")
print("=" * 70)
print(f"Training samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Eval steps: {eval_steps} (will eval ~{steps_per_epoch // eval_steps} times per epoch)")
print(f"Total training steps: {steps_per_epoch * TRAINING_CONFIG['num_train_epochs']}")
print("=" * 70)

# ‚úÖ TRAINING ARGUMENTS V·ªöI FULL VALIDATION SUPPORT
from transformers import TrainingArguments, EarlyStoppingCallback

training_config_dict = TRAINING_CONFIG.copy()
training_config_dict.pop('save_steps', None)
training_config_dict.pop('save_total_limit', None)

training_args = TrainingArguments(
    **training_config_dict,
    overwrite_output_dir=True,
    remove_unused_columns=True,
    
    # ‚úÖ VALIDATION & EVALUATION
    eval_strategy="steps",
    eval_steps=eval_steps,
    do_eval=True,  # ‚úÖ TH√äM: Explicitly enable evaluation
    
    # ‚úÖ CHECKPOINTING
    save_strategy="steps",
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=3,
    
    # ‚úÖ MONITORING
    logging_first_step=True,  # ‚úÖ Log ngay step ƒë·∫ßu
    eval_on_start=True,  # ‚úÖ Eval tr∆∞·ªõc khi train ƒë·ªÉ c√≥ baseline
)

print("\nüóÇÔ∏è  Creating trainer...")

# ‚úÖ CHECK MODEL STATE
from peft import PeftModel
if not isinstance(model, PeftModel):
    print("‚ö†Ô∏è Model ch∆∞a ƒë∆∞·ª£c apply LoRA! ƒêang apply l·∫°i...")
    lora_config = LoraConfig(**LORA_CONFIG)
    model = get_peft_model(model, lora_config)
    print("‚úÖ LoRA ƒë√£ ƒë∆∞·ª£c apply l·∫°i!")

model.train()
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   Trainable parameters: {trainable_params:,}")
if trainable_params == 0:
    raise RuntimeError("‚ùå Model kh√¥ng c√≥ trainable parameters! Ki·ªÉm tra l·∫°i LoRA config.")

# ‚úÖ EARLY STOPPING CALLBACK
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,  # Stop if no improvement for 3 evals
    early_stopping_threshold=0.001  # Minimum improvement threshold
)

# ‚úÖ CREATE TRAINER WITH EARLY STOPPING
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[early_stopping_callback],  # ‚úÖ Add early stopping
)

print(f"\n‚úÖ Trainer ready!")
print(f"   Total training steps: {trainer.state.max_steps if hasattr(trainer.state, 'max_steps') else steps_per_epoch * TRAINING_CONFIG['num_train_epochs']}")
print(f"   Eval strategy: {training_args.eval_strategy}")
print(f"   Eval steps: {eval_steps}")
print(f"   Early stopping patience: 3 evaluations")
print(f"   Gradient checkpointing: {TRAINING_CONFIG.get('gradient_checkpointing', False)}")
print("=" * 70)

[34m[1mwandb[0m: Tracking run with wandb version 0.22.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/root/wandb/run-20251107_020414-69fqyi6w[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mkani-tts-vi-20251107-020347[0m
[34m[1mwandb[0m: ‚≠êÔ∏è View project at [34m[4mhttps://wandb.ai/scream3ktr6-ewyss/vietnamese-tts-finetune[0m
[34m[1mwandb[0m: üöÄ View run at [34m[4mhttps://wandb.ai/scream3ktr6-ewyss/vietnamese-tts-finetune/runs/69fqyi6w[0m
  trainer = Trainer(
Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


üìä TRAINING CONFIGURATION
Training samples: 67,373
Validation samples: 7,485
Steps per epoch: 2105
Eval steps: 526 (will eval ~4 times per epoch)
Total training steps: 6315

üóÇÔ∏è  Creating trainer...
   Trainable parameters: 16,023,552

‚úÖ Trainer ready!
   Total training steps: 0
   Eval strategy: IntervalStrategy.STEPS
   Eval steps: 526
   Early stopping patience: 3 evaluations
   Gradient checkpointing: True


## üéØ Step 7: Start Training


In [11]:
import time

print("=" * 70)
print("üöÄ STARTING TRAINING")
print("=" * 70)
print(f"Dataset size: {len(train_dataset):,} samples")
print(f"Validation size: {len(val_dataset):,} samples")
print(f"Batch size: {TRAINING_CONFIG['per_device_train_batch_size']}")
print(f"Gradient accumulation: {TRAINING_CONFIG['gradient_accumulation_steps']}")
print(f"Effective batch size: {TRAINING_CONFIG['per_device_train_batch_size'] * TRAINING_CONFIG['gradient_accumulation_steps']}")
print(f"Epochs: {TRAINING_CONFIG['num_train_epochs']}")
print(f"Eval steps: {eval_steps} (validation s·∫Ω ƒë∆∞·ª£c ch·∫°y m·ªói {eval_steps} steps)")
print("=" * 70)
print("\nüí° B·∫°n s·∫Ω th·∫•y validation loss trong:")
print("   - Console output (m·ªói eval_steps)")
print("   - W&B dashboard (train_loss vs eval_loss)")
print("=" * 70)

start_time = time.time()

# Start training
trainer.train()

end_time = time.time()
training_time = end_time - start_time

print("\n" + "=" * 70)
print("‚úÖ TRAINING COMPLETED!")
print("=" * 70)
print(f"Total training time: {training_time / 3600:.2f} hours ({training_time / 60:.1f} minutes)")
print(f"Average time per epoch: {training_time / TRAINING_CONFIG['num_train_epochs'] / 60:.1f} minutes")
print("=" * 70)
print("\nüìä Validation results:")
print(f"   Best eval_loss: {trainer.state.best_metric if hasattr(trainer.state, 'best_metric') else 'N/A'}")
print(f"   Best model checkpoint: {trainer.state.best_model_checkpoint if hasattr(trainer.state, 'best_model_checkpoint') else 'N/A'}")
print("=" * 70)

wandb.finish()


üöÄ STARTING TRAINING
Dataset size: 67,373 samples
Validation size: 7,485 samples
Batch size: 8
Gradient accumulation: 4
Effective batch size: 32
Epochs: 3
Eval steps: 526 (validation s·∫Ω ƒë∆∞·ª£c ch·∫°y m·ªói 526 steps)

üí° B·∫°n s·∫Ω th·∫•y validation loss trong:
   - Console output (m·ªói eval_steps)
   - W&B dashboard (train_loss vs eval_loss)


You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
0,No log,5.726142
526,4.949700,4.942051
1052,4.652100,4.66265
1578,4.595200,4.588301
2104,4.541400,4.549474
2630,4.502600,4.525425
3156,4.489000,4.507982
3682,4.506300,4.493707
4208,4.480600,4.483937
4734,4.462400,4.478301


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
[34m[1mwandb[0m: updating run metadata



‚úÖ TRAINING COMPLETED!
Total training time: 1.98 hours (118.8 minutes)
Average time per epoch: 39.6 minutes

üìä Validation results:
   Best eval_loss: 4.473038673400879
   Best model checkpoint: ./checkpoints/kani_tts_vi/checkpoint-6312


[34m[1mwandb[0m: uploading history steps 329-329, summary
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:               eval/loss ‚ñà‚ñÑ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
[34m[1mwandb[0m:            eval/runtime ‚ñÇ‚ñÑ‚ñÅ‚ñÑ‚ñÅ‚ñÇ‚ñà‚ñÇ‚ñÇ‚ñÇ‚ñÑ‚ñÑ‚ñÑ
[34m[1mwandb[0m: eval/samples_per_second ‚ñá‚ñÖ‚ñà‚ñÖ‚ñà‚ñá‚ñÅ‚ñá‚ñá‚ñá‚ñÖ‚ñÖ‚ñÖ
[34m[1mwandb[0m:   eval/steps_per_second ‚ñá‚ñÖ‚ñà‚ñÖ‚ñà‚ñá‚ñÅ‚ñá‚ñá‚ñá‚ñÖ‚ñÖ‚ñÖ
[34m[1mwandb[0m:             train/epoch ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà
[34m[1mwandb[0m:       train/global_step ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
[34m[1mwandb[0m:         train/grad_norm ‚ñÉ‚ñÉ‚ñÅ‚ñÇ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñà‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñÜ‚ñá‚ñá‚ñá‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ
[34m[1mwandb[0m:     train/learning_rate ‚ñÉ‚ñá‚ñà‚

## üíæ Step 8: Save Model

Merge LoRA weights v√† l∆∞u model.


In [12]:
print("üîÑ Merging LoRA weights...")
merged_model = model.merge_and_unload()

output_path = TRAINING_CONFIG["output_dir"]
os.makedirs(output_path, exist_ok=True)

print(f"üíæ Saving model to {output_path}...")
merged_model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)

print(f"\n‚úÖ Model saved successfully!")
print(f"   Path: {output_path}")
print(f"\nüì§ To upload to HuggingFace Hub:")
print(f"   huggingface-cli upload <your-username>/vietnamese-tts-model {output_path} --private")


üîÑ Merging LoRA weights...
üíæ Saving model to ./checkpoints/kani_tts_vi...

‚úÖ Model saved successfully!
   Path: ./checkpoints/kani_tts_vi

üì§ To upload to HuggingFace Hub:
   huggingface-cli upload <your-username>/vietnamese-tts-model ./checkpoints/kani_tts_vi --private


In [16]:
!huggingface-cli upload pnnbao-ump/kani-tts-370m-vi "./checkpoints/kani_tts_vi" --private

It seems you are trying to upload a large folder at once. This might take some time and then fail if the folder is too large. For such cases, it is recommended to upload in smaller batches or to use `HfApi().upload_large_folder(...)`/`hf upload-large-folder` instead. For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#upload-a-large-folder.
Start hashing 43 files.
Finished hashing 43 files.
Processing Files (0 / 0)                : |        |  0.00B /  0.00B            
New Data Upload                         : |        |  0.00B /  0.00B            [A

  ...ts_vi/checkpoint-5786/rng_state.pth:  77%|‚ñà‚ñà‚ñà | 11.3kB / 14.6kB            [A[A


  ...ts_vi/checkpoint-6318/rng_state.pth:  77%|‚ñà‚ñà‚ñà | 11.3kB / 14.6kB            [A[A[A



  ...ts_vi/checkpoint-6312/rng_state.pth:  77%|‚ñà‚ñà‚ñà | 11.3kB / 14.6kB            [A[A[A[A




  ...tts_vi/checkpoint-5786/optimizer.pt:   0%|    |  568kB /  128MB         

## üìä Step 9: Training Summary

Xem l·∫°i k·∫øt qu·∫£ training.


In [None]:
print("=" * 70)
print("üìä TRAINING SUMMARY")
print("=" * 70)
print(f"Base Model: {BASE_MODEL}")
print("Datasets:")
for summary in dataset_summaries:
    print(f"   - {summary['speaker_id']}: {summary['reponame']} ({summary['num_samples']:,} samples)")
print(f"Training Samples: {len(train_dataset):,}")
print(f"LoRA Rank: {LORA_CONFIG['r']}")
print(f"Batch Size: {TRAINING_CONFIG['per_device_train_batch_size']}")
print(f"Epochs: {TRAINING_CONFIG['num_train_epochs']}")
print(f"Learning Rate: {TRAINING_CONFIG['learning_rate']}")
print(f"Training Time: {training_time / 3600:.2f} hours")
print(f"Model Saved: {output_path}")
print("=" * 70)

# Check final GPU memory
if torch.cuda.is_available():
    print(f"\nüíæ Final GPU Memory:")
    print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"   Reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
    print(f"   Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)) / 1e9:.2f} GB")

print("\nüéâ Fine-tuning completed successfully!")
print("\nüí° Next steps:")
print("   1. Test the model with inference script")
print("   2. Upload to HuggingFace Hub if needed")
print("   3. Evaluate audio quality")


üìä TRAINING SUMMARY
Base Model: nineninesix/kani-tts-370m
Dataset: pnnbao-ump/VieNeu-TTS-140h-nanocodec
Training Samples: 67,373
LoRA Rank: 48
Batch Size: 8
Epochs: 3
Learning Rate: 2e-05
Training Time: 1.98 hours
Model Saved: ./checkpoints/kani_tts_vi

üíæ Final GPU Memory:
   Allocated: 0.96 GB
   Reserved: 24.36 GB
   Free: 26.51 GB

üéâ Fine-tuning completed successfully!

üí° Next steps:
   1. Test the model with inference script
   2. Upload to HuggingFace Hub if needed
   3. Evaluate audio quality
