In [None]:
import optuna
import torch
from torch.utils.data import DataLoader
# Assuming you have adapted XTTS-v2's training code into a modular function
from your_xtts_training_module import train_one_epoch, evaluate  # Replace with your actual training functions
from your_xtts_data_loading import create_data_loaders # Replace with your data loading

def objective(trial):
    # 1. Suggest Hyperparameters
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
    gradient_clip = trial.suggest_float("gradient_clip", 0.5, 5.0)
    # Example: Only fine-tune last N layers
    num_layers_to_finetune = trial.suggest_int("num_layers_to_finetune", 4, 12)

    # 2. Load Data (10% Subset) - Replace with YOUR data loading code
    train_loader, eval_loader = create_data_loaders(
        dataset_config,  # Assuming you have a dataset_config
        batch_size=batch_size,
        train_split=0.9,  # Adjust split as needed
        max_samples=0.1  # Limit to 10% of dataset
    )

    # 3. Initialize Model and Optimizer (Adapt to XTTS-v2)
    model = YourXTTSModel().to("cuda")  # Replace with your actual XTTS-v2 model loading
    # Freeze earlier layers (example)
    for name, param in model.named_parameters():
        if "some_early_layer" in name: # Replace with your layer naming logic
            param.requires_grad = False

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), # Only train trainable parameters
        lr=lr,
        weight_decay=weight_decay
    )

    # 4. Training Loop (Limited Epochs)
    num_epochs = 30  # Or something reasonable for the 10% dataset
    for epoch in range(num_epochs):
        # Train one epoch (Assuming you have a function for this)
        train_one_epoch(model, train_loader, optimizer, criterion, gradient_clip) # Adapt to XTTS

        # 5. Evaluation (Adapt to XTTS and Consider Speaker Verification)
        eval_loss, speaker_verification_accuracy = evaluate(model, eval_loader, criterion, speaker_verification_model) # Adapt to XTTS

        # Report the evaluation loss and speaker verification accuracy to Optuna
        trial.report(speaker_verification_accuracy, epoch) # Or combine eval_loss and speaker_verification_accuracy

        # Handle pruning based on intermediate values.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return speaker_verification_accuracy # Or your combined metric

# Create Optuna Study
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) # Maximize speaker verification accuracy
study.optimize(objective, n_trials=50) # Adjust n_trials

print("Best trial:")
trial = study.best_trial
print("  Value: {}".format(trial.value))
print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))