In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.optim import lr_scheduler, Adam

from core.audio_handler import load_dataset
from core.audio_model import AudioModel
from core.dataset_handler import DatasetHandler

In [None]:
num_epochs = 30
dataset_length = 2000

In [None]:
dataset_handler = DatasetHandler(batch_size=16, use_mps=True)
dataset = load_dataset("dataset.pt")
train_loader, val_loader = dataset_handler.split_dataset_into_data_loaders(dataset, dataset_length)

print('Finished data preparation')

In [None]:
model = AudioModel(use_mps=True)
model.init_weights()

criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.0001, weight_decay=1e-3)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

print('Model initialized')

In [None]:
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

    val_loss = val_loss / len(val_loader)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Training Loss: {train_loss:.0f}, Validation Loss: {val_loss:.0f}")

print('Finished Training')

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss during training')
plt.legend()

plt.show()

In [None]:
model.save()
print('Model saved')