__This notebook__ showcases how to infer a trained WaveStudent obtained using distill_waveglow.py. It requires the same environment as was used for training.

In [1]:
%env CUDA_VISIBLE_DEVICES=2
%load_ext autoreload
%autoreload 2

from IPython.display import Audio, display, clear_output
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch.utils.data as data
import torch.nn.functional as F
import torch.nn as nn
import torch

import numpy as np
import argparse
import math
import os

import librosa # was: 0.7.2
from pathlib import Path
from scipy.io import wavfile

import sys
sys.path.append('..')

from mel2samp import Mel2Samp, load_wav_to_torch
from models import defaults
from denoiser import StudentDenoiser, TeacherDenoiser

np.random.seed(1337)
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)

env: CUDA_VISIBLE_DEVICES=7


In [2]:
def save_wav(audio, path):
    wavfile.write(
        "tmp.wav", defaults.SAMPLING_RATE, 
        (np.clip(audio, -1, 1) * 32767).astype(np.int16)
    )
    !sox "tmp.wav" {path} norm -1
    !rm "tmp.wav"

In [3]:
from torch.utils.data import DataLoader

dataset_kwargs = dict(
    split='test',
    segment_length=None,
    filter_length=defaults.STFT_FILTER_LENGTH,
    hop_length=defaults.STFT_HOP_LENGTH,
    win_length=defaults.STFT_WIN_LENGTH,
    sampling_rate=defaults.SAMPLING_RATE,
    mel_fmin=defaults.MEL_FMIN,
    mel_fmax=defaults.MEL_FMAX
)
test_dataset = Mel2Samp("../data/wavs", **dataset_kwargs)

test_loader_kwargs = dict(
    num_workers=1, shuffle=False,
    batch_size=1, pin_memory=False
)
test_loader = DataLoader(test_dataset, **test_loader_kwargs)

In [4]:
model_path = '../pretrained_models'
sample_path = '../samples'
                            
sigma = 1.0
device = 'cuda'
dtype = torch.float16

### Create teacher model

In [5]:
from models.waveglow_teacher import WaveGlowTeacher
teacher_path = os.path.join(model_path, 'wg_teacher_ch256_wn12.pth')
teacher = WaveGlowTeacher.load(teacher_path, fp16=(dtype==torch.float16)).train(False).to(device)
teacher_denoiser = TeacherDenoiser(teacher, mode='zeros')

Cast WaveGlow to fp16


### Create student models

In [6]:
from denoiser import StudentDenoiser
from models.students import WaveNetStudent, WideFlowStudent, FlowStudent, AffineStudent


def create_student(model_path, ch=96, wn=4, student_arch='wg'):
    if student_arch == 'wg':
        student = WaveNetStudent(in_channels=8, mel_channels=640, hid_channels=ch, n_wavenets=wn,
                          wavenet_layers=8, kernel_size=3).to(device).train(True)
    elif student_arch == 'flow':
        student = FlowStudent(in_channels=8, mel_channels=640, hid_channels=ch, n_wavenets=wn,
                              wavenet_layers=8, kernel_size=3).to(device).train(True)
    elif student_arch == 'wide_flow':
        student = WideFlowStudent(in_channels=8, mel_channels=640, hid_channels=ch, n_wavenets=wn,
                              wavenet_layers=8, kernel_size=3).to(device).train(True)
    elif student_arch == 'affine':
        student = AffineStudent(in_channels=8, mel_channels=640, hid_channels=ch, n_wavenets=wn,
                                wavenet_layers=8, kernel_size=3).to(device).train(True)
    
    ckpt_path = os.path.join(model_path, f'{student_arch}_student_ch{ch}_wn{wn}.pth')
    checkpoint = torch.load(ckpt_path)
    student.load_state_dict(checkpoint['state_dict'])
    
    if student_arch in ['flow', 'affine']:
        # FIXME: do not add dummy params 
        student.in_proj = nn.Linear(1, 1).to(device).to(dtype)
    
    def remove(conv_list):
        new_conv_list = torch.nn.ModuleList()
        for old_conv in conv_list:
            old_conv = torch.nn.utils.remove_weight_norm(old_conv)
            new_conv_list.append(old_conv)
        return new_conv_list

    wavenets = student.WN if hasattr(student, "WN") else student.wavenets
    for wn in wavenets:
        wn.start = torch.nn.utils.remove_weight_norm(wn.start)
        wn.in_layers = remove(wn.in_layers)
        wn.cond_layers = torch.nn.utils.remove_weight_norm(wn.cond_layer)
        wn.res_skip_layers = remove(wn.res_skip_layers)
    return student.train(False).to(dtype).to(device)

In [7]:
students = {}
students['wg_ch96_wn2'] = create_student(model_path, ch=96, wn=2)
students['wg_ch96_wn4'] = create_student(model_path, ch=96, wn=4)
students['wg_ch128_wn4'] = create_student(model_path, ch=128, wn=4)
students['flow_ch96_wn4'] = create_student(model_path, ch=96, wn=4, student_arch='flow')
students['wide_flow_sh96_wn4'] = create_student(model_path, ch=96, wn=4, student_arch='wide_flow')
students['affine_sh96_wn4'] = create_student(model_path, ch=96, wn=4, student_arch='affine')

