# Imports and pips

In [None]:
#!git clone https://github.com/idiap/coqui-ai-TTS.git
!git clone https://github.com/coqui-ai/TTS.git
!git clone https://github.com/eginhard/coqui-trainer
!pip install Coqpit
!pip install TTS
!pip install transformers torchaudio
!pip install coqui-trainer
# DO NOT RESTART RUNTIME AFTER RUNNING THIS CELL
# YOU MIGHT HAVE A FEW WARNINGS/ERROR BUT DW IT'S FINE

In [None]:
import numpy
import os
import torch
import torchaudio
from torch.utils.data import DataLoader, Dataset
from transformers import Wav2Vec2FeatureExtractor, WavLMModel
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from dataclasses import dataclass, field
from typing import List, Dict
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
from TTS.tts.datasets.dataset import TTSDataset

In [None]:
from dataclasses import dataclass
from torch import nn
import pandas as pd
from trainer import Trainer, TrainerConfig, TrainerModel
from trainer.trainer import TrainerArgs
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.text.cleaners import english_cleaners
from TTS.tts.utils.text.characters import Graphemes

# LJSpeech dataset

In [None]:
# Download the LJSpeech dataset without checking the SSL certificate
!wget --no-check-certificate https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2

# Extract the dataset
!tar -xjf LJSpeech-1.1.tar.bz2

# Verify the extraction by listing the contents
!ls LJSpeech-1.1


--2024-11-28 16:17:52--  https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
Resolving data.keithito.com (data.keithito.com)... 169.150.236.105, 2400:52e0:1a00::845:1
Connecting to data.keithito.com (data.keithito.com)|169.150.236.105|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2748572632 (2.6G) [text/plain]
Saving to: ‘LJSpeech-1.1.tar.bz2’


2024-11-28 16:18:07 (172 MB/s) - ‘LJSpeech-1.1.tar.bz2’ saved [2748572632/2748572632]

metadata.csv  README  wavs


# GlowTTS Adaptation and Training

Implements the text-to-SSL conversion using a modified GlowTTS architecture:

### Architecture Overview
1. Configuration Setup:
   - num_chars: 148 for English character set
   - out_channels: 1024 to match WavLM features
   - hidden_channels: 192 for encoder/decoder
   - encoder_type: "rel_pos_transformer"

2. Model Components:
   - Transformer-based text encoder
   - Duration predictor
   - Flow-based decoder
   - Speaker-independent design

### Key Features
- Non-autoregressive architecture
- Flow-based feature generation
- Duration prediction for proper alignment
- Batch processing support
- Device-agnostic implementation

## SSL Encoder

In [None]:
class SSLEncoder:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        print(f"Loading WavLM model to {device}...")
        self.model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
        self.model.eval()
        print("WavLM model loaded successfully!")

    @torch.no_grad()
    def extract_features(self, waveform, sample_rate=16000):
        """Extract WavLM features from the 6th layer"""
        # Resample if sample rate is not 16000 Hz
        if sample_rate != 16000:
            waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)

        # Ensure waveform is properly batched
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)

        # Move waveform to the specified device
        waveform = waveform.to(self.device)
        outputs = self.model(waveform, output_hidden_states=True)

        # Extract features from the 6th layer
        features = outputs.hidden_states[6]
        return features

## Data Preprocessing

This code sets up the data processing pipeline for training SSL-TTS. It performs three key operations:

1. **Audio Loading and Resampling**
   - Loads audio files from LJSpeech dataset
   - Resamples them from 22.05kHz to 16kHz (required by WavLM)

2. **Feature Extraction**
   - Uses WavLM to convert raw audio into high-level speech features
   - Instead of using mel-spectrograms, we get 1024-dimensional WavLM features
   - These features contain rich information about speech content and speaker characteristics

3. **Batch Processing**
   - Handles variable-length audio files by padding them to the same length
   - Creates batches of features and their corresponding text transcriptions
   - Makes the data ready for training the GlowTTS model

This pipeline transforms raw audio into the format needed for training our SSL-TTS system, where GlowTTS will learn to predict WavLM features from text.

