In [None]:
import os

import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.optim import lr_scheduler, Adam
from torch.utils.data import DataLoader, random_split

from core.audio_model import AudioModel
from datasets.clean_noisy_dataset import CleanNoisyDataset

In [None]:
dataset_path = '../_datasets/clean_noisy_dataset.pt'
dataset_length = 10000
train_ratio = 0.8
batch_size = 16
use_mps = True
num_epochs = 50

model_dir = '../_models/'
os.makedirs(model_dir, exist_ok=True)
weights_file_name = model_dir + "weights_speech_denoiser_model.pth"
model_file_name = model_dir + "speech_denoiser_model.pth"

In [None]:
def collate_fn(batch):
    device = torch.device("mps" if use_mps and torch.backends.mps.is_available() else "cpu")
    batch = torch.utils.data.dataloader.default_collate(batch)
    batch = [x.to(device).to(torch.float32) for x in batch]
    return batch


dataset: CleanNoisyDataset = torch.load(dataset_path)
if dataset_length is not None and dataset_length < len(dataset):
    dataset, _ = random_split(dataset, [dataset_length, len(dataset) - dataset_length])

dataset_size = len(dataset)
train_size = round((dataset_size * train_ratio) / batch_size) * batch_size

train_dataset, val_dataset = random_split(dataset, [train_size, dataset_size - train_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print('Finished data preparation')

In [None]:
def move_model_to_mps(m):
    if use_mps and torch.backends.mps.is_available():
        device = torch.device("mps")
        m.to(device)
        m.to(torch.float32)


model = AudioModel()
model.init_weights()
move_model_to_mps(model)

criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.0001, weight_decay=1e-3)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

print('Model initialized')

In [None]:
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

    val_loss = val_loss / len(val_loader)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Training Loss: {train_loss:.0f}, Validation Loss: {val_loss:.0f}")

print('Finished Training')

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss during training')
plt.legend()

plt.show()

In [None]:
torch.save(model.state_dict(), weights_file_name)
torch.save(model, model_file_name)

print('Model saved')