In [None]:
import os
from functools import partial

import matplotlib.pyplot as plt
import torch
from torch.nn import MSELoss
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split

from core.audio_model import AudioModel
from datasets.lmdb_clean_noisy_dataset import LmdbCleanNoisyDataset
from utils.device_utils import device_collate_fn, to_device_fn

In [None]:
train_version = "v2"
dataset_length = None
train_ratio = 0.9
batch_size = 256
num_workers = 2
prefetch_factor = 2
pin_memory = True
use_mps = False
use_cuda = True
num_epochs = 50

model_dir = '../_models/'
os.makedirs(model_dir, exist_ok=True)
checkpoint_dir = model_dir + 'checkpoints/'
os.makedirs(checkpoint_dir, exist_ok=True)
weights_file_name = model_dir + "weights_speech_denoiser_model.pth"
model_file_name = model_dir + "speech_denoiser_model.pth"

In [None]:
lmdb_path = '../_datasets/clean_noisy_dataset.lmdb'
dataset = LmdbCleanNoisyDataset(lmdb_path)

In [None]:
custom_collate_fn = partial(device_collate_fn, use_cuda=use_cuda, use_mps=use_mps)
custom_to_device_fn = partial(to_device_fn, use_cuda=use_cuda, use_mps=use_mps)

In [None]:
print(f'Dataset total size: {len(dataset)}')

if dataset_length is not None and dataset_length < len(dataset):
  dataset, _ = random_split(dataset, [dataset_length, len(dataset) - dataset_length])

dataset_size = len(dataset)
train_size = round((dataset_size * train_ratio) / batch_size) * batch_size

train_dataset, val_dataset = random_split(dataset, [train_size, dataset_size - train_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, prefetch_factor=prefetch_factor,
                          pin_memory=pin_memory)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
                        prefetch_factor=prefetch_factor, pin_memory=pin_memory)

train_loader_len = len(train_loader.dataset)
val_loader_len = len(val_loader.dataset)

print(
  f'Finished train data preparation, train loader size: {train_loader_len}, val loader size: {val_loader_len}')

In [None]:
model = AudioModel()
model.init_weights()
# state_dict = torch.load(weights_file_name)
# model.load_state_dict(state_dict)
custom_to_device_fn(model)

print('Model initialized')

In [None]:
criterion = MSELoss()
optimizer = Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, threshold=1e-2,
                              cooldown=2, min_lr=1e-7)

print('Criterion and optimizer initialized')

In [None]:
train_losses = []
val_losses = []

In [None]:
for epoch in range(num_epochs):
  model.train()
  running_loss = 0.0
  for batch_idx, (inputs, targets) in enumerate(train_loader):
    inputs = custom_to_device_fn(inputs)
    targets = custom_to_device_fn(targets)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    running_loss += loss.item() * inputs.size(0)

  train_loss = running_loss / len(train_loader.dataset)
  train_losses.append(train_loss)

  model.eval()
  val_loss = 0.0
  with torch.no_grad():
    for inputs, targets in val_loader:
      inputs = custom_to_device_fn(inputs)
      targets = custom_to_device_fn(targets)
      outputs = model(inputs)
      loss = criterion(outputs, targets)
      val_loss += loss.item() * inputs.size(0)

  val_loss = val_loss / len(val_loader.dataset)
  val_losses.append(val_loss)

  scheduler.step(val_loss)

  lr = optimizer.param_groups[0]['lr']
  print(
    f"Epoch: {epoch + 1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Learning Rate: {lr:.1e}")

  checkpoint = {
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': val_loss
  }
  torch.save(checkpoint,
             f"{checkpoint_dir}/checkpoint_{train_version}_epoch_{epoch + 1}_loss_{val_loss:.4f}.pt")

print('Finished Training')

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss during training')
plt.legend()

plt.show()

In [None]:
torch.save(model.state_dict(), weights_file_name)
torch.save(model, model_file_name)

print('Model saved')