In [None]:
class LJSpeechDataset(Dataset):
    def __init__(self, root_dir, metadata_file, ssl_encoder, tokenizer):
        self.root_dir = root_dir
        self.metadata = pd.read_csv(metadata_file, sep="|", header=None,
                                  names=["file", "text", "normalized_text"])
        self.resampler = torchaudio.transforms.Resample(orig_freq=22050, new_freq=16000)
        self.tokenizer = tokenizer
        self.ssl_encoder = ssl_encoder
        self.device = ssl_encoder.device  # Get device from encoder

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        try:
            wav_file = os.path.join(
                self.root_dir,
                "wavs",
                self.metadata.iloc[idx, 0] + ".wav"
            )

            # Load and process audio
            waveform, sample_rate = torchaudio.load(wav_file)

            # Resample if necessary
            if sample_rate != 16000:
                waveform = self.resampler(waveform)

            # Move waveform to correct device
            waveform = waveform.to(self.device)

            # Extract features
            with torch.no_grad():
                features = self.ssl_encoder.extract_features(waveform)

            # Get and tokenize text, then move to correct device
            text = self.metadata.iloc[idx, 1]
            text_tokens = torch.LongTensor(self.tokenizer.text_to_ids(text)).to(self.device)

            return features, text_tokens

        except Exception as e:
            print(f"Error processing file {wav_file}: {str(e)}")
            raise

def collate_fn(batch):
    # Separate features and texts
    features, texts = zip(*batch)

    # Get max lengths
    max_feature_len = max(feature.size(1) for feature in features)
    max_text_len = max(text.size(0) for text in texts)

    # Get device from first feature
    device = features[0].device

    # Pad features
    padded_features = [
        torch.nn.functional.pad(
            feature.squeeze(0),
            (0, 0, 0, max_feature_len - feature.size(1))
        ) for feature in features
    ]

    # Pad text sequences
    padded_texts = [
        torch.nn.functional.pad(
            text,
            (0, max_text_len - text.size(0))
        ) for text in texts
    ]

    # Stack tensors
    features_tensor = torch.stack(padded_features)
    texts_tensor = torch.stack(padded_texts)

    return features_tensor.to(device), texts_tensor.to(device)

## GlowTTS

In [None]:
# Check if CUDA is available
is_cuda = torch.cuda.is_available()

class ExtendedGraphemes(Graphemes):
    def __init__(self):
        super().__init__()
        self.characters += "0123456789.,!?\"'()$"

@dataclass
class GlowTTSWavLMConfig(TrainerConfig):
    """Configuration for GlowTTS with WavLM features"""
    epochs: int = 1590
    batch_size: int = 32
    print_step: int = 25
    wavlm_feature_dim: int = 1024  # Dimension of WavLM features
    lr: float = 0.000003
    lr_scheduler: str = 'noamlr'
    lr_scheduler_params: Dict = field(default_factory=lambda: {'warmup_steps': 0.1})
    # run_eval_steps: int = 50000
    run_eval: bool = False
    # mixed_precision: bool = True

