In [1]:
import tensorflow as tf
from IPython.display import Audio, display
import torchaudio
import numpy as np
import torch
from tqdm import tqdm

from main import SoundStreamEncoder, load_weights

my_model = SoundStreamEncoder()
my_model.build([1, 320, 1, 1])
my_model = load_weights(my_model)


wav, sr = torchaudio.load('./sample.wav')
wav = torchaudio.functional.resample(wav, sr, 16000)
wav = wav[:, 3 * 16000: 8*16000]
wav = torch.mean(wav, 0)[None]

def run_model(wav, my=False):
    tflite_model = tf.lite.Interpreter('./soundstream_encoder.tflite')
    quant_encode = tf.lite.Interpreter('./quantizer.tflite').get_signature_runner('encode')
    quant_decode = tf.lite.Interpreter('./quantizer.tflite').get_signature_runner('decode')
    decoder = tf.lite.Interpreter('./lyragan.tflite').get_signature_runner()

    tflite_runner = tflite_model.get_signature_runner()
    lyra_decoded_frames = []
    for i in tqdm(range(0, wav.shape[1], 320)):
        frame = wav[:, i: i + 320]
        if not my:
            embeddings = tflite_runner(input_audio=frame)['output_0']
        else:
            embeddings = my_model(frame.reshape(1, 320, 1, 1).numpy())

        x = quant_encode(input_frames=embeddings, num_quantizers=np.array([16]))['output_0']
        x = quant_decode(encoding_indices=x)['output_0']
        x = decoder(input_audio=x)['output_0']
        lyra_decoded_frames.append(x)

    lyra_decoded_audio = np.concatenate(lyra_decoded_frames, 1)  
    return lyra_decoded_audio

In [2]:
lyra_decoded_audio = run_model(wav, my=False)
my_decoded_audio = run_model(wav, my=True)

100%|█████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:01<00:00, 145.15it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:06<00:00, 41.17it/s]


In [3]:
lyra_obj = Audio(lyra_decoded_audio, rate=16000)
my_obj = Audio(my_decoded_audio, rate=16000)
orig_obj = Audio(wav, rate=16000)
display(orig_obj)
display(lyra_obj)
display(my_obj)