Point out model in config

In [None]:
import argparse
import json
from dataclasses import asdict

import torch
from lightning import seed_everything
from config.config import TrainConfig
from src.models import Generator, TorchSTFT
from src.models.acoustic_model.fastspeech.lightning_model import FastSpeechLightning
from src.utils.utils import set_up_logger, write_wav, crash_with_msg
from src.utils.vocoder_utils import load_checkpoint, synthesize_wav_from_mel

In [None]:
def load_data(datalist):
    with open(datalist, encoding='utf-8') as f:
        data = [line.strip().split("|") for line in f]
    return data

all_data = load_data('path/to/all.txt')
import pandas as pd
df = pd.DataFrame(all_data, columns=['x','y','z', 'phone', 'text'])
text2phone = dict()
for i, row in df.iterrows():
    text = row['text'].lower()
    phone = row['phone']
    if not text in text2phone.keys():
        text2phone[text] = phone[1:-1]
    # print(text2phone)
    # break

In [None]:
len(text2phone)

In [None]:
from torch.utils.data import Dataset
class ESDDataset(Dataset):
    def __init__(self,
                 datalist="path/to/test/filelist"):
        self.datalist = datalist
        self.data = self.load_data()
        self.text2label = {
            "neutral": 0,
            "angry": 1,
            "happy": 2,
            "sad": 3,
            "surprise": 4,
        }
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        audio_name = self.data[idx][0].split('/')[-1][:-4]
        speaker_id = int(self.data[idx][1])
        phone = self.data[idx][4]
        emotion = self.data[idx][3]
        emotion_id = self.text2label[emotion]
        return audio_name, speaker_id, phone, emotion_id
    
    def load_data(self):
        with open(self.datalist, encoding='utf-8') as f:
            data = [line.strip().split("|") for line in f]
        return data
# train_dataset = ESDDataset(preprocess=processor_CLAP)
test_dataset = ESDDataset(datalist='path/to/test/filelist')
# len(train_dataset), len(test_dataset), train_dataset[0][0].shape, train_dataset[0][1].shape

In [None]:
len(test_dataset), test_dataset[0]

In [None]:
def get_input_dict(config, phone_sequence, speaker_id, emotion_id):
    with open(config.phones_path, "r") as f:
        phones_mapping = json.load(f)
    phone_ids = []
    for p in phone_sequence.split(" "):
        try:
            phone_ids.append(phones_mapping[p])
        except KeyError:
            crash_with_msg(
                f"Couldn't map input sequence: {phone_sequence} into phone ids. \n"
                f"Supported phones: {phones_mapping} \n"
                f"Phone: {p} is not in a dictionary."
            )
    texts = torch.tensor(phone_ids).long().unsqueeze(0)
    text_lens = torch.tensor([texts.shape[1]]).long()
    ids = [f"{speaker_id}_0_{emotion_id}"]
    speakers = torch.tensor([speaker_id])
    emotions = torch.tensor([emotion_id])
    mels, mel_lens, pitches, energies, durations, egemap_features = [None] * 6
    batch_dict = {
        "ids": ids,
        "speakers": speakers,
        "emotions": emotions,
        "texts": texts,
        "text_lens": text_lens,
        "mels": mels,
        "mel_lens": mel_lens,
        "pitches": pitches,
        "energies": energies,
        "durations": durations,
        "egemap_features": egemap_features,
    }
    return batch_dict

In [None]:
from tqdm import tqdm
set_up_logger("inference.log")
config = TrainConfig()

# phone_sequence = "S P IY2 K ER1 F AY1 V  T AO1 K IH0 NG W IH0 TH AE1 NG G R IY0 IH0 M OW0 SH AH0 N"
# speaker_id = 5
# emotion_id = 1
# generated_audio_path = "path/to/test.wav"

seed_everything(config.seed)
vocoder = Generator(**asdict(config))
stft = TorchSTFT(**asdict(config))
vocoder_state_dict = load_checkpoint(config.vocoder_checkpoint_path)
vocoder.load_state_dict(vocoder_state_dict["generator"])
vocoder.remove_weight_norm()
vocoder.eval()
model = FastSpeechLightning.load_from_checkpoint(
    config.testing_checkpoint,
    config=config,
    vocoder=vocoder,
    stft=stft,
    train=False,
)
model.eval()
torch.set_float32_matmul_precision(config.matmul_precision)
# 遍历test_dataset
# for i in range(len(test_dataset)):
import time
start_time = time.time()

for i in tqdm(range(len(test_dataset))):
    audio_name, speaker_id, phone_sequence, emotion_id = test_dataset[i]
    input_dict = get_input_dict(config, phone_sequence, speaker_id, emotion_id)
    model_output = model.model(model.device, input_dict)
    predicted_mel_len = model_output["mel_len"][0]
    predicted_mel_no_padding = model_output["predicted_mel"][0, :predicted_mel_len]
    generated_wav = synthesize_wav_from_mel(
        predicted_mel_no_padding, model.vocoder, model.stft
    )
    write_wav(f'path/to/wav/save/path', generated_wav, config.sample_rate)
end_time = time.time()

elapsed_time = end_time - start_time
print(f"run time: {elapsed_time} seconds.")
# input_dict = get_input_dict(config, phone_sequence, speaker_id, emotion_id)
# model_output = model.model(model.device, input_dict)
# predicted_mel_len = model_output["mel_len"][0]
# predicted_mel_no_padding = model_output["predicted_mel"][0, :predicted_mel_len]
# generated_wav = synthesize_wav_from_mel(
#     predicted_mel_no_padding, model.vocoder, model.stft
# )
# write_wav(generated_audio_path, generated_wav, config.sample_rate)