class GlowTTSWavLM(TrainerModel):
    def __init__(self, config: GlowTTSWavLMConfig):
        super().__init__()

        # Initialize GlowTTS base configuration
        glow_config = GlowTTSConfig(
            num_chars=148,
            hidden_channels_enc=192,
            hidden_channels_dec=192,
            out_channels=config.wavlm_feature_dim,
            use_encoder_prenet=True,
            encoder_type="rel_pos_transformer",
            dropout_p_dec=0.1,
        )

        # Initialize components
        self.glow = GlowTTS(glow_config)
        self.ssl_encoder = SSLEncoder()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.tokenizer = TTSTokenizer(
            use_phonemes=False,
            characters=ExtendedGraphemes(),
        )

        # Move model components to device
        self.glow = self.glow.to(self.device)
        self.to(self.device)  # Move the whole model to device

    def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
        """Forward pass for training the GlowTTS model with WavLM features.

        Args:
            input (torch.Tensor): Input text tensor of token indices.
            aux_input (Dict): Auxiliary inputs including:
                - x_lengths (torch.Tensor): Lengths of input sequences
                - y (torch.Tensor): Target WavLM features
                - y_lengths (torch.Tensor): Lengths of target sequences

        Returns:
            Dict: Model outputs dictionary containing:
                - model_outputs: Main model output (WavLM features)
                - z: Latent variable
                - y_mean: Mean of output distribution
                - y_log_scale: Log scale of output distribution
                - logdet: Log determinant of transformations
                - alignments: Attention alignments
                - durations_log: Log durations
                - total_durations_log: Total log durations
        """
        x_lengths = aux_input.get("x_lengths")
        y = aux_input.get("y")
        y_lengths = aux_input.get("y_lengths")

        # Call the GlowTTS forward pass
        outputs = self.glow(
            input,  # text input
            x_lengths,  # text lengths
            y,  # target WavLM features
            y_lengths,  # feature lengths
            aux_input  # any additional inputs
        )

        # Ensure the output dictionary has the required 'model_outputs' key
        outputs["model_outputs"] = outputs.get("z")  # Using z as the main output

        return outputs


    def train_step(self, batch, criterion):
        features, texts = batch

        # Ensure all inputs are on the correct device
        features = features.to(self.device)
        texts = texts.to(self.device)

        # Get lengths
        feature_lengths = torch.tensor([f.size(1) for f in features]).long().to(self.device)
        text_lengths = torch.tensor([len(t) for t in texts]).long().to(self.device)

        # Create aux_input dictionary
        aux_input = {
            "x_lengths": text_lengths,
            "y": features,
            "y_lengths": feature_lengths
        }

        # Forward pass
        outputs = self.forward(texts, aux_input=aux_input)

        # Calculate loss
        loss_dict = criterion(
            outputs["z"],
            outputs["y_mean"],
            outputs["y_log_scale"],
            outputs["logdet"],
            feature_lengths,
            outputs["durations_log"],
            outputs["total_durations_log"],
            text_lengths
        )

        return outputs, loss_dict

    def optimize(self, batch, trainer):
        """Custom optimization step"""
        # Forward pass and loss computation
        outputs, loss_dict = self.train_step(batch, trainer.criterion)

        # Compute total loss
        total_loss = sum(loss_dict.values())

        # Backward pass with gradient scaling
        self.scaled_backward(total_loss, trainer, trainer.optimizer)

        # Optimizer step
        if trainer.total_steps_done % trainer.grad_accum_steps == 0:
            trainer.optimizer.step()
            trainer.optimizer.zero_grad()

        return outputs, loss_dict

    @torch.no_grad()
    def eval_step(self, batch, criterion=None):
        """
        Perform evaluation step for the model.
        """
        return self.train_step(batch, criterion.criterion)

    def get_criterion(self):
        """Get the loss criterion"""
        from TTS.tts.layers.losses import GlowTTSLoss
        return GlowTTSLoss()

    def get_optimizer(self):
        """Get the optimizer"""
        return torch.optim.Adam(self.parameters(), lr=0.000003)

    def get_data_loader(self, config, assets, is_eval, samples=None, verbose=True, num_gpus=1, rank=0):
        """Get the data loader for training or evaluation"""
        dataset = LJSpeechDataset(
            root_dir="LJSpeech-1.1",
            metadata_file=os.path.join("LJSpeech-1.1", "metadata.csv"),
            ssl_encoder=self.ssl_encoder,
            tokenizer=self.tokenizer
        )

        return DataLoader(
            dataset,
            batch_size=config.batch_size,
            shuffle=not is_eval,
            collate_fn=collate_fn,
            drop_last=True,
            num_workers=0
        )



In [None]:
def main():
    # Initialize configuration
    config = GlowTTSWavLMConfig()
    config.batch_size = 32

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize model
    model = GlowTTSWavLM(config)
    model = model.to(device)

    checkpoint = torch.load('/content/best_model_15133.pth')

    # Load the model state dict
    model.load_state_dict(checkpoint['model'])

    # Initialize trainer
    trainer = Trainer(
        TrainerArgs(),
        config,
        model=model,
        output_path=os.getcwd(),
        gpu=0 if torch.cuda.is_available() else None
    )

    # Start training
    trainer.fit()

if __name__ == "__main__":
    main()

Loading WavLM model to cuda...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/2.22k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

WavLM model loaded successfully!


  checkpoint = torch.load('/content/best_model_21268.pth')
 > Training Environment:
 | > Backend: Torch
 | > Mixed precision: False
 | > Precision: float32
 | > Current device: 0
 | > Num. of GPUs: 1
 | > Num. of CPUs: 12
 | > Num. of Torch Threads: 6
 | > Torch seed: 54321
 | > Torch CUDNN: True
 | > Torch CUDNN deterministic: False
 | > Torch CUDNN benchmark: False
 | > Torch TF32 MatMul: False
 > Start Tensorboard: tensorboard --logdir=output/run-November-28-2024_04+27PM-0000000

 > Model has 35388609 parameters

