# Main training loop
The orchestrator that coordinates everything.
Contains:
- Epoch loop
- Calls training function
- Calls validation function
- Tracks best model
- Early stopping logic
- Model checkpointing (saving)
- Training history tracking (for plotting)

In [25]:
import torch

In [3]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, scheduler, save_path='/Users/hela/Code/pata/best_model.pth'):
    trained_model = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]}  # dict for graphs
    best_val_acc = 0.0
    patience = 10  # wait N epochs without improvement before stopping
    patience_counter = 0  # count epochs without improvement

    print('\n'+'='*70)
    print('Starting training')
    print('='*70+'\n')

    # loop for each epoch
    for epoch in range(num_epochs):
        print(f'\nEpoch [{epoch+1}/{num_epochs}]')
        # training
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
        # validation
        val_loss, val_acc, _, _ = validate_epoch(model, val_loader, criterion)
        # update lr
        scheduler.step(val_loss)  # adjust lr if val_loss doesn't improve
        
        # print results
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # save history for graphs
        trained_model['train_loss'].append(train_loss)
        trained_model['train_acc'].append(train_acc)
        trained_model['val_loss'].append(val_loss)
        trained_model['val_acc'].append(val_acc)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch':epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'val_acc': val_acc,
                        'val_loss': val_loss},
                       save_path)
            print(f"✓ Model saved. (Val Acc: {val_acc:.2f}%)")
            patience_counter = 0
        else:
            patience_counter += 1

        # early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break
    
    print('\n'+'='*70)
    print('Training completed')
    print('='*70+'\n')

    return trained_model

In [22]:
# FOR CHECK:
import torch.nn as nn
import import_ipynb
import pytorch_model_05_valid_fn
import pytorch_model_04_training_fn
import pytorch_model_03_CNN_class
import pytorch_model_02_transfoms_dataloaders
import pytorch_model_01_dataset_class
import torch.optim as optim

model = pytorch_model_03_CNN_class.SpectrogramCNN()
train_loader, val_loader = pytorch_model_02_transfoms_dataloaders.prepare_data_loaders()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001,weight_decay=0.50)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=5)

Dataset initialized with: 1600 samples
Label mapping: {'pa': 0, 'ta': 1}
Class distribution:
label
pa    800
ta    800
Name: count, dtype: int64

Dataset splits (val_split=0.2):
Training set: 1280
Validation set: 320


                                                                                           

Dataset initialized with: 1600 samples
Label mapping: {'pa': 0, 'ta': 1}
Class distribution:
label
pa    800
ta    800
Name: count, dtype: int64

Dataset splits (val_split=0.2):
Training set: 1280
Validation set: 320


                                                                                           

Dataset initialized with: 1600 samples
Label mapping: {'pa': 0, 'ta': 1}
Class distribution:
label
pa    800
ta    800
Name: count, dtype: int64

Dataset splits (val_split=0.2):
Training set: 1280
Validation set: 320




In [24]:
# CHECK:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, scheduler, save_path='/Users/hela/Code/pata/best_model.pth'):
    trained_model = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]}  # dict for graphs
    best_val_acc = 0.0
    patience = 10  # wait N epochs without improvement before stopping
    patience_counter = 0  # count epochs without improvement

    print('\n'+'='*70)
    print('Starting training')
    print('='*70+'\n')

    # loop for each epoch
    for epoch in range(num_epochs):
        print(f'\nEpoch [{epoch+1}/{num_epochs}]')
        # training
        train_loss, train_acc = pytorch_model_04_training_fn.train_epoch(model, train_loader, criterion, optimizer)
        # validation
        val_loss, val_acc, _, _ = pytorch_model_05_valid_fn.validate_epoch(model, val_loader, criterion)
        # update lr
        scheduler.step(val_loss)  # adjust lr if val_loss doesn't improve
        
        # print results
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # save history for graphs
        trained_model['train_loss'].append(train_loss)
        trained_model['train_acc'].append(train_acc)
        trained_model['val_loss'].append(val_loss)
        trained_model['val_acc'].append(val_acc)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch':epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'val_acc': val_acc,
                        'val_loss': val_loss},
                       save_path)
            print(f"✓ Model saved. (Val Acc: {val_acc:.2f}%)")
            patience_counter = 0
        else:
            patience_counter += 1

        # early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break
    
    print('\n'+'='*70)
    print('Training completed')
    print('='*70+'\n')

    return trained_model
    
    
trained_model = train_model(model=model,
                      train_loader=train_loader,
                      val_loader=val_loader,
                      criterion=criterion,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      num_epochs=10,
                      save_path='/Users/hela/Code/pata/best_model.pth')


Starting training


Epoch [1/10]


                                                                                           

Train Loss: 0.6949 | Train Acc: 50.31%
Val Loss: 0.6927 | Val Acc: 51.56%
Learning Rate: 0.001000




NameError: name 'torch' is not defined