To test:
* optimizers: Adam, AdamW
* learning rate schedulers: StepLR, CosineAnnealingLR
* more epochs
* pos embedding
* dropout
* more heads

In [1]:
import sys

sys.path.append("./../../src")

from Dataset import SpeechCommandsDataset
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset
from Transformer import SpeechCommandTransformer, train_transformer, calculate_class_weights, plot_confusion_matrix, plot_accuracy_loss, set_seed
import torch
from torch.optim import AdamW
import torch.nn as nn
torch.cuda.empty_cache()
set_seed(213)

In [None]:
train_dataset = SpeechCommandsDataset("../../data/train", mode="modified")
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=6)

test_dataset = SpeechCommandsDataset("../../data/test", mode="modified")
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=6)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class_weights = calculate_class_weights(train_dataset)
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

In [None]:
model = SpeechCommandTransformer(num_classes=len(train_dataset.class_to_idx), embed_dim=256, device=device,
                                 stride=1).to(device)

(train_losses, train_accuracies, test_losses, test_accuracies, train_true_labels, train_pred_labels,
     test_true_labels, test_pred_labels) = train_transformer(train_loader, test_loader, model=model, num_epochs=20, device=device, criterion=criterion)
plot_confusion_matrix(train_true_labels, train_pred_labels, train_dataset, normalize=False)
plot_confusion_matrix(test_true_labels, test_pred_labels, train_dataset, normalize=False)
plot_accuracy_loss(train_accuracies, train_losses, test_accuracies, test_losses)

In [None]:
model = SpeechCommandTransformer(num_classes=len(train_dataset.class_to_idx), embed_dim=256, device=device,
                                 stride=1).to(device)
optimizer = AdamW(model.parameters(), lr=0.00005, weight_decay=0.001)

(train_losses, train_accuracies, test_losses, test_accuracies, train_true_labels, train_pred_labels,
     test_true_labels, test_pred_labels) = train_transformer(train_loader, test_loader, model=model, num_epochs=20, device=device, criterion=criterion, optimizer=optimizer)

In [None]:
plot_confusion_matrix(train_true_labels, train_pred_labels, train_dataset, normalize=False)
plot_confusion_matrix(test_true_labels, test_pred_labels, train_dataset, normalize=False)
plot_accuracy_loss(train_accuracies, train_losses, test_accuracies, test_losses)

In [None]:
model = SpeechCommandTransformer(num_classes=len(train_dataset.class_to_idx), embed_dim=256, device=device,
                                 stride=1).to(device)
optimizer = AdamW(model.parameters(), lr=0.00005, weight_decay=0.001)

(train_losses, train_accuracies, test_losses, test_accuracies, train_true_labels, train_pred_labels,
     test_true_labels, test_pred_labels) = train_transformer(train_loader, test_loader, model=model, num_epochs=20, device=device, criterion=criterion, optimizer=optimizer, scheduling=False)

In [None]:
plot_confusion_matrix(train_true_labels, train_pred_labels, train_dataset, normalize=False)
plot_confusion_matrix(test_true_labels, test_pred_labels, train_dataset, normalize=False)
plot_accuracy_loss(train_accuracies, train_losses, test_accuracies, test_losses)

In [None]:
model = SpeechCommandTransformer(num_classes=len(train_dataset.class_to_idx), embed_dim=256, device=device,
                                 stride=1).to(device)
optimizer = AdamW(model.parameters(), lr=0.00005, weight_decay=0.001)

(train_losses, train_accuracies, test_losses, test_accuracies, train_true_labels, train_pred_labels,
     test_true_labels, test_pred_labels) = train_transformer(train_loader, test_loader, model=model, num_epochs=20, device=device, criterion=criterion, optimizer=optimizer, scheduling=False)

In [None]:
plot_confusion_matrix(train_true_labels, train_pred_labels, train_dataset, normalize=False)
plot_confusion_matrix(test_true_labels, test_pred_labels, train_dataset, normalize=False)
plot_accuracy_loss(train_accuracies, train_losses, test_accuracies, test_losses)

In [None]:
model = SpeechCommandTransformer(num_classes=len(train_dataset.class_to_idx), embed_dim=256, device=device,
                                 stride=1, pos_embedding=True).to(device)
optimizer = AdamW(model.parameters(), lr=0.00005, weight_decay=0.001)

(train_losses, train_accuracies, test_losses, test_accuracies, train_true_labels, train_pred_labels,
     test_true_labels, test_pred_labels) = train_transformer(train_loader, test_loader, model=model, num_epochs=20, device=device, criterion=criterion, optimizer=optimizer, scheduling=False)

In [None]:
plot_confusion_matrix(train_true_labels, train_pred_labels, train_dataset, normalize=False)
plot_confusion_matrix(test_true_labels, test_pred_labels, train_dataset, normalize=False)
plot_accuracy_loss(train_accuracies, train_losses, test_accuracies, test_losses)