<a href="https://colab.research.google.com/github/r-ogrady/dissertation2/blob/main/c37_train_AMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Part 2 - Training a CNN on word recogntion task

Follow up to the processing. Now we train a CNN to recognise the digits spoken in the audio. We use spectrograms and AlexNet so that we can compare with work by Becker et al (https://arxiv.org/abs/1807.03418)

**Work flow**
1. We set up the CNN by extending the nn.Module class and define the training loop (hyperparameter is learning rate only). We use a validation set to impose early stopping with a patience of 10 epochs
2. The data has been been processed so we load it and create a dataset to use with the dataloader
3. Train! 

**Findings**

In our processing, we deviated slightly from Becker's procedure. We do here too (eg a different optimiser: Adam, instead of SGD) but they are still comparable. We end up with a model that trains with high (validation) accuracy and is actually abandoned, because I required more 'challenging' tasks in my dissertation. However, it is worth noting that Becker's approach provides a model with 96% and we gain +99% with only 40% of the data they used.

**Notes**
* Again, training is deliberately abandoned. Here, I was seaking to create a base model to demonstrate that the novel loss function in my dissertation provides an improvement. But if I was gaining close to 100% accuracy on the base, I would be unable to demonstrate much of an improvement!

In [1]:
import numpy as np
import time
import pandas as pd
import copy
import random
import gc
import torchvision.transforms as T
import torchvision.transforms.functional as F
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

In [None]:
# random seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
generator1=torch.Generator().manual_seed(42)

## Define network and training loop

In [None]:
# create CNN with AlexNet architecture
# credit to author Nouman
# https://blog.paperspace.com/alexnet-pytorch/

class AlexNet(nn.Module):
    def __init__(self, input_channels=3,num_classes=10):
        super(AlexNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(input_channels, 96, kernel_size=11, stride=4, padding=0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(9216, 4096),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())
        self.fc2= nn.Sequential(
            nn.Linear(4096, num_classes))
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

In [5]:
# training loop
def trainer0(n_epochs,lrs,train_dataloader,val_dataloader):
    
    # maximise available memory
    gc.collect()
    torch.cuda.empty_cache()
    
    # storage
    # data dict for storing metrics per lr and epoch
    epoch_keys=["train_losses","train_accuracies",
                "val_losses","val_accuracies","time"]
    
    data_dict={lr:
      {key:[] for key in epoch_keys} for lr in lrs}
    
    # we will store the model filenames for reference
    best_model_strings=[]

    # we train a new model for each lr value
    for i,lr in enumerate(lrs):
        # load model for each loop
        model = AlexNet(input_channels=1).to(device)
        optimizer = optim.Adam(params = model.parameters(),lr=lr)
        criterion = nn.CrossEntropyLoss()

        # initialise variables for EarlyStopping
        best_loss = float('inf')
        best_model_weights = None
        patience = 10


        print(f"training for lr = {lr}")

        for epoch in range(n_epochs):
            t0 = time.time()
            # initialise variables for recording training loss and accuracy
            running_loss = 0.0
            running_correct = 0
            running_total = 0

            model.train()
            for input,labels in train_dataloader:
                desired_labels=labels[0]

                input = input.to(device)
                desired_labels = desired_labels.to(device)


                optimizer.zero_grad()
                output = model(input)
                loss = criterion(output,desired_labels)
                loss.backward()
                optimizer.step()

                # training loss and accuracy for this batch
                # this is scaled by the batch size and divided back later
                running_loss += loss.item() * input.size(0)

                _, predicted = torch.max(output.data, 1)
                running_total += desired_labels.size(0)
                running_correct += (predicted == desired_labels).sum().item()

            # training loss and accuracy for epoch (scaled back)
            train_loss = running_loss/len(train_dataloader.dataset)

            train_accuracy = 100 * running_correct/running_total

            # store training loss and acuracy in our dictionary
            epoch_dict=data_dict[lr]
            epoch_dict["train_losses"].append(train_loss)
            epoch_dict["train_accuracies"].append(train_accuracy)

            # initialise variables for recording val loss and accuracy
            val_running_loss = 0.0
            val_running_correct = 0
            val_running_total = 0

            model.eval()
            with torch.no_grad():
                for input,labels in val_dataloader:
                    desired_labels=labels[0]
                    input = input.to(device)
                    desired_labels = desired_labels.to(device)

                    output = model(input)

                    # validation loss and accuracy for batch
                    val_running_loss += criterion(output,desired_labels).item() * desired_labels.size(0)

                    _, predicted = torch.max(output.data, 1)
                    val_running_total += desired_labels.size(0)
                    val_running_correct += (predicted == desired_labels).sum().item()


            # validation loss and accuracy for epoch
            val_loss = val_running_loss / len(val_dataloader.dataset)

            val_accuracy = 100 * val_running_correct / val_running_total

            # store
            epoch_time=time.time()-t0
            epoch_dict["val_losses"].append(val_loss)

            epoch_dict["val_accuracies"].append(val_accuracy)
            epoch_dict["time"].append(epoch_time)

            # print for epoch
            if epoch % 5 == 0:
                print(f"epoch: {epoch + 1}, time: {epoch_time:0.2f}")
                print("training loss: ",
                      f"{train_loss:0.2f}, accuracy: {train_accuracy:.2f}")
                print("validation loss: ",
                      f"{val_loss:0.2f}, accuracy: {val_accuracy:.2f}")
                print()

            # EarlyStopping
            if val_loss < best_loss:
                best_loss = val_loss
                best_model_weights = copy.deepcopy(model.state_dict())
                patience = 10
            else:
                patience -= 1
                if patience == 0:
                    print(f"Early stop at epoch {epoch+1}\n")
                    break

        # load and store the best model weights
        model.load_state_dict(best_model_weights)
        time_stamp=time.strftime("%Y%m%d-%H%M%S")
        file_path=f"output/model_base_lr{lr}_{time_stamp}.pth"
        torch.save(model.state_dict(), file_path)
        print(f"saved model as '{file_path}'")
        print()
        best_model_strings.append(file_path)

    return data_dict, best_model_strings


In [7]:
# test accuracy of a given model
# this is not used because we abandon the training
def tester(model_string,test_dataloader,target_type=None):
    gc.collect()
    torch.cuda.empty_cache()
    model = AlexNet(input_channels=1)
    model.load_state_dict(torch.load(model_string))
    model.to(device)
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for input,target in test_dataloader:
            if target_type=="tuple":
                target=target[0]
            input,target = input.to(device),target.to(device)
            scores = model(input)
            _, predictions = scores.max(1)
            num_correct +=  (predictions ==  target).sum()
            num_samples +=  predictions.size(0)

    accuracy = float(num_correct)/float(num_samples)*100

    print(model_string)
    print(f'Got {num_correct}/{num_samples} with accuracy {accuracy:.2f}\n')

    return accuracy

## Load data

In [None]:
# load the data
my_data=torch.load('data/AudioMNIST_processed/my_small_data.pt')

In [9]:
# we create our dataset class
# add another label for additional work in my disseratation
class MyDataset(Dataset):
    def __init__(self, my_list):
        self.my_list = my_list

    def __len__(self):
        return len(self.my_list[0])

    def __getitem__(self, index):
        spec=self.my_list[0][index]
        digit=self.my_list[1][index]
        gender=self.my_list[3][index]

        return spec, (digit,gender)

In [10]:
amnist_dataset=MyDataset(my_data)

# split the data into train/validation/test
amnist_train_dataset,amnist_val_dataset,amnist_test_dataset=torch.utils.data.random_split(amnist_dataset,[2/3,1/6,1/6],generator=generator1)

In [11]:
# double check shapes and size
print(f"amnist train len {len(amnist_train_dataset)}")
print(f"amnist val len {len(amnist_val_dataset)}")
print(f"amnist test len {len(amnist_test_dataset)}")

print(f"image shape {amnist_train_dataset[0][0].shape}")
print(f'labels {amnist_train_dataset[0][1]}')

amnist train len 4000
amnist val len 1000
amnist test len 1000
image shape torch.Size([1, 227, 227])
labels (tensor(3), tensor(0))


In [12]:
# dataloaders
amnist_train_dataloader = DataLoader(amnist_train_dataset, batch_size = 64, shuffle = True,generator = generator1)
amnist_val_dataloader = DataLoader(amnist_val_dataset, batch_size = 64, shuffle = True,generator = generator1)
amnist_test_dataloader = DataLoader(amnist_test_dataset, batch_size = 64, shuffle = True,generator = generator1)

## Train

### Use GPU if possible

In [13]:
# GPU
if torch.cuda.is_available():
    device = "cuda"
    # show GPU details - this is a colab command
    !nvidia-smi
else:
    device = "cpu"

print(f"\nUsing {device}")

Tue Sep  3 10:44:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P8               9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### Model overview

In [14]:
# show the model architecture and size
model0 = AlexNet(input_channels=1).to(device)
# overview of model - we enter the input size
summary(model0,(1, 227, 227),device = device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 55, 55]          11,712
       BatchNorm2d-2           [-1, 96, 55, 55]             192
              ReLU-3           [-1, 96, 55, 55]               0
         MaxPool2d-4           [-1, 96, 27, 27]               0
            Conv2d-5          [-1, 256, 27, 27]         614,656
       BatchNorm2d-6          [-1, 256, 27, 27]             512
              ReLU-7          [-1, 256, 27, 27]               0
         MaxPool2d-8          [-1, 256, 13, 13]               0
            Conv2d-9          [-1, 384, 13, 13]         885,120
      BatchNorm2d-10          [-1, 384, 13, 13]             768
             ReLU-11          [-1, 384, 13, 13]               0
           Conv2d-12          [-1, 384, 13, 13]       1,327,488
      BatchNorm2d-13          [-1, 384, 13, 13]             768
             ReLU-14          [-1, 384,

### Training loop
Note: we interupt this, because we start getting very high accuracy (which we don't actually want in my dissertation!)

In [15]:
# hyperparams
# lrs and n_epochs for all
lrs=[0.0001,0.0005,0.001]
n_epochs=100

# loss
base_loss_fn = nn.CrossEntropyLoss()

In [16]:
# our output from the training loop is a dictionary and list of file names
model02_dict,model02_strings=trainer0(n_epochs,lrs,amnist_train_dataloader,amnist_val_dataloader)

training for lr = 0.0001
epoch: 1, time: 6.55
training loss:  0.87, accuracy: 70.10
validation loss:  0.41, accuracy: 83.40

epoch: 6, time: 5.87
training loss:  0.05, accuracy: 98.50
validation loss:  0.34, accuracy: 91.10

epoch: 11, time: 5.79
training loss:  0.00, accuracy: 99.97
validation loss:  0.01, accuracy: 99.70

epoch: 16, time: 5.71
training loss:  0.01, accuracy: 99.75
validation loss:  0.03, accuracy: 99.20

epoch: 21, time: 5.72
training loss:  0.01, accuracy: 99.78
validation loss:  0.01, accuracy: 99.50

epoch: 26, time: 5.96
training loss:  0.00, accuracy: 99.92
validation loss:  0.02, accuracy: 99.30



KeyboardInterrupt: 

In [None]:
# saving
time_stamp=time.strftime("%Y%m%d-%H%M%S")
torch.save(model02_dict, f'output/model02_dict_{time_stamp}.pt')
print(f"saved dictionary as 'model02_dict_{time_stamp}.pt'")

# reference for loading
print(model02_strings)