This notebook converts Forward Tacotron pre-trained PyTorch Model to ONNX. In future it will be updated to support TFLite Conversion.

## Acknowledgments

- Pre Trained Model is taken from [Transformer TTS Repository](https://github.com/as-ideas/TransformerTTS) by Axel Springer.
- Model Utilities and helper functions are also taken from the same [repository](https://github.com/as-ideas/TransformerTTS)

## Setup

In [None]:
# Clone the repo including pretrained models
!git clone https://github.com/as-ideas/ForwardTacotron.git

In [None]:
# Install requirements
%cd ForwardTacotron/  
!apt-get install espeak
!pip install -r requirements.txt

In [None]:
!pip install onnx
!pip install onnxruntime
!pip install pip install git+https://github.com/onnx/onnx-tensorflow.git

## Download Checkpoints

In [None]:
# Load pretrained models
from pathlib import Path
from typing import Union, Callable, List

import numpy as np

import torch.nn as nn
import torch
import torch.nn.functional as F

from models.tacotron import CBHG
from utils.text import text_to_sequence, clean_text
from utils.text.symbols import phonemes
from utils import hparams as hp

from notebook_utils.synthesize import (
    get_forward_model, get_melgan_model, get_wavernn_model, synthesize, init_hparams)
from utils import hparams as hp
import IPython.display as ipd
init_hparams('pretrained/pretrained_hparams.py')
voc_melgan = get_melgan_model() 
voc_wavernn = get_wavernn_model('pretrained/wave_575K.pyt')

In [4]:
#@title Model Helper Functions

pitch_function: Callable[[torch.tensor], torch.tensor] = lambda x: x

class LengthRegulator(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x, dur):
        return self.expand(x, dur)

    @staticmethod
    def build_index(duration, x):
        duration[duration < 0] = 0
        tot_duration = duration.cumsum(1).detach().cpu().numpy().astype('int')
        max_duration = int(tot_duration.max().item())
        index = np.zeros([x.shape[0], max_duration, x.shape[2]], dtype='long')

        for i in range(tot_duration.shape[0]):
            pos = 0
            for j in range(tot_duration.shape[1]):
                pos1 = tot_duration[i, j]
                index[i, pos:pos1, :] = j
                pos = pos1
            index[i, pos:, :] = j
        return torch.LongTensor(index).to(duration.device)

    def expand(self, x, dur):
        idx = self.build_index(dur, x)
        y = torch.gather(x, 1, idx)
        return y


class SeriesPredictor(nn.Module):

    def __init__(self, in_dims, conv_dims=256, rnn_dims=64, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList([
            BatchNormConv(in_dims, conv_dims, 5, activation=torch.relu),
            BatchNormConv(conv_dims, conv_dims, 5, activation=torch.relu),
            BatchNormConv(conv_dims, conv_dims, 5, activation=torch.relu),
        ])
        self.rnn = nn.GRU(conv_dims, rnn_dims, batch_first=True, bidirectional=True)
        self.lin = nn.Linear(2 * rnn_dims, 1)
        self.dropout = dropout

    def forward(self, x, alpha=1.0):
        x = x.transpose(1, 2)
        for conv in self.convs:
            x = conv(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = x.transpose(1, 2)
        x, _ = self.rnn(x)
        x = self.lin(x)
        return x / alpha


class ConvResNet(nn.Module):

    def __init__(self, in_dims, conv_dims=256):
        super().__init__()
        self.first_conv = BatchNormConv(in_dims, conv_dims, 5, activation=torch.relu)
        self.convs = torch.nn.ModuleList([
            BatchNormConv(conv_dims, conv_dims, 5, activation=torch.relu),
            BatchNormConv(conv_dims, conv_dims, 5, activation=torch.relu),
        ])

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.first_conv(x)
        for conv in self.convs:
            x_res = x
            x = conv(x)
            x = x_res + x
        x = x.transpose(1, 2)
        return x

class BatchNormConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel, activation=None):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
        self.bnorm = nn.BatchNorm1d(out_channels)
        self.activation = activation

    def forward(self, x):
        x = self.conv(x)
        if self.activation:
            x = self.activation(x)
        x = self.bnorm(x)
        return x


class ForwardTacotron(nn.Module):

    def __init__(self,
                 embed_dims,
                 num_chars,
                 durpred_conv_dims,
                 durpred_rnn_dims,
                 durpred_dropout,
                 pitch_conv_dims,
                 pitch_rnn_dims,
                 pitch_dropout,
                 pitch_emb_dims,
                 pitch_proj_dropout,
                 rnn_dim,
                 prenet_k,
                 prenet_dims,
                 postnet_k,
                 postnet_dims,
                 highways,
                 dropout,
                 n_mels):
        super().__init__()
        self.rnn_dim = rnn_dim
        self.embedding = nn.Embedding(num_chars, embed_dims)
        self.lr = LengthRegulator()
        self.dur_pred = SeriesPredictor(embed_dims,
                                        conv_dims=durpred_conv_dims,
                                        rnn_dims=durpred_rnn_dims,
                                        dropout=durpred_dropout)
        self.pitch_pred = SeriesPredictor(embed_dims,
                                          conv_dims=pitch_conv_dims,
                                          rnn_dims=pitch_rnn_dims,
                                          dropout=pitch_dropout)
        self.prenet = CBHG(K=prenet_k,
                           in_channels=embed_dims,
                           channels=prenet_dims,
                           proj_channels=[prenet_dims, embed_dims],
                           num_highways=highways)
        self.lstm = nn.LSTM(2 * prenet_dims + pitch_emb_dims,
                            rnn_dim,
                            batch_first=True,
                            bidirectional=True)
        self.lin = torch.nn.Linear(2 * rnn_dim, n_mels)
        self.register_buffer('step', torch.zeros(1, dtype=torch.long))
        self.postnet = CBHG(K=postnet_k,
                            in_channels=n_mels,
                            channels=postnet_dims,
                            proj_channels=[postnet_dims, n_mels],
                            num_highways=highways)
        self.dropout = dropout
        self.post_proj = nn.Linear(2 * postnet_dims, n_mels, bias=False)
        self.pitch_emb_dims = pitch_emb_dims
        if pitch_emb_dims > 0:
            self.pitch_proj = nn.Sequential(
                nn.Conv1d(1, pitch_emb_dims, kernel_size=3, padding=1),
                nn.Dropout(pitch_proj_dropout))

    def forward(self,
                 x: List[int],
                 alpha=1.0,
                 pitch_function: Callable[[torch.tensor], torch.tensor] = lambda x: x) -> tuple:
        self.eval()
        device = next(self.parameters()).device  # use same device as parameters
        x = torch.as_tensor(x, dtype=torch.long, device=device).unsqueeze(0)

        x = self.embedding(x)
        dur = self.dur_pred(x, alpha=alpha)
        dur = dur.squeeze(2)

        pitch_hat = self.pitch_pred(x).transpose(1, 2)
        pitch_hat = pitch_function(pitch_hat)

        x = x.transpose(1, 2)
        x = self.prenet(x)

        if self.pitch_emb_dims > 0:
            pitch_hat_proj = self.pitch_proj(pitch_hat).transpose(1, 2)
            x = torch.cat([x, pitch_hat_proj], dim=-1)

        x = self.lr(x, dur)

        x, _ = self.lstm(x)
        x = F.dropout(x,
                      p=self.dropout,
                      training=self.training)
        x = self.lin(x)
        x = x.transpose(1, 2)

        x_post = self.postnet(x)
        x_post = self.post_proj(x_post)
        x_post = x_post.transpose(1, 2)

        x, x_post, dur = x.squeeze(), x_post.squeeze(), dur.squeeze()
        # x = x.cpu().data.numpy()
        # x_post = x_post.cpu().data.numpy()
        # dur = dur.cpu().data.numpy()

        return x, x_post, pitch_hat

    def pad(self, x, max_len):
        x = x[:, :, :max_len]
        x = F.pad(x, [0, max_len - x.size(2), 0, 0], 'constant', -11.5129)
        return x

    def get_step(self):
        return self.step.data.item()

    def load(self, path: Union[str, Path]):
        # Use device of model params as location for loaded state
        device = next(self.parameters()).device
        state_dict = torch.load(path, map_location=device)
        self.load_state_dict(state_dict, strict=False)

    def save(self, path: Union[str, Path]):
        # No optimizer argument because saving a model should not include data
        # only relevant in the training process - it should only be properties
        # of the model itself. Let caller take care of saving optimzier state.
        torch.save(self.state_dict(), path)

    def log(self, path, msg):
        with open(path, 'a') as f:
            print(msg, file=f)



In [5]:
def get_forward_model(model_path):
    device = torch.device('cuda')
    model = ForwardTacotron(embed_dims=hp.forward_embed_dims,
                            num_chars=len(phonemes),
                            durpred_rnn_dims=hp.forward_durpred_rnn_dims,
                            durpred_conv_dims=hp.forward_durpred_conv_dims,
                            durpred_dropout=hp.forward_durpred_dropout,
                            pitch_rnn_dims=hp.forward_pitch_rnn_dims,
                            pitch_conv_dims=hp.forward_pitch_conv_dims,
                            pitch_dropout=hp.forward_pitch_dropout,
                            pitch_emb_dims=hp.forward_pitch_emb_dims,
                            pitch_proj_dropout=hp.forward_pitch_proj_dropout,
                            rnn_dim=hp.forward_rnn_dims,
                            postnet_k=hp.forward_postnet_K,
                            postnet_dims=hp.forward_postnet_dims,
                            prenet_k=hp.forward_prenet_K,
                            prenet_dims=hp.forward_prenet_dims,
                            highways=hp.forward_num_highways,
                            dropout=hp.forward_dropout,
                            n_mels=hp.num_mels).to(device)
    model.load(model_path)
    return model

## Load Model

In [None]:
tts_model = get_forward_model('pretrained/forward_46K.pyt')

tts_model.eval()

In [8]:
input_text = 'Checking the quality of Forward Tacotorn'

text = clean_text(input_text.strip())
x = text_to_sequence(text)
x = np.asarray(x)
x = torch.from_numpy(x)
x.size()

torch.Size([39])

## Export to ONNX

In [None]:
torch.onnx.export(tts_model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "forward_tac.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=12,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output1', 'output2',
                                  'output4'], # the model's output names
                  dynamic_axes={'input' : {1 : 'seq_length'},    # variable lenght axes
                                'output2' : {1 : 'seq_length'}})
print("Model converted succesfully")

## ONNX Model Inference

In [11]:
import onnxruntime

onnx_runtime_input = x.detach().numpy()
ort_session = onnxruntime.InferenceSession("forward_tac.onnx")

def to_numpy(tensor):
    print(tensor)
    return tensor.detach().cpu().numpy()

# # compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: onnx_runtime_input}
ort_outs = ort_session.run(None, ort_inputs)

## Synthesize

#### ONNX Model Output

In [None]:
m = torch.tensor(ort_outs[1]).unsqueeze(0).cuda()
with torch.no_grad():
    wav = voc_melgan.inference(m).cpu().numpy()
ipd.Audio(wav, rate=hp.sample_rate)

#### PyTorch Model Output

In [25]:
# Run PyTorch Model
_, torch_out, _ = tts_model(x, alpha=1)
# Synthesize with melgan and PyTorch Model Output
torch_out = torch_out.unsqueeze(0)
with torch.no_grad():
    wav = vocoder.inference(torch_out).cpu().detach().numpy()
ipd.Audio(wav, rate=hp.sample_rate)

## Compare PyTorch and ONNX Output

In [13]:
np.testing.assert_allclose(m, ort_outs[1], rtol=1e-03, atol=1e-04)

In [None]:
# import onnx
# from onnx_tf.backend import prepare

# onnx_model = onnx.load('forward_tac.onnx')
# tf_rep = prepare(onnx_model)
# tf_rep.export_graph('forward_tac.pb')