# Classifying Audio using Spectrograms and CNNs

## Note on Data locations

I recommend putting the audio (WAV) and image (PNG) files in `data\audio` and `data\iamges` directories, respectively.

The `data\` directory is already included in the `.gitignore` file, and so these large binary files won't be included in commits.

### Example:

<img src="data_structure_example.png" style="widht:400px; height:auto;">

## Imports

In [None]:
import os
import datetime
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as ipd
from timeit import default_timer as timer
from typing import Tuple

import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import librosa
import librosa.display

from nnViewer import wrap_model, run_gui

## AudioFile class

- `file_path`: Path to the audio file
- `file_name`: Name of the audio file (extracted from the path)
- `label`: Label of the audio file (derived from the parent directory name)
- `audio`: Loaded audio data
- `sample_rate`: Sampling rate of the audio file
- `duration`: Duration of the audio file in seconds

### Methods

- `display_waveform()`: Display the waveform of the audio file
- `play()`: Play the audio file and return an audio player widget
- `trim(top_db=30)`: Trim silent parts of the audio using a decibel threshold
- `create_spectrogram()`: Generate a mel spectrogram of the audio file
- `show_spectrogram()`: Display the spectrogram of the audio file
- `save_spectrogram(output_dir=None, skip_existing=True)`: Save the spectrogram as a PNG file


In [None]:
class AudioFile:
    """
    A class to handle audio files and provide utilities for analysis and visualization.

    Attributes:
        file_path (str): Path to the audio file.
        file_name (str): Name of the audio file (extracted from the path).
        label (str): Label of the audio file (derived from the parent directory name).
        audio (np.ndarray): Loaded audio data.
        sample_rate (int): Sampling rate of the audio file.
        duration (float): Duration of the audio file in seconds.
    """

    def __init__(self, file_path):
        """
        Initialize the AudioFile instance by loading the audio file and extracting metadata.

        Args:
            file_path (str): Path to the audio file.
        """
        self.file_path = file_path
        self.file_name = os.path.basename(file_path)
        self.label = os.path.basename(os.path.dirname(self.file_path))
        self.audio, self.sample_rate = librosa.load(file_path)
        self.audio = librosa.util.normalize(self.audio)   # normalize audio
        self.duration = librosa.get_duration(y=self.audio, sr=self.sample_rate)

    def display_waveform(self):
        """
        Display the waveform of the audio file.
        """
        librosa.display.waveshow(self.audio, sr=self.sample_rate)
        plt.show()
        plt.close()

    def play(self):
        """
        Play the audio file.

        Returns:
            IPython.display.Audio: audio player widget.
        """
        return ipd.display(ipd.Audio(self.audio, rate=self.sample_rate))

    def trim(self, top_db=50):
        """
        Trim silent parts of the audio based on a decibel threshold.

        Args:
            top_db (int, optional): Decibel threshold below which audio is considered silent. Defaults to 30.
        """
        self.audio, _ = librosa.effects.trim(self.audio, top_db=top_db)

    def create_spectrogram(self):
        """
        Create a mel spectrogram of the audio file.

        Returns:
            np.ndarray: The mel spectrogram in decibel units.
        """
        mel_scale_sgram = librosa.feature.melspectrogram(
            y=self.audio,
            sr=self.sample_rate,
            power=1)
        mel_sgram = librosa.amplitude_to_db(mel_scale_sgram, ref=np.min)
        return mel_sgram

    def display_spectrogram(self):
        """
        Display the spectrogram of the audio file.
        """
        _spectrogram = self.create_spectrogram()

        fig, ax = plt.subplots()
        img = librosa.display.specshow(
            _spectrogram,
            sr=self.sample_rate,
            x_axis='time',
            y_axis='mel',
            ax=ax)
        plt.colorbar(img, format='%+2.0f dB')

        # remove whitespace around image
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

        plt.show()
        plt.close(fig)

    def save_spectrogram(self, output_dir=None, skip_existing=True):
        """
        Save the spectrogram as a PNG file.

        Args:
            output_dir (str, optional): Directory to save the spectrogram. Defaults to the directory of the audio file.
            skip_existing (bool, optional): Whether to skip saving if the file already exists. Defaults to True.
        """
        if not output_dir:
            output_dir = os.path.dirname(self.file_path)
        else:
            output_dir = os.path.join(output_dir, self.label)

        base, _ = os.path.splitext(self.file_name)
        output_file = os.path.join(output_dir, base + ".png")

        if skip_existing and os.path.exists(output_file):
            return

        spectrogram = self.create_spectrogram()
        librosa.display.specshow(spectrogram, sr=self.sample_rate)

        os.makedirs(output_dir, exist_ok=True)
        # save, removing whitespace
        plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
        plt.close()


### Example of using AudioFile class

In [None]:
_audio_file = os.path.join("data", "audio", "Speech Commands", "backward", "0a2b400e_nohash_0.wav")
test_audio = AudioFile(_audio_file)

test_audio.display_waveform()
test_audio.display_spectrogram()
test_audio.trim(top_db=50)
test_audio.play()
test_audio.display_waveform()
test_audio.display_spectrogram()
test_audio.play()

## Convert Audio Files to Spectrograms

 - set input_dir and output_dir accordingly
 - call process_directory()
 - if skip_existing is True, existing spectrogram PNG files will be skipped (recommended)


### NOTE:

- Only run this cell if you need to save out all the spectrograms. It takes awhile, and is prone to crashing (hence the use of skip_existing, so it can continue where it left off).
- Commented out the "process_directory(...)" line at the bottom to avoid accidental runs

In [None]:
def process_directory(input_dir, skip_existing=True, include_trimmed=False):
    output_dir = os.path.join("data", "images", os.path.basename(input_dir))
    output_dir_trimmed = os.path.join(output_dir + " (trimmed)")
    for root, dirs, files in os.walk(input_dir):
        # sort directories alphabetically
        dirs.sort()
        directory = os.path.basename(root)
        print(f"Processing directory: {directory}")
        for file in files:
            if file.endswith('.wav'):
                audio = None
                # trim off .wav from file
                base, _ = os.path.splitext(file)
                output_file = os.path.join(output_dir, directory, base + ".png")
                
                if not (skip_existing and os.path.exists(output_file)):
                    # load file
                    audio = AudioFile(os.path.join(root, file))
                    # save spectrogram
                    audio.save_spectrogram(output_dir, skip_existing=skip_existing)

                if include_trimmed:
                    output_file_trimmed = os.path.join(output_dir_trimmed, directory, base + ".png")
                    if not (skip_existing and os.path.exists(output_file_trimmed)):
                        # have we loaded the file already?
                        if not audio:
                            audio = AudioFile(os.path.join(root, file))
                        # trim and save
                        audio.trim()
                        audio.save_spectrogram(output_file_trimmed, skip_existing=skip_existing)
                
process_directory('data/audio/Speech Commands_noise', skip_existing=True, include_trimmed=False)

### Corrupt image check

In [None]:
from PIL import Image

# Define your directories
output_dir = os.path.join("data", "images", "Speech Commands")
output_dir_trimmed = os.path.join("data", "images", "Speech Commands (trimmed)")

def check_png_corruption(directories, output_file="corrupt_pngs.txt"):
    corrupt_files = []
    for directory in directories:
        for root, dirs, files in os.walk(directory):
            for file in files:
                if file.lower().endswith('.png'):
                    file_path = os.path.join(root, file)
                    try:
                        with Image.open(file_path) as img:
                            img.verify()  # This will raise an exception if the file is corrupted.
                    except Exception as e:
                        print(f"Corrupted PNG found: {file_path} (Error: {e})")
                        corrupt_files.append(file_path)

    # Write the results to a text file
    if corrupt_files:
        with open(output_file, "w") as f:
            for filename in corrupt_files:
                f.write(f"{filename}\n")
        print(f"Found {len(corrupt_files)} corrupt PNG(s). Details saved to '{output_file}'.")
    else:
        print("No corrupt PNG files found.")

# List of directories to check
directories_to_check = [output_dir, output_dir_trimmed]

# Run the check
check_png_corruption(directories_to_check)

## Convolutional Neural Net

*Pre-process TODO*

- spectrogram: look into librosa specshow options. remove black bar at top of many. check axes.

#### Setup and Parameters

In [2]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"  # Apple Metal Performance Shaders
else:
    device = "cpu"

print(f"Using {device} device")


datetime_stamp = datetime.datetime.now().strftime('%y%m%d-%H%M')

# setup tensorboard
writer_path = f"./logs/run_{datetime_stamp}"
writer = SummaryWriter(writer_path)

# path to save model
model_path = f"./models/cnn_ryan/{datetime_stamp}"
            

## PARAMETERS ##
image_size = (256, 190) # from 496x369. Closely maintains aspect ratio
num_channels = 3 # RGB images
# for DataLoader
batch_size = 64
num_workers = 2

Using mps device


 #### Transform and Load data 

In [3]:
data_transforms = transforms.Compose([
    transforms.Resize((image_size[0], image_size[1])), # convert to square that can easily divide evenly
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# ImageFolder will load data from subdirectories and assign integer labels
data_dir = "data/images/Speech Commands"
dataset = datasets.ImageFolder(root=data_dir, transform=data_transforms)

# get class labels
class_labels = [label for label in dataset.class_to_idx]
num_classes = len(class_labels)
print(f"Classes: {num_classes}")

# for speed of training, we'll only use a subset of the data, randomly selected
fraction = 1
subset_size = int(len(dataset) * fraction)
indices = torch.randperm(len(dataset))[:subset_size]
dataset = torch.utils.data.Subset(dataset, indices)


# split into training and validation sets
val_split = 0.2
val_size = int(val_split * len(dataset))
train_size = len(dataset) - val_size

X_train, X_val = torch.utils.data.random_split(dataset, [train_size, val_size])
print(f"Training: {train_size}\nValidation: {val_size}")

# load batches
train_batches = torch.utils.data.DataLoader(
    X_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True)

val_batches = torch.utils.data.DataLoader(
    X_val,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True)

Classes: 35
Training: 84664
Validation: 21165


#### Helper Functions

In [4]:
def display_spectrogram(img):
    img = img / 2 + 0.5     # de-normalize
    plt.axis('off')
    plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))

def display_spectrogram_batches(batches, writer_path=None):
    _iter = iter(batches)
    images, _ = next(_iter)
    img_grid = torchvision.utils.make_grid(images)
    display_spectrogram(img_grid)

    if writer_path:
        # write to tensorboard
        writer.add_images(writer_path, img_grid.unsqueeze(0))

def calc_convolution_output(image_size: Tuple[int, int], kernel_size, stride=1, padding=0):
    _w = int((image_size[0] - kernel_size[0] + 2 * padding) / stride) + 1
    _h = int((image_size[1] - kernel_size[1] + 2 * padding) / stride) + 1
    return (_w, _h)

def calc_pooling_output(image_size: Tuple[int, int], kernel_size, stride=None):
    if stride is None:
        stride = kernel_size
    _w = int((image_size[0] - kernel_size[0]) / stride[0]) + 1
    _h = int((image_size[1] - kernel_size[1]) / stride[1]) + 1
    return (_w, _h)

#### __Neural Net Class__


#### ResNet style

In [None]:
# A simple convolutional layer with BatchNorm and ReLU
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, zero_bn=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                              stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        # If zero_bn is True, initialize the BatchNorm weight (gamma) to zero.
        if zero_bn:
            nn.init.constant_(self.bn.weight, 0)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

# The stem: a sequence of ConvLayers with the first layer doing stride-2,
# followed by a max pooling layer.
def _resnet_stem(*sizes):
    # sizes should be a sequence like (in_channels, mid_channels, out_channels)
    stem_layers = [
        ConvLayer(sizes[i], sizes[i+1], kernel_size=3, stride=2 if i == 0 else 1, padding=1)
        for i in range(len(sizes) - 1)
    ]
    stem_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    return stem_layers

# Basic residual block with two 3x3 convolutions.
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        # Adjust the shortcut if dimensions differ
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

# Bottleneck residual block with a 1x1, 3x3, 1x1 structure.
class BottleneckBlock(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        # Reduce channels
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # 3x3 convolution
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        # Restore channels (multiplying by expansion factor)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes):
        super().__init__()
        # Use a stem instead of a single initial convolution:
        # For example, a stem from 3 channels -> 32 channels -> 64 channels.
        self.stem = nn.Sequential(*_resnet_stem(3, 32, 64))
        # After the stem, the output has 64 channels.
        self.in_channels = 64

        # Residual layers
        self.layer1 = self._make_layer(block, 64,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # Final fully connected layer
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.stem(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # Global average pooling; adjust kernel size as needed.
        out = F.avg_pool2d(out, out.size()[2:])
        out = out.view(out.size(0), -1)
        return self.linear(out)

# ResNet 18 style, using BasicBlock
model18 = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

# ResNet 50 style, using BottleneckBlock
# model50 = ResNet(BottleneckBlock, [3, 4, 6, 3], num_classes=num_classes)

print(summary(model18, input_size=(num_channels, image_size[0], image_size[1])))
# print(summary(model50, input_size=(num_channels, image_size[0], image_size[1])))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 32, 128, 95]             864
       BatchNorm2d-2          [-1, 32, 128, 95]              64
              ReLU-3          [-1, 32, 128, 95]               0
         ConvLayer-4          [-1, 32, 128, 95]               0
            Conv2d-5          [-1, 64, 128, 95]          18,432
       BatchNorm2d-6          [-1, 64, 128, 95]             128
              ReLU-7          [-1, 64, 128, 95]               0
         ConvLayer-8          [-1, 64, 128, 95]               0
         MaxPool2d-9           [-1, 64, 64, 48]               0
           Conv2d-10           [-1, 64, 64, 48]          36,864
      BatchNorm2d-11           [-1, 64, 64, 48]             128
           Conv2d-12           [-1, 64, 64, 48]          36,864
      BatchNorm2d-13           [-1, 64, 64, 48]             128
       BasicBlock-14           [-1, 64,

In [None]:
class SpectrogramCNN(nn.Module):
    def __init__(self, num_classes):
        super(SpectrogramCNN, self).__init__()
        
        # Kernel calculations
        # W: input dimension; F: filter size; P: padding; S: stride
        # [(W - f + 2p) / s] + 1
         
        # input: (3, 256, 190)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1)
        _map_size = calc_convolution_output(image_size, kernel_size=self.conv1.kernel_size, stride=self.conv1.stride[0], padding=self.conv1.padding[0])
        self.instance_norm = nn.InstanceNorm2d(64)    # speaker normalization
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        _map_size = calc_pooling_output(_map_size, kernel_size=self.pool1.kernel_size)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        _map_size = calc_convolution_output(_map_size, kernel_size=self.conv2.kernel_size, stride=self.conv2.stride[0], padding=self.conv2.padding[0])
        self.batch_norm_2 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
        _map_size = calc_pooling_output(_map_size, kernel_size=self.pool2.kernel_size)
        
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        _map_size = calc_convolution_output(_map_size, kernel_size=self.conv3.kernel_size, stride=self.conv3.stride[0], padding=self.conv3.padding[0])
        self.batch_norm_3 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(kernel_size=(2, 2))
        _map_size = calc_pooling_output(_map_size, kernel_size=self.pool3.kernel_size)
        
        # final feature dimensions - channels x width x height
        feature_dims = 256 * _map_size[0] * _map_size[1]
        print(f'feature dimensions: {feature_dims}')
        
        # Fully Connected with larger intermediate layers
        self.fc1 = nn.Linear(in_features=feature_dims, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=256)
        self.fc3 = nn.Linear(in_features=256, out_features=num_classes)
        
        # dropout for regularization
        self.dropout2d = nn.Dropout2d(0.2)  # spatial dropout for conv layers
        self.dropout = nn.Dropout(0.4)      # regular dropout for FC layers
    
    def forward(self, x):
        # block 1
        x = F.relu(self.instance_norm(self.conv1(x)))  # using speaker normalization here
        x = self.pool1(x)
        x = self.dropout2d(x)
        
        # block 2
        x = F.relu(self.batch_norm_2(self.conv2(x)))
        x = self.pool2(x)
        x = self.dropout2d(x)
        
        # block 3
        x = F.relu(self.batch_norm_3(self.conv3(x)))
        x = self.pool3(x)
        x = self.dropout2d(x)
        
        # flatten
        x = x.view(x.size(0), -1)
        
        # FC layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x
    

model = SpectrogramCNN(num_classes)
print(summary(model, input_size=(num_channels, image_size[0], image_size[1])))
model = model.to(device)

#### Training

In [15]:
# display training time
def display_training_time(start, end):
    total_time = end - start
    print(f"Training time : {total_time:.3f} seconds")
    return total_time

# display training info for each epoch
def display_training_info(epoch, val_loss, train_loss, accuracy, learning_rate):
    accuracy = round(accuracy, 2)
    print(f"\nEpoch: {epoch} | Training loss: {train_loss.item():.3f} | Validation loss: {val_loss.item():.3f} | Accuracy: {accuracy:.2f}% | LR: {learning_rate:.4g}")
    
# calculate accuracy
def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

# training
def train_neural_net(epochs, model, loss_func, optimizer, train_batches, val_batches, stop_patience=4, lr_scheduler=None):
    final_accuracy = 0
    last_val_accuracy = 0
    epochs_without_improvement = 0
            
    for epoch in tqdm(range(epochs), desc="Epochs"):
        # === training ===
        model.train()
        with torch.enable_grad():
            train_loss = 0
            for images, labels in tqdm(train_batches, desc="Training Batches", leave=False):
                labels = labels.to(device)
                predictions = model(images.to(device))
                loss = loss_func(predictions, labels)
                train_loss += loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # update learning rate here if using OneCycleLR
                if isinstance(lr_scheduler, torch.optim.lr_scheduler.OneCycleLR):
                    lr_scheduler.step()

            train_loss /= len(train_batches)


        # === evaluation ===
        val_loss, val_accuracy = 0, 0
        model.eval()
        with torch.inference_mode():
            for images, labels in tqdm(val_batches, desc="Validation Batches", leave=False):
                labels = labels.to(device)
                predictions = model(images.to(device))
                val_loss += loss_func(predictions, labels)
                val_accuracy += accuracy_fn(y_true=labels, y_pred=predictions.argmax(dim=1))

            val_loss /= len(val_batches)
            val_accuracy /= len(val_batches)
            final_accuracy = val_accuracy


        # update learning rate here if using ReduceLROnPlateau
        if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            lr_scheduler.step(val_loss)

        # get current leqrning rate            
        learning_rate = optimizer.param_groups[0]['lr']

        display_training_info(epoch+1, val_loss, train_loss, val_accuracy, learning_rate)

        # log to tensorboard
        writer.add_scalars("Loss", {"Training": train_loss, "Validation": val_loss}, epoch)
        writer.add_scalar("Accuracy", val_accuracy, epoch)
        writer.add_scalar("Learning Rate", learning_rate, epoch)        

        # save model if validation accuracy improves
        if val_accuracy > last_val_accuracy:
            os.makedirs(model_path, exist_ok=True)
            torch.save(model.state_dict(), f"{model_path}/model_params.pt")
            torch.save(model, f"{model_path}/model_full.pt")
            print(f"Checkpoint saved at epoch {epoch+1} with val accuracy {val_accuracy:.2f}%")
            epochs_without_improvement = 0  # reset counter
        else:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s)")

        if epochs_without_improvement >= stop_patience:
            print(f"Early stopping: no improvement in validation accuracy for {stop_patience} epochs.")
            break

        # update last_val_accuracy
        last_val_accuracy = val_accuracy
        
    return final_accuracy

In [None]:
model = model18
model = model.to(device)

max_epochs = 80
stop_patience = 6   # if no improvement in validation accuracy for this many epochs, stop training
learning_rate = 3e-3

# parameters for OneCycleLR
max_learning_rate = 1e-2
steps_per_epoch = len(train_batches)
total_steps = max_epochs * steps_per_epoch  # total number of training steps

loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_learning_rate, total_steps=total_steps)


train_time_start_on_gpu = timer()
model_accuracy = train_neural_net(max_epochs, model, loss_func, optimizer, train_batches, val_batches, stop_patience, lr_scheduler)
print(f"\nTraining complete : {model_accuracy} %")
display_training_time(start=train_time_start_on_gpu, end=timer())

writer.flush()
writer.close()

Epochs:   0%|          | 0/80 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 1 | Training loss: 0.766 | Validation loss: 0.335 | Accuracy: 89.91% | LR: 0.0004411
Checkpoint saved at epoch 1 with val accuracy 89.91%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 2 | Training loss: 0.242 | Validation loss: 0.306 | Accuracy: 90.38% | LR: 0.0005636
Checkpoint saved at epoch 2 with val accuracy 90.38%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 3 | Training loss: 0.196 | Validation loss: 0.222 | Accuracy: 93.04% | LR: 0.0007654
Checkpoint saved at epoch 3 with val accuracy 93.04%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 4 | Training loss: 0.180 | Validation loss: 0.232 | Accuracy: 93.31% | LR: 0.001043
Checkpoint saved at epoch 4 with val accuracy 93.31%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 5 | Training loss: 0.169 | Validation loss: 0.244 | Accuracy: 92.67% | LR: 0.001392
No improvement for 1 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 6 | Training loss: 0.160 | Validation loss: 0.239 | Accuracy: 93.06% | LR: 0.001806
Checkpoint saved at epoch 6 with val accuracy 93.06%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 7 | Training loss: 0.154 | Validation loss: 0.217 | Accuracy: 93.30% | LR: 0.002278
Checkpoint saved at epoch 7 with val accuracy 93.30%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 8 | Training loss: 0.144 | Validation loss: 0.217 | Accuracy: 93.79% | LR: 0.0028
Checkpoint saved at epoch 8 with val accuracy 93.79%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 9 | Training loss: 0.132 | Validation loss: 0.262 | Accuracy: 92.75% | LR: 0.003363
No improvement for 1 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 10 | Training loss: 0.125 | Validation loss: 0.238 | Accuracy: 93.41% | LR: 0.003958
Checkpoint saved at epoch 10 with val accuracy 93.41%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:03<?, ?it/s]


Epoch: 11 | Training loss: 0.117 | Validation loss: 0.255 | Accuracy: 93.52% | LR: 0.004574
Checkpoint saved at epoch 11 with val accuracy 93.52%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:03<?, ?it/s]


Epoch: 12 | Training loss: 0.108 | Validation loss: 0.202 | Accuracy: 94.34% | LR: 0.0052
Checkpoint saved at epoch 12 with val accuracy 94.34%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 13 | Training loss: 0.099 | Validation loss: 0.201 | Accuracy: 94.40% | LR: 0.005827
Checkpoint saved at epoch 13 with val accuracy 94.40%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 14 | Training loss: 0.095 | Validation loss: 0.202 | Accuracy: 94.74% | LR: 0.006443
Checkpoint saved at epoch 14 with val accuracy 94.74%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 15 | Training loss: 0.085 | Validation loss: 0.215 | Accuracy: 94.21% | LR: 0.007037
No improvement for 1 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 16 | Training loss: 0.083 | Validation loss: 0.189 | Accuracy: 94.81% | LR: 0.0076
Checkpoint saved at epoch 16 with val accuracy 94.81%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 17 | Training loss: 0.071 | Validation loss: 0.199 | Accuracy: 95.04% | LR: 0.008122
Checkpoint saved at epoch 17 with val accuracy 95.04%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 18 | Training loss: 0.069 | Validation loss: 0.202 | Accuracy: 94.93% | LR: 0.008594
No improvement for 1 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 19 | Training loss: 0.061 | Validation loss: 0.269 | Accuracy: 93.94% | LR: 0.009008
No improvement for 2 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 20 | Training loss: 0.059 | Validation loss: 0.264 | Accuracy: 93.68% | LR: 0.009357
No improvement for 3 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 21 | Training loss: 0.053 | Validation loss: 0.223 | Accuracy: 94.95% | LR: 0.009635
Checkpoint saved at epoch 21 with val accuracy 94.95%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 22 | Training loss: 0.047 | Validation loss: 0.186 | Accuracy: 95.69% | LR: 0.009837
Checkpoint saved at epoch 22 with val accuracy 95.69%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 23 | Training loss: 0.044 | Validation loss: 0.216 | Accuracy: 95.17% | LR: 0.009959
No improvement for 1 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 24 | Training loss: 0.037 | Validation loss: 0.223 | Accuracy: 95.20% | LR: 0.01
Checkpoint saved at epoch 24 with val accuracy 95.20%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 25 | Training loss: 0.035 | Validation loss: 0.214 | Accuracy: 95.61% | LR: 0.009992
Checkpoint saved at epoch 25 with val accuracy 95.61%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 26 | Training loss: 0.034 | Validation loss: 0.231 | Accuracy: 95.09% | LR: 0.009969
No improvement for 1 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 27 | Training loss: 0.023 | Validation loss: 0.321 | Accuracy: 94.27% | LR: 0.009929
No improvement for 2 epoch(s)


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]

Validation Batches:   0%|          | 0/331 [00:02<?, ?it/s]


Epoch: 28 | Training loss: 0.023 | Validation loss: 0.282 | Accuracy: 95.09% | LR: 0.009875
Checkpoint saved at epoch 28 with val accuracy 95.09%


Training Batches:   0%|          | 0/1323 [00:02<?, ?it/s]