In [None]:
#Do not restart the runtime after running this cell, it is just a warning everything will work fine
!pip install TTS transformers torchaudio
!pip install gTTS

In [None]:
#Complete SSL-TTS Implementatio#n
import torch
import torchaudio
import torch.nn as nn
from transformers import WavLMModel
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.models.glow_tts import GlowTTS
import torch.nn.functional as F
from typing import List
import os


In [None]:
class SSLEncoder:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        print(f"Loading WavLM model to {device}...")
        self.model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
        self.model.eval()
        print("WavLM model loaded successfully!")

    @torch.no_grad()
    def extract_features(self, waveform, sample_rate=16000):
        """Extract WavLM features from the 6th layer"""
        # Resample if sample rate is not 16000 Hz
        if sample_rate != 16000:
            waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)

        # Ensure waveform is properly batched
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)

        # Move waveform to the specified device
        waveform = waveform.to(self.device)
        outputs = self.model(waveform, output_hidden_states=True)

        # Extract features from the 6th layer
        features = outputs.hidden_states[6]
        return features

In [6]:
class TextToSSL(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        # Initialize GlowTTS config with proper dimensions
        config = GlowTTSConfig(
            num_chars=148,  # Increased for full character set
            hidden_channels_enc=192,
            hidden_channels_dec=192,
            out_channels=1024,  # WavLM feature dimension
            use_encoder_prenet=True,
            encoder_type="rel_pos_transformer",
            dropout_p_dec=0.1,
        )

        # Initialize GlowTTS model
        self.glow_tts = GlowTTS(config).to(self.device)


        # Initialize tokenizer
        from TTS.tts.utils.text.tokenizer import TTSTokenizer
        from TTS.tts.utils.text.characters import Graphemes

        self.tokenizer = TTSTokenizer(
            use_phonemes=False,
            characters=Graphemes(),
        )

        # Load checkpoint if available
        if os.path.exists("glowtts_checkpoint.pth"):
            checkpoint = torch.load("glowtts_checkpoint.pth")
            self.glow_tts.load_state_dict(checkpoint['model'])

        self.glow_tts.eval()

    def tokenize_text(self, text):
        """Convert text to token indices."""
        token_ids = torch.LongTensor(self.tokenizer.text_to_ids(text)).to(self.device)
        return token_ids

    def generate(self, text: List[str], text_lengths=None):
        """Generate SSL features from text input."""
        # Tokenize all text inputs
        tokenized_texts = [self.tokenize_text(t) for t in text]

        # Calculate lengths if not provided
        if text_lengths is None:
            text_lengths = torch.tensor([len(t) for t in tokenized_texts]).long().to(self.device)

        # Pad sequences to max length
        max_length = max(text_lengths)
        padded_texts = []
        for tokens in tokenized_texts:
            if len(tokens) < max_length:
                padding = torch.zeros(max_length - len(tokens), dtype=torch.long, device=self.device)
                tokens = torch.cat([tokens, padding])
            padded_texts.append(tokens)

        # Stack into batch
        text_tensor = torch.stack(padded_texts)

        # Generate features using GlowTTS
        outputs = self.glow_tts.inference(
            text_tensor,
            aux_input={"x_lengths": text_lengths}
        )
        return outputs['model_outputs']

In [None]:
class TTSPipeline(nn.Module):
    def __init__(self, source='LJSpeech', target='LJSpeech'):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.ssl_encoder = SSLEncoder()
        self.text_to_ssl = TextToSSL()
        # Initialize the vocoder by getting it directly from knn_vc
        knn_vc = torch.hub.load(
            'bshall/knn-vc',
            'knn_vc',
            pretrained=True,
            prematched=True,
            trust_repo=True
        )
        self.vocoder = knn_vc.hifigan
        self.i = 0

    def get_features(self, path=None, waveform=None, get_target=False):
        if waveform is not None:
            x = waveform
        else:
            x, sample_rate = torchaudio.load(path, normalize=True)
            if sample_rate != 16000:
                resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
                x = resampler(x)

        if get_target:
            transform = torchaudio.transforms.Vad(sample_rate=16000, trigger_level=7.0)
            x_trim = transform(x)
            x_reversed = torch.flip(x_trim, (-1,))
            x_reversed_trim = transform(x_reversed)
            x_full_trim = torch.flip(x_reversed_trim, (-1,))
            x = x_full_trim

        features = self.ssl_encoder.extract_features(x, sample_rate=16000)
        features = features.squeeze(0) if features.dim() > 2 else features
        return features.to(self.device)

    def get_target_features(self, wavs):
        """Get features from target speaker utterances."""
        if isinstance(wavs, str):
            wavs = [wavs]
        features = []
        for path in wavs:
            feat = self.get_features(path, get_target=True)
            features.append(feat)
        features = torch.cat(features, dim=0)
        return features.to(self.device)

    def cosine_dist(self, source_features, target_features):
        source_norms = torch.norm(source_features, p=2, dim=-1)
        matching_norms = torch.norm(target_features, p=2, dim=-1)
        dotprod = -torch.cdist(source_features[None], target_features[None], p=2)[0]**2 + \
                  source_norms[:, None]**2 + matching_norms[None]**2
        dotprod /= 2
        dists = 1 - (dotprod / (source_norms[:, None] * matching_norms[None]))
        return dists

    def KNN(self, source_features, target_features):
        synth_set = target_features
        dists = self.cosine_dist(source_features, target_features)
        best = dists.topk(k=4, largest=False, dim=-1)
        selected_features = synth_set[best.indices].mean(dim=1)
        return selected_features

    def forward(self, text, target_wavs, save_file=False, lambda_value=1.0):
        """Generate speech given text input and target speaker references."""
        # Convert text to SSL features using GlowTTS
        source_features = self.text_to_ssl.generate([text])
        print(f"Source features initial shape: {source_features.shape}")

        # Get target speaker features
        target_features = self.get_target_features(target_wavs)
        print(f"Target features shape: {target_features.shape}")

        # Ensure source_features has the right shape (sequence_length, feature_dim)
        if source_features.dim() == 3:
            source_features = source_features.squeeze(0)
        print(f"Source features shape before KNN: {source_features.shape}")

        # Perform KNN retrieval
        selected_features = self.KNN(source_features, target_features)
        print(f"Selected features shape: {selected_features.shape}")

        # Linear interpolation
        converted_features = lambda_value * selected_features + (1 - lambda_value) * source_features

        # Following the working example, the vocoder expects shape [batch_size, channels, seq_len]
        # No need for additional reshaping beyond transposing
        converted_features = converted_features.unsqueeze(0)  # Add batch dimension
        print(f"Converted features shape for vocoder: {converted_features.shape}")

        # Generate waveform using HiFi-GAN
        generated_waveform = self.vocoder(converted_features.to(self.device)).squeeze()

        # Normalize loudness
        src_loudness = torchaudio.functional.loudness(generated_waveform[None], 16000)
        tgt_loudness = -16
        generated_waveform = torchaudio.functional.gain(generated_waveform,
                                                      tgt_loudness - src_loudness)

        if save_file:
            save_waveform = generated_waveform
            if self.i < 10:
                output_path = f'/content/generated_waveform_0{self.i}.wav'
            else:
                output_path = f'/content/generated_waveform_{self.i}.wav'
            torchaudio.save(output_path, save_waveform.unsqueeze(0).cpu(), sample_rate=16000)
            print(f"Generated waveform saved at {output_path}")
            self.i += 1  # Increment self.i after saving the waveform

        return generated_waveform.unsqueeze(0)

Exampe Usage\
The TextToSSL (GlowTTS) needs just a text input to generate WavLM features. The TTSPipeline takes in text and a target audio file and produces an audio file. Feel free to load the LJSpeech dataset and use it as a target audio.

In [None]:
!wget --no-check-certificate https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2

# Extract the dataset
!tar -xjf LJSpeech-1.1.tar.bz2

# Verify the extraction by listing the contents
!ls LJSpeech-1.1

In [7]:
# Initialize
text_to_ssl = TextToSSL()

# Generate features from text
text_inputs = ["Hello world", "We've enhanced Connect with new features to manage and pay for multiple users from one account. These features were previously available only in beta. As part of this update, we're also revising our terms and conditions."]
features = text_to_ssl.generate(text_inputs)

In [None]:
# Initialize pipeline
pipeline = TTSPipeline()

# Set up target speaker wavs, change to LJSpeech wavs if you decide to use that
#target_wavs = [path_to_wav1, path_to_wav2, ...]
targ = "/content/test.wav"
# Generate speech
output = pipeline("We've enhanced Connect with new features to manage and pay for multiple users from one account. These features were previously available only in beta. As part of this update, we're also revising our terms and conditions.", targ, save_file=True)

Loading WavLM model to cuda...
WavLM model loaded successfully!


Using cache found in /root/.cache/torch/hub/bshall_knn-vc_master


Removing weight norm...
[HiFiGAN] Generator loaded with 16,523,393 parameters.
WavLM-Large loaded with 315,453,120 parameters.
Source features initial shape: torch.Size([1, 220, 1024])
Target features shape: torch.Size([301, 1024])
Source features shape before KNN: torch.Size([220, 1024])
Selected features shape: torch.Size([220, 1024])
Converted features shape for vocoder: torch.Size([1, 220, 1024])
Generated waveform saved at /content/generated_waveform_00.wav
