In [None]:
import os

import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import random_split, DataLoader

from core.audio_model import RecurrentAttentionModel
from datasets.speech_phonemes_dataset import SpeechPhonemesDataset

In [None]:
dataset_path = '../_datasets/speech_phonemes_dataset.pt'
dataset_length = 5000
train_ratio = 0.7
val_ratio = 0.2
batch_size = 20
use_mps = True
use_cuda = False
num_epochs = 50

model_dir = '../_models/'
os.makedirs(model_dir, exist_ok=True)
weights_file_name = model_dir + 'weights_phonemes_counter_model.pth'
model_file_name = model_dir + 'phonemes_counter_model.pth'

In [None]:
def collate_fn(batch):
    if use_cuda and torch.cuda.is_available():
        device = torch.device("cuda")
    elif use_mps and torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    spectrograms, phoneme_counts = zip(*batch)
    spectrograms = [s.squeeze(0).permute(1, 0) for s in spectrograms]
    spectrograms = pad_sequence(spectrograms, batch_first=True)
    spectrograms = spectrograms.to(device).to(torch.float32)
    phoneme_counts = torch.tensor(phoneme_counts).to(device).to(torch.float32)

    return spectrograms, phoneme_counts


dataset: SpeechPhonemesDataset = torch.load(dataset_path)
if dataset_length is not None and dataset_length < len(dataset):
    dataset, _ = random_split(dataset, [dataset_length, len(dataset) - dataset_length])

dataset_size = len(dataset)
train_size = round((dataset_size * train_ratio) / batch_size) * batch_size
val_size = round((dataset_size * val_ratio) / batch_size) * batch_size
test_size = dataset_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print('Finished data preparation')

In [None]:
def move_model_to_device(model):
    if use_cuda and torch.cuda.is_available():
        device = torch.device("cuda")
    elif use_mps and torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    model.to(device)
    model.to(torch.float32)


model = RecurrentAttentionModel(216, 265, 4, 0.3)
move_model_to_device(model)

criterion = nn.HuberLoss()
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

print('Model initialized')

In [None]:
train_losses = []
val_losses = []

In [None]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

    train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(train_loss)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)

    val_loss = val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Training Loss: {train_loss:.0f}, Validation Loss: {val_loss:.0f}")

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss during training')
plt.legend()

plt.show()

In [None]:
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item() * inputs.size(0)

test_loss = test_loss / len(test_loader.dataset)

print(f'Test Loss: {test_loss:.4f}')

In [ ]:
torch.save(model.state_dict(), weights_file_name)
torch.save(model, model_file_name)

print('Model saved')