student_denoisers = {}
for key, student in students.items():
    student_denoisers[key] = StudentDenoiser(student, teacher, mode='zeros')

In [8]:
import shutil

result_path = os.path.join(sample_path, 'release_sigma_1')
!rm -rf {result_path}
!mkdir -p {result_path}

torch.cuda.empty_cache()
for sample_i, (mel, wav) in enumerate(tqdm(test_loader)):
    sample_path = test_dataset.audio_files[sample_i]
    sample_id = str(sample_path).split('/')[-1].split('.')[0]
    mel = mel.to(dtype).to(device)

    with torch.no_grad():
        ### Teacher ###
        inputs = teacher.sample_inputs_for(mel, sigma=sigma)
        teacher_prediction = teacher(*inputs)
        while (teacher_prediction < -1).any() or (teacher_prediction > 1).any():
            inputs = teacher.sample_inputs_for(mel, sigma=sigma)
            teacher_prediction = teacher(*inputs)

        teacher_prediction = teacher_denoiser(teacher_prediction, strength=0.004)
        teacher_audio = teacher_prediction.reshape(-1).clamp(-1, 1).data.cpu().float().numpy()

        ### Students ###
        student_mel = inputs[0]
        student_audios = {}
        for key in students.keys():
            student_input = torch.cat(inputs[1:], dim=1)
            student_prediction = students[key](student_input, student_mel).permute(0, 2, 1).flatten(1)
            student_prediction = student_denoisers[key](student_prediction, strength=0.004) 
            student_audios[key] = student_prediction.reshape(-1).clamp(-1, 1).data.cpu().float().numpy()

    path = os.path.join(result_path, 'teacher')
    os.makedirs(path, exist_ok=True)
    save_wav(teacher_audio, os.path.join(path, f"{sample_id}.wav"))

    for i, (key, student_audio) in enumerate(student_audios.items()):        
        path = os.path.join(result_path, f'{key}')
        os.makedirs(path, exist_ok=True)
        save_wav(student_audio, os.path.join(path, f"{sample_id}.wav"))

  sampling_rate, data = read(full_path)


  0%|          | 0/50 [00:00<?, ?it/s]

  sampling_rate, data = read(full_path)


### Estimate inference speed in MHz

In [None]:
@torch.no_grad()
def get_teacher_inference_speed(model, denoiser, repetitions=50):
    model.eval()
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    timings=np.zeros((repetitions,1))
    
    mel = test_dataset[0][0]
    inputs = model.sample_inputs_for(mel.unsqueeze(0).half().cuda())
    inputs = [input.repeat(32, 1, 2) for input in inputs]
    
    #GPU-WARM-UP
    for _ in range(10):
        output = denoiser(model(*inputs))
    num_samples = np.prod(output.shape)

    # MEASURE PERFORMANCE
    for rep in tqdm(range(repetitions)):
        starter.record()
        denoiser(model(*inputs))
        ender.record()
        # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    std_syn = np.std(timings)
    mhz = num_samples / mean_syn / 1000
    return mhz

@torch.no_grad()
def get_student_inference_speed(model, denoiser, repetitions=50):
    model.eval()
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    timings=np.zeros((repetitions,1))
    
    mel = test_dataset[0][0]
    inputs = teacher.sample_inputs_for(mel.unsqueeze(0).half().cuda())
    inputs = [input.repeat(32, 1, 2) for input in inputs]
    student_mel = inputs[0]
    student_input = torch.cat(inputs[1:], dim=1)
    
    #GPU-WARM-UP
    for _ in range(10):
        output = denoiser(model(student_input, student_mel).permute(0, 2, 1).flatten(1))
    num_samples = np.prod(output.shape)
    
    # MEASURE PERFORMANCE
    for rep in tqdm(range(repetitions)):
        starter.record()
        denoiser(model(student_input, student_mel).permute(0, 2, 1).flatten(1))
        ender.record()
        # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    std_syn = np.std(timings)
    mhz = num_samples / mean_syn / 1000
    return mhz


@torch.no_grad()
def get_melgan_inference_speed(model, repetitions=50):
    model.mel2wav.eval()
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    timings=np.zeros((repetitions,1))
    
    wav = test_dataset[0][1]
    melgan_mel = model(wav.unsqueeze(0)).repeat(32, 1, 2).half()
    
    #GPU-WARM-UP
    for _ in range(10):
        output = melgan.inverse(melgan_mel)
    num_samples = np.prod(output.shape)

    # MEASURE PERFORMANCE
    for rep in tqdm(range(repetitions)):
        starter.record()
        out = melgan.inverse(melgan_mel)
        ender.record()
        # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    std_syn = np.std(timings)
    mhz = num_samples / mean_syn / 1000
    return mhz

In [None]:
mhz = get_teacher_inference_speed(teacher, teacher_denoiser)
print(f"Teacher MHz: {mhz}")

In [None]:
for key, student in students.items():
    mhz = get_student_inference_speed(student, student_denoisers[key])
    print(f"{key} MHz: {mhz}")