# XTTS v2 Fine-tuning for Indic Languages

This notebook fine-tunes XTTS v2 for Indic languages (Hindi, Tamil, Telugu, Bengali, etc.).

**Requirements:**
- Colab T4 GPU (16GB VRAM)
- Google Drive for storage
- Prepared manifests from notebook 01

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q TTS torch torchaudio
!pip install -q indic-nlp-library aksharamukha

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set paths
DATA_DIR = '/content/drive/MyDrive/indic_speech_data'
OUTPUT_DIR = '/content/drive/MyDrive/indic_tts_models'

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)

## 1. Load XTTS Model

In [None]:
import torch
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.manage import ModelManager

# Configuration
LANGUAGE = 'hi'  # Change as needed: hi, ta, te, bn
BATCH_SIZE = 2
LEARNING_RATE = 1e-5
EPOCHS = 50
GRADIENT_ACCUMULATION = 8

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

In [None]:
# Load XTTS v2 model
print("Loading XTTS v2 model...")

model_manager = ModelManager()
model_path, config_path, _ = model_manager.download_model('tts_models/multilingual/multi-dataset/xtts_v2')

config = XttsConfig()
config.load_json(config_path)

model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=model_path)
model = model.to(device)

print("Model loaded successfully!")

## 2. Prepare Dataset

In [None]:
import json
import torchaudio
from torch.utils.data import Dataset, DataLoader

class XTTSDataset(Dataset):
    def __init__(self, manifest_path, max_length=11.0, sample_rate=22050):
        self.samples = []
        self.max_length = max_length
        self.sample_rate = sample_rate
        
        with open(manifest_path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip():
                    continue
                sample = json.loads(line)
                duration = sample.get('duration', 0)
                if 1.0 <= duration <= max_length:
                    self.samples.append(sample)
        
        print(f"Loaded {len(self.samples)} samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load audio
        try:
            waveform, sr = torchaudio.load(sample['audio_filepath'])
            if waveform.shape[0] > 1:
                waveform = waveform.mean(dim=0, keepdim=True)
            if sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                waveform = resampler(waveform)
            waveform = waveform.squeeze(0)
        except Exception as e:
            print(f"Error loading {sample['audio_filepath']}: {e}")
            waveform = torch.zeros(self.sample_rate)
        
        return {
            'audio': waveform,
            'text': sample['text'],
            'language': sample.get('language', f'{LANGUAGE}-IN'),
        }

In [None]:
# Load datasets
train_manifest = f"{DATA_DIR}/manifests/tts_{LANGUAGE}_train.jsonl"
val_manifest = f"{DATA_DIR}/manifests/tts_{LANGUAGE}_val.jsonl"

# Check if manifests exist
if not os.path.exists(train_manifest):
    print(f"Train manifest not found: {train_manifest}")
    print("Please run notebook 01 first to prepare data.")
else:
    train_dataset = XTTSDataset(train_manifest)
    val_dataset = XTTSDataset(val_manifest) if os.path.exists(val_manifest) else None

## 3. Training Loop

In [None]:
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm

def collate_fn(batch):
    max_len = max(item['audio'].shape[0] for item in batch)
    audio_batch = []
    audio_lengths = []
    
    for item in batch:
        audio = item['audio']
        audio_lengths.append(audio.shape[0])
        if audio.shape[0] < max_len:
            padding = torch.zeros(max_len - audio.shape[0])
            audio = torch.cat([audio, padding])
        audio_batch.append(audio)
    
    return {
        'audio': torch.stack(audio_batch),
        'audio_lengths': torch.tensor(audio_lengths),
        'text': [item['text'] for item in batch],
        'language': [item['language'] for item in batch],
    }

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn,
    pin_memory=True,
)

In [None]:
# Setup optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=0.01,
)

scaler = GradScaler()

# Training loop
model.train()
best_loss = float('inf')

for epoch in range(EPOCHS):
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for batch_idx, batch in enumerate(pbar):
        audio = batch['audio'].to(device)
        audio_lengths = batch['audio_lengths'].to(device)
        texts = batch['text']
        languages = batch['language']
        
        try:
            with autocast():
                # Note: XTTS forward_train API may vary
                # This is a simplified example
                outputs = model.forward(
                    texts[0],
                    languages[0][:2],  # Use 2-char lang code
                    audio[:1],  # Speaker reference
                )
                
                # Simple reconstruction loss
                loss = torch.nn.functional.mse_loss(
                    outputs['wav'][:audio_lengths[0]],
                    audio[0, :audio_lengths[0]]
                )
            
            loss = loss / GRADIENT_ACCUMULATION
            scaler.scale(loss).backward()
            
            if (batch_idx + 1) % GRADIENT_ACCUMULATION == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            total_loss += loss.item() * GRADIENT_ACCUMULATION
            num_batches += 1
            
            pbar.set_postfix({'loss': f"{total_loss/num_batches:.4f}"})
            
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    avg_loss = total_loss / max(num_batches, 1)
    print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
    
    # Save checkpoint
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f"{OUTPUT_DIR}/xtts_{LANGUAGE}_best.pt")
        print(f"Saved best model with loss: {avg_loss:.4f}")

## 4. Test Synthesis

In [None]:
import IPython.display as ipd

# Test text
test_texts = {
    'hi': 'नमस्ते, मेरा नाम है और मैं हिंदी बोल रहा हूं।',
    'ta': 'வணக்கம், என் பெயர் தமிழ் மொழி பேசுகிறேன்.',
    'te': 'నమస్కారం, నా పేరు తెలుగు భాషలో మాట్లాడుతున్నాను.',
    'bn': 'নমস্কার, আমার নাম বাংলা ভাষায় কথা বলছি।',
}

model.eval()

test_text = test_texts.get(LANGUAGE, test_texts['hi'])
print(f"Synthesizing: {test_text}")

# Get a reference audio from dataset
ref_audio = train_dataset[0]['audio'].unsqueeze(0).to(device)

with torch.no_grad():
    outputs = model.synthesize(
        test_text,
        config,
        speaker_wav=ref_audio,
        language=LANGUAGE,
    )

# Play audio
ipd.Audio(outputs['wav'], rate=22050)

## 5. Save Final Model

In [None]:
# Save final model
final_path = f"{OUTPUT_DIR}/xtts_{LANGUAGE}_final"
os.makedirs(final_path, exist_ok=True)

# Save model
torch.save(model.state_dict(), f"{final_path}/model.pt")

# Save config
config.save_json(f"{final_path}/config.json")

print(f"Model saved to {final_path}")