In [10]:
import torch
import os
import json
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from IPython.display import Audio

from loss_functions import GeneratorLoss
from data_class import TextMelCollate, TextMelLoader
from hparams import create_hparams
from inference_utils import recover_wav
from model_utils import (
    load_generator,
    load_model_checkpoint,
    validate_generator,
    parse_batch_generator,
)

In [11]:
hparams = create_hparams()
validation_dataset = TextMelLoader(hparams.validation_files, hparams)
collate_fn = TextMelCollate(hparams.n_frames_per_step)

validation_loader = DataLoader(
    validation_dataset,
    batch_size=hparams.batch_size,
    collate_fn=collate_fn,
    shuffle=False,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [12]:
model = load_generator(hparams, device)
critertion = GeneratorLoss()

In [13]:
model_file = "/home/xzodia/dev/emo-gan/outputs/train_gan/train_6/generator/checkpoint_epoch_100.pth.tar"

# state_dict = torch.load(model_file, map_location="cpu")
# model.load_state_dict(state_dict)

model = load_model_checkpoint(model_file, model)

In [14]:
trainset = TextMelLoader(hparams.training_files, hparams, normalize=False)
collate_fn = TextMelCollate(hparams.n_frames_per_step)

train_loader = DataLoader(
    trainset,
    num_workers=1,
    shuffle=True,
    batch_size=1,
    pin_memory=False,
    drop_last=True,
    collate_fn=collate_fn,
)

for batch in train_loader:
    x, y = parse_batch_generator(batch)
    break

In [15]:
(
    text_padded,
    _,
    emotion_vectors,
    _,
    _,
) = x

real_mel, _, _ = y

In [16]:
with torch.no_grad():
    mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model.inference(
        text_padded, emotion_vectors
    )

print(mel_outputs_postnet.size())
print(real_mel.size())

torch.Size([1, 80, 148])
torch.Size([1, 80, 132])


In [17]:
audio = recover_wav(real_mel, mel_outputs_postnet, n_iter=50)
audio_display = Audio(audio, rate=hparams.sampling_rate)
display(audio_display)

In [18]:
audio = recover_wav(real_mel, real_mel, n_iter=50)
audio_display = Audio(audio, rate=hparams.sampling_rate)
display(audio_display)