[4m[1m > EPOCH: 0/1590[0m
 --> output/run-November-28-2024_04+27PM-0000000

[1m > TRAINING (2024-11-28 16:27:52) [0m

[1m   --> TIME: 2024-11-28 16:27:59 -- STEP: 0/409 -- GLOBAL_STEP: 0[0m
     | > loss: -1.0882741212844849  (-1.0882741212844849)
     | > log_mle: -1.406938910484314  (-1.406938910484314)
     | > loss_dur: 0.3186648190021515  (0.3186648190021515)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 3.5883  (3.588322877883911)
     | > lo

He was summoned to the Mansion House, where he repeated his request, crying, "Accordez moi cette grâce," with much urgency.
 [!] Character 'â' not found in the vocabulary. Discarding it.
Müller protested after sentence of death had been passed upon him that he had been convicted on a false statement of facts.
 [!] Character 'ü' not found in the vocabulary. Discarding it.
he broke open the chest and stole £4700 in notes, with a quantity of gold and some silver.
 [!] Character '£' not found in the vocabulary. Discarding it.
So great is the authority exercised by him,|So great is the authority exercised by him,
 [!] Character '|' not found in the vocabulary. Discarding it.



[1m   --> TIME: 2024-11-28 16:28:47 -- STEP: 25/409 -- GLOBAL_STEP: 25[0m
     | > loss: 0.07189086079597473  (0.7555053353309631)
     | > log_mle: -0.2619713544845581  (0.42243337869644165)
     | > loss_dur: 0.33386221528053284  (0.3330719530582428)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.3378  (0.36444716453552245)
     | > loader_time: 1.4246  (1.5114511775970458)


[1m   --> TIME: 2024-11-28 16:29:30 -- STEP: 50/409 -- GLOBAL_STEP: 50[0m
     | > loss: 0.19165048003196716  (0.4489688462018966)
     | > log_mle: -0.20028388500213623  (0.09457263588905335)
     | > loss_dur: 0.3919343650341034  (0.354396208524704)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.3263  (0.3440762233734131)
     | > loader_time: 1.3237  (1.4529456615447998)



While your bread is taking a three hours’ rise, you are free in body and mind for other things.
 [!] Character '’' not found in the vocabulary. Discarding it.
When arrested on the day of the assassination, he had in his possession a Smith & Wesson 38 caliber revolver
 [!] Character '&' not found in the vocabulary. Discarding it.



[1m   --> TIME: 2024-11-28 16:30:12 -- STEP: 75/409 -- GLOBAL_STEP: 75[0m
     | > loss: 0.24252840876579285  (0.30182783802350355)
     | > log_mle: -0.13570153713226318  (-0.05339456478754679)
     | > loss_dur: 0.37822994589805603  (0.35522240181763964)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.293  (0.3352998987833659)
     | > loader_time: 1.3327  (1.4224552694956463)


[1m   --> TIME: 2024-11-28 16:30:54 -- STEP: 100/409 -- GLOBAL_STEP: 100[0m
     | > loss: 0.28964290022850037  (0.21338904708623882)
     | > log_mle: -0.10029721260070801  (-0.14849021255970002)
     | > loss_dur: 0.3899401128292084  (0.36187925890088074)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.3462  (0.3337572383880616)
     | > loader_time: 1.3843  (1.4043866300582883)



So I gave her a transfer and opened the door and she was going out the gentleman I had picked up about two blocks [back]
 [!] Character '[' not found in the vocabulary. Discarding it.
So I gave her a transfer and opened the door and she was going out the gentleman I had picked up about two blocks [back]
 [!] Character ']' not found in the vocabulary. Discarding it.



[1m   --> TIME: 2024-11-28 16:31:37 -- STEP: 125/409 -- GLOBAL_STEP: 125[0m
     | > loss: -0.18404024839401245  (0.1647198631763458)
     | > log_mle: -0.5191107988357544  (-0.20042934560775758)
     | > loss_dur: 0.33507055044174194  (0.3651492081880569)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.3346  (0.33377989959716803)
     | > loader_time: 1.3669  (1.3932109851837156)



Into the “crater” dug out in the middle, pour the sponge, warm water, the molasses, and soda dissolved in hot water.
 [!] Character '“' not found in the vocabulary. Discarding it.
Into the “crater” dug out in the middle, pour the sponge, warm water, the molasses, and soda dissolved in hot water.
 [!] Character '”' not found in the vocabulary. Discarding it.



[1m   --> TIME: 2024-11-28 16:32:19 -- STEP: 150/409 -- GLOBAL_STEP: 150[0m
     | > loss: -0.4794243276119232  (0.11101292967796325)
     | > log_mle: -0.7929145097732544  (-0.25646921594937644)
     | > loss_dur: 0.3134901821613312  (0.36748214513063426)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.3724  (0.33418945153554286)
     | > loader_time: 1.3468  (1.3842518329620361)


[1m   --> TIME: 2024-11-28 16:33:01 -- STEP: 175/409 -- GLOBAL_STEP: 175[0m
     | > loss: 0.041770100593566895  (0.0731823492050171)
     | > log_mle: -0.3613091707229614  (-0.2957280196462359)
     | > loss_dur: 0.4030792713165283  (0.3689103684255055)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.3167  (0.3334522560664586)
     | > loader_time: 1.3977  (1.3778261811392654)



He had a nice taste in bric-à-brac, and was considered a good judge of pictures.
 [!] Character 'à' not found in the vocabulary. Discarding it.



[1m   --> TIME: 2024-11-28 16:33:43 -- STEP: 200/409 -- GLOBAL_STEP: 200[0m
     | > loss: 0.1650693714618683  (0.024932592660188668)
     | > log_mle: -0.2347468137741089  (-0.34388052076101305)
     | > loss_dur: 0.3998161852359772  (0.36881311275064943)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.3481  (0.3332423257827759)
     | > loader_time: 1.3792  (1.372394526004792)


[1m   --> TIME: 2024-11-28 16:34:25 -- STEP: 225/409 -- GLOBAL_STEP: 225[0m
     | > loss: -0.3463161885738373  (-0.006028768486446812)
     | > log_mle: -0.6961883306503296  (-0.37288279665840995)
     | > loss_dur: 0.3498721420764923  (0.3668540275096893)
     | > current_lr: 2.0554804791094465e-06 
     | > step_time: 0.3383  (0.33296140882703995)
     | > loader_time: 1.3619  (1.3689751254187696)


[1m   --> TIME: 2024-11-28 16:35:07 -- STEP: 250/409 -- GLOBAL_STEP: 250[0m
     | > loss: -0.23447063565254211  (-0.03811470532417298)
     | > log_mle: -0.6050094366073608  (-0.40393

an avowed "snatcher" and habitué of the Fortune of War, a public-house in Smithfield frequented openly by men of this awful profession.
 [!] Character 'é' not found in the vocabulary. Discarding it.
and the raison d'être of the penalty, which in principle so many opposed, would be gone.
 [!] Character 'ê' not found in the vocabulary. Discarding it.
The greatest causes célèbre, however, of recent times were the turf frauds by which the Comtesse de Goncourt was swindled
 [!] Character 'è' not found in the vocabulary. Discarding it.


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[1m   --> TIME: 2024-11-29 08:06:45 -- STEP: 3/409 -- GLOBAL_STEP: 33950[0m
     | > loss: -1.554314374923706  (-1.4011663595835369)
     | > log_mle: -1.6891556978225708  (-1.639520287513733)
     | > loss_dur: 0.13484136760234833  (0.2383539229631424)
     | > current_lr: 2.2561829371270255e-07 
     | > step_time: 0.3572  (0.3116176128387451)
     | > loader_time: 1.3916  (1.361732800801595)


[1m   --> TIME: 2024-11-29 08:07:27 -- STEP: 28/409 -- GLOBAL_STEP: 33975[0m
     | > loss: -1.1803768873214722  (-1.5459882787295751)
     | > log_mle: -1.5322970151901245  (-1.8415165671280451)
     | > loss_dur: 0.35192012786865234  (0.2955282899950232)
     | > current_lr: 2.2561829371270255e-07 
     | > step_time: 0.3128  (0.310452972139631)
     | > loader_time: 1.3304  (1.3389836549758911)


[1m   --> TIME: 2024-11-29 08:08:09 -- STEP: 53/409 -- GLOBAL_STEP: 34000[0m
     | > loss: -1.4548499584197998  (-1.534293734

In [None]:
while True:
  pass

In [None]:
from google.colab import drive
drive.mount('/content/drive')