In [None]:
import json
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from typing import Tuple, Callable
from hyperparams import TVAEParams
from tvae import TVAE
from dataclasses import asdict
from earlystopping import EarlyStopping
from utils import *
from raw_audio_dataloader import get_tvae_dataloaders

In [None]:
# Criterion is 
# (reconstruction, target, mu, logvar) -> (total_loss, mse_loss, kl_loss)

CriterionType = Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]

In [None]:
def train(model:TVAE, train_loader: DataLoader, optimzer: optim.Optimizer, criterion:CriterionType, hp:TVAEParams)->float:
    model.train()
    total_loss = 0.0 
    total_mse_loss = 0.0
    total_kl_loss = 0.0

    for src, tgt in train_loader:
        src, tgt = src.to(hp.device), tgt.to(hp.device) # [B, C, Seq]
        optimzer.zero_grad()
        recon, mu, logvar = model(src, tgt) # recon[B, seq, C]
        # permute tgt to [B, Seq, C] to compute loss
        tgt = tgt.permute(0, 2 , 1)
        loss, mse_loss, kl_loss = criterion(tgt, recon, mu, logvar)
        
        loss.backward()
        optimzer.step()

        total_loss += loss.item() * src.size(0)
        total_mse_loss += mse_loss.item() * src.size(0)
        total_kl_loss += kl_loss.item() * src.size(0)

    total_loss /= len(train_loader.dataset)
    total_mse_loss /= len(train_loader.dataset)
    total_kl_loss /= len(train_loader.dataset)

    return total_loss, total_mse_loss, total_kl_loss

In [None]:
def validate(model:TVAE, val_loader: DataLoader, criterion:CriterionType, hp:TVAEParams)->float:
    model.eval()
    total_loss = 0.0 
    total_mse_loss = 0.0
    total_kl_loss = 0.0
    
    with torch.no_grad():
        for src, tgt in val_loader:
            src, tgt = src.to(hp.device), tgt.to(hp.device)
            recon, mu, logvar = model(src, tgt)
            tgt = tgt.permute(0, 2, 1)

            loss, mse_loss, kl_loss = criterion(tgt, recon, mu, logvar)
            total_loss += loss.item() * src.size(0)
            total_mse_loss += mse_loss.item() * src.size(0)
            total_kl_loss += kl_loss.item() * src.size(0)

        total_loss /= len(val_loader.dataset)
        total_mse_loss /= len(val_loader.dataset)
        total_kl_loss /= len(val_loader.dataset)
    
    return total_loss, total_mse_loss, total_kl_loss

In [None]:
hp = TVAEParams()
print(f"Parameters: {asdict(hp)}")

In [None]:
model = TVAE(hp=hp).to(hp.device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params}")
print(model)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=hp.lr)
criterion = vae_loss

In [None]:
train_loader, val_loader = get_tvae_dataloaders(hp=hp)
print(f'Train Seq #{len(train_loader.dataset)}, Val Seq #{len(val_loader.dataset)}')

In [None]:
es = EarlyStopping(patience=5, min_delta=0.001)

In [None]:
set_train = True

In [None]:
if set_train:
    log_data = []
    for epoch in range(hp.num_epochs):
        train_loss, train_mse_loss, train_kl_loss = train(model=model, train_loader=train_loader, optimzer=optimizer, criterion=criterion, hp=hp)
        val_loss, val_mse_loss, val_kl_loss = validate(model=model, val_loader=val_loader, cirterion=criterion, hp=hp)
        print(f"Train [ LOSS:{train_loss:.3f}, MSE:{train_mse_loss:.3f}, KL{train_kl_loss:.3f}]")
        print(f"Val [ LOSS:{val_loss:.3f}, MSE:{val_mse_loss:.3f}, KL{val_kl_loss:.3f}]")

        es(val_loss=vae_loss, model=model, model_dir=hp.model_dir, model_file_name=hp.model_file_name)
        if es.early_stop:
            print(f"Early stop triggered @ EPOCH: {epoch}")
            break

        logs = {
                "epoch": epoch,
                "train_mse_loss": train_mse_loss,
                "train_kl_loss": train_kl_loss,
                "train_total_loss": train_loss,
                "val_mse_loss": val_mse_loss,
                "val_kl_loss": val_kl_loss,
                "val_total_loss": val_loss
            }
        
        log_data.append(log_data)

    # Save training logs to a JSON file.
    os.makedirs(hp.log_dir, exist_ok=True)
    log_path = os.path.join(hp.log_dir, hp.train_log_file)
    with open(log_path, 'w') as f:
        json.dump(logs, f, indent=4)
    print(f"Training logs saved to {log_path}")
    print("Training and validation completed ")
else:
    model.load_checkpoint(f"{hp.model_dir}/{hp.model_file_name}")

In [None]:
save_generated_audio(model,hp=hp)