In [None]:
!pip install neurobench

import copy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset

from neurobench.benchmarks import Benchmark
from neurobench.datasets import MSWC
from neurobench.datasets.MSWC_IncrementalLoader import IncrementalFewShot
from tqdm import tqdm

!pip install --force-reinstall --no-deps git+https://github.com/tanviriitb/TCN-library.git@global-pool
!pip install torchsummary
!pip install -U numpy

import torch.optim as optim

In [None]:
# data in repo root dir
ROOT = "./data/"

directory = "./model_data"

if not os.path.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

NUM_WORKERS = 4
BATCH_SIZE = 256

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device == torch.device("cuda"):
    PIN_MEMORY = True
else:
    PIN_MEMORY = False
device

In [None]:
SPIKING = False

In [None]:
from neurobench.preprocessing import MFCCPreProcessor, S2SPreProcessor

n_fft = 512
win_length = None
hop_length = 240
n_mels = 20
n_mfcc = 20

if SPIKING:
    encode = S2SPreProcessor(device, transpose=True)
    config_change = {"sample_rate": 48000,
                     "hop_length": 240}
    encode.configure(threshold=1.0, **config_change)
else:
    encode = MFCCPreProcessor(
        sample_rate=48000,
        n_mfcc=n_mfcc,
        melkwargs={
            "n_fft": n_fft,
            "n_mels": n_mels,
            "hop_length": hop_length,
            "mel_scale": "htk",
            "f_min": 20,
            "f_max": 4000,
        },
        device = device
    )

In [None]:
base_train_set = MSWC(root=ROOT, subset="base", procedure="training")

train_loader = DataLoader(base_train_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=True)

base_validation_set = MSWC(root=ROOT, subset="base", procedure="validation")

validation_loader = DataLoader(base_validation_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=True)

In [None]:
from tcn_lib import TCN
from torchsummary import summary

feature_count = 128
model = TCN(20, 200, [64] * 2 + [128] * 2, [9] * 4, batch_norm=True, weight_norm=True, 
            residual=True, bottleneck=True, groups=32, dropout = 0.2).to(device)

summary(model, (20, 200))

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss= 100
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nLeast validation error: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                }, './model_data/best_model.pth')

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

# Define your hyperparameters
epochs = 15
lr_max = 5e-3  # Initial maximum learning rate
lr_min = 1e-5  # Minimum learning rate
T_max = epochs  # Number of epochs for one cycle

# Create loss function, optimizer, and learning rate scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr_max)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=T_max, eta_min=lr_min)
save_best_model = SaveBestModel()

for epoch in tqdm(range(epochs)):
    train_avg_loss = 0
    train_correct = 0
    train_total = 0

    validation_avg_loss = 0
    validation_correct = 0
    validation_total = 0
    
    model.train()

    for data, target in train_loader:
        data, target = encode((data.to(device), target.to(device)))
        data = data.squeeze()

        optimizer.zero_grad()
        
        features, output = model(data)

        # Calculate Loss
        loss = criterion(output, target)

        loss.backward()
        optimizer.step()

        train_avg_loss += loss.item()
        _, predicted = output.max(1)
        train_total += target.size(0)
        train_correct += predicted.eq(target).sum().item()

    
    model.eval()

    with torch.no_grad():
        for data, target in validation_loader:
            data, target = encode((data.to(device), target.to(device)))
            data = data.squeeze()

            _, output = model(data)
            loss = criterion(output, target)

            validation_avg_loss += loss.item()
            _, predicted = output.max(1)
            validation_total += target.size(0)
            validation_correct += predicted.eq(target).sum().item()

    train_loss, train_acc = train_avg_loss / len(train_loader), 100 * train_correct / train_total
    validation_loss, validation_acc = validation_avg_loss / len(validation_loader), 100 * validation_correct / validation_total

    print(f"Epoch {epoch} - Train Loss: {train_loss:.4f} - Train Acc: {train_acc:.2f}")
    print(f"Epoch {epoch} - Validation Loss: {validation_loss:.4f} - Validation Acc: {validation_acc:.2f}")

    save_best_model(
        validation_loss, epoch, model
    )
    
    # Step the scheduler
    lr_scheduler.step()

print('Finished Training')