This notebook converts pre-trained model of HiFi GAN to TFLite format.

## Clone the Repository

In [None]:
!git clone https://github.com/jik876/hifi-gan.git
%cd hifi-gan/
!pip install -r requirements.txt
!pip install numba==0.48
# Nightly version needed for tflite inference
!pip install tf-nightly

**Runtime restart required**

## Imports

In [None]:
!pip install onnx
!pip install onnxruntime
!pip install pip install git+https://github.com/onnx/onnx-tensorflow.git

In [2]:
import json
import IPython
import os
import shutil
import torch
import onnxruntime
import onnx
import numpy as np
import tensorflow as tf

%cd hifi-gan/

from onnx_tf.backend import prepare
from models import Generator
from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav


# Download Hifi-GAN V3
!gdown --id 18TNnHbr4IlduAWdLrKcZrqmbfPOed1pS -O generator_v3

# Download sample audio file
!wget https://storage.googleapis.com/demo-experiments/demo_tts.wav

# Download config file
!wget https://raw.githubusercontent.com/jik876/hifi-gan/master/config_v3.json

/content/hifi-gan
Downloading...
From: https://drive.google.com/uc?id=18TNnHbr4IlduAWdLrKcZrqmbfPOed1pS
To: /content/hifi-gan/generator_v3
5.87MB [00:00, 51.8MB/s]
--2020-12-31 08:47:05--  https://storage.googleapis.com/demo-experiments/demo_tts.wav
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.137.128, 142.250.101.128, 2607:f8b0:4023:c03::80, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.137.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 116780 (114K) [audio/wav]
Saving to: ‘demo_tts.wav’


2020-12-31 08:47:05 (110 MB/s) - ‘demo_tts.wav’ saved [116780/116780]

--2020-12-31 08:47:06--  https://raw.githubusercontent.com/jik876/hifi-gan/master/config_v3.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting respons

## Helper Functions

In [3]:
def get_mel(x):
    return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)

In [4]:
config = 'config_v3.json'
device = 'cpu'
with open(config) as f:
    data = f.read()

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

json_config = json.loads(data)
h = AttrDict(json_config)

torch_checkpoints = torch.load("generator_v3", map_location=torch.device('cpu'))
torch_generator_weights = torch_checkpoints["generator"]
torch_model = Generator(h)
torch_model.load_state_dict(torch_checkpoints["generator"])

<All keys matched successfully>

## Generate Mel from Audio


In [5]:
wav, sr = load_wav('/content/hifi-gan/demo_tts.wav')
wav = wav / MAX_WAV_VALUE
wav = torch.FloatTensor(wav).to(device)
x = get_mel(wav.unsqueeze(0))

In [6]:
torch_model.eval()
torch_model.remove_weight_norm()

hifigan_output = torch_model(x)

Removing weight norm...


## Play Audio

In [7]:
output = hifigan_output.squeeze()
audio = output.detach().numpy()

IPython.display.display(IPython.display.Audio(audio, rate=22050))

## ONNX Conversion

#### Generate random input

In [8]:
x = torch.randn(1, 80, 100, requires_grad=True)
onnx_runtime_input = x.detach().numpy()
torch_out = torch_model(x)
store_out = torch_out[0].detach().numpy()
print("Output size", torch_out[0].size())

Output size torch.Size([1, 25600])


In [9]:
# Export the model
torch.onnx.export(torch_model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "hifigan.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {2 : 'seq_length'},    # variable lenght axes
                                'output' : {1 : 'seq_length'}})
print("Model converted succesfully")

Model converted succesfully


## ONNX Inference

In [10]:
ort_session = onnxruntime.InferenceSession("hifigan.onnx")

def to_numpy(tensor):
    print(tensor)
    return tensor.detach().cpu().numpy()

# # compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: onnx_runtime_input}
ort_outs = ort_session.run(None, ort_inputs)

#### Compare onnx output and PyTorch model output

In [11]:
onnx_output = np.squeeze(ort_outs[0], 0)
np.testing.assert_allclose(store_out, onnx_output, rtol=1e-03, atol=1e-04)

## ONNX Model to TensorFlow graph

In [12]:
onnx_model = onnx.load('hifigan.onnx')
tf_rep = prepare(onnx_model)
tf_rep.export_graph('hifigan.pb')



INFO:tensorflow:Assets written to: hifigan.pb/assets


INFO:tensorflow:Assets written to: hifigan.pb/assets


## TFLite Conversion

In [13]:
def convert_tflite(quantization='dr'):
    loaded = tf.saved_model.load('hifigan.pb')
    concrete_func = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    # concrete_func.inputs[0].set_shape([None, 80, 1280, 800])
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    if quantization == 'float16':
        converter.target_spec.supported_types = [tf.float16]
    tflite_model = converter.convert()
    model_name = f'hifigan_{quantization}.tflite'
    with open(model_name, 'wb') as f:
      f.write(tflite_model)

In [14]:
# Dynamic Range Quantization
convert_tflite()
!du -sh hifigan_dr.tflite

3.5M	hifigan_dr.tflite


In [15]:
# Float16 Quantization
convert_tflite('float16')
!du -sh hifigan_float16.tflite

2.9M	hifigan_float16.tflite


In [16]:
import tensorflow as tf
tf.__version__

'2.5.0-dev20201230'

## TFLite Inference

In [17]:
def tflite_inference(input, quantization='dr'):
    model_name = f'hifigan_{quantization}.tflite'
    interpreter = tf.lite.Interpreter(model_path=model_name)
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    interpreter.resize_tensor_input(input_details[0]['index'],  [1, input.shape[1], input.shape[2]], strict=True)
    interpreter.allocate_tensors()
    interpreter.set_tensor(input_details[0]['index'], input)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    return output

In [18]:
wav, sr = load_wav('/content/hifi-gan/demo_tts.wav')
wav = wav / MAX_WAV_VALUE
wav = torch.FloatTensor(wav).to('cpu')
x = get_mel(wav.unsqueeze(0))
input = x.detach().numpy()
input.shape

(1, 80, 228)

#### Dynamic Range Model inference

In [19]:
output = tflite_inference(input)

In [20]:
from IPython.display import Audio

output = output.squeeze()
Audio(output, rate=22050)

#### Float16 Model Inference

In [21]:
output = tflite_inference(input, quantization='float16')

output = output.squeeze()
Audio(output, rate=22050)