In [None]:
import os
from functools import partial

import torch
from matplotlib import pyplot as plt
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import random_split, DataLoader

from core.abs_loss import AbsLoss
from core.recurrent_attention_model import RecurrentAttentionModel
from utils.device_utils import to_device_fn

In [None]:
dataset_path = '../_datasets/test_valentini_speech_syllables_dataset.pt'
dataset_length = 500
train_ratio = 0.8
batch_size = 20
use_mps = True
use_cuda = False
num_epochs = 50

model_dir = '../_models/'
os.makedirs(model_dir, exist_ok=True)
weights_file_name = model_dir + 'weights_syllable_counter_model.pth'
model_file_name = model_dir + 'syllable_counter_model.pth'

In [None]:
custom_to_device_fn = partial(to_device_fn, use_cuda=use_cuda, use_mps=use_mps)

In [None]:
def collate_fn(batch):
    spectrograms, syllable_counts = zip(*batch)
    spectrograms = [s.squeeze(0).permute(1, 0) for s in spectrograms]
    spectrograms = pad_sequence(spectrograms, batch_first=True)
    spectrograms = custom_to_device_fn(spectrograms)
    syllable_counts = custom_to_device_fn(torch.tensor(syllable_counts))
    return spectrograms, syllable_counts


dataset = torch.load(dataset_path)
if dataset_length is not None and dataset_length < len(dataset):
    dataset, _ = random_split(dataset, [dataset_length, len(dataset) - dataset_length])

dataset_size = len(dataset)
train_size = round((dataset_size * train_ratio) / batch_size) * batch_size
val_size = dataset_size - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print('Finished data preparation')

In [None]:
model = RecurrentAttentionModel(1, 256, 2)
custom_to_device_fn(model)

print('Model initialized')

In [None]:
criterion = AbsLoss()
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, threshold=0.01, cooldown=5, min_lr=1e-5)

In [None]:
train_losses = []
val_losses = []

In [None]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

    train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(train_loss)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)

    val_loss = val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch + 1}, Training Loss: {train_loss:.2f}, Validation Loss: {val_loss:.2f}, LR: {lr:.1e}")

print('Finished Training')

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss during training')
plt.legend()

plt.show()

In [None]:
torch.save(model.state_dict(), weights_file_name)
torch.save(model, model_file_name)

print('Model saved')