In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt

from IPython.display import clear_output

In [3]:
from torch.utils.data import DataLoader
from hw_tts.datasets import LJSpeechDataset, collate_fn
from hw_tts.melspecs import MelSpectrogram, MelSpectrogramConfig

featurizer = MelSpectrogram(MelSpectrogramConfig())

DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(LJSpeechDataset('./hw_tts/data/'), batch_size=4, collate_fn=collate_fn)

test_batch = next(iter(dataloader))
test_batch["melspecs"] = featurizer(test_batch["waveforms"]).transpose(1, 2)

In [4]:
from hw_tts.aligners import GraphemeAligner

aligner = GraphemeAligner(melspec_sr=22050).to(DEVICE)

test_batch["durations"] = (aligner(
    test_batch["waveforms"].to(DEVICE), test_batch["waveforms_lengths"], test_batch["transcripts"]
) * test_batch["melspecs"].shape[-2]).long()

# Model

In [None]:
from hw_tts.models import FastSpeech

model = FastSpeech(
    d_model=384,
    n_head=2,
    n_tokens=51,
    n_encoders=3,
    n_decoders=3,
    alpha=1.,
    melspec_size=80,
    kernel_size=3
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

loss_values = []

n_iters = 100
for _ in range(n_iters):
    mels, durs = model(test_batch)
    max_mel_len = min(mels.shape[-2], test_batch["melspecs"].shape[-2])
    
    loss = criterion(mels[:, :max_mel_len], test_batch["melspecs"][:, :max_mel_len]) + \
        criterion(durs, test_batch["durations"].float())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    loss_values.append(loss.item())
    
    clear_output(wait=True)
    plt.plot(loss_values)
    plt.xlabel("iter")
    plt.ylabel("loss")
    plt.grid()
