# Brain Classifier Training and Evaluation

In [None]:
from datetime import datetime
import json
import os
from pathlib import Path
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import WeightedRandomSampler
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, multilabel_confusion_matrix, confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
import numpy as np

from models.brain_class_encoder import BrainClassifierV1, BrainClassifierV3
from dataloaders.conditional_loaders import ClassConditionalLoader, ECOGLoader
from utils.generic import get_word_from_filepath

In [None]:
%env CUDA_VISIBLE_DEVICES=0

## Data Preparation

In [6]:
data_base_dir = Path('/home/passch/data/')
data_path = data_base_dir / 'HP1_ECoG_conditional/sub-002'
splits_path = data_base_dir / 'datasplits/HP1_ECoG_conditional/sub-002'

# Open train and val files from the precomputed splits. Since these were computed for the speech files, we replace the
# .wav file endings with .npy to get the brain activity samples.
with open(splits_path / 'train.csv', 'r') as f:
    train_files = [fn.replace('.wav','.npy') for fn in f.read().split(',')]
with open(splits_path / 'val.csv', 'r') as f:
    val_files = [fn.replace('.wav','.npy') for fn in f.read().split(',')]

# Set no. of classes to 55 even if dataset may have lower actual number of classes. 
# This ensures compatibility with the class-conditional pretraining model.
n_classes = 55
print(len(train_files), len(val_files))

# Shuffle in-place
rng = np.random.default_rng()
rng.shuffle(train_files)
rng.shuffle(val_files)

171 17


In [7]:
class_label_loader = ClassConditionalLoader(data_base_dir / 'HP_VariaNTS_intersection.txt')

class Dataset(torch.utils.data.Dataset):
    def __init__(self, files) -> None:
        super().__init__()
        self.files=files
    def __getitem__(self, n:int):
        ecog = ECOGLoader.process_ecog(data_path / self.files[n])
        word_vector = class_label_loader(self.files[n]).squeeze()
        return ecog, word_vector
    def __len__(self):
        return len(self.files)
    def labels(self):
        return [get_word_from_filepath(fp) for fp in self.files]

trainset = Dataset(train_files)
valset = Dataset(val_files)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=8, num_workers=4, pin_memory=False, shuffle=False, drop_last=False)

# Make batch size == length of val files
valloader = torch.utils.data.DataLoader(
    valset, batch_size=len(val_files), num_workers=4, pin_memory=False, shuffle=False, drop_last=False)

## Training

In [115]:
LEARNING_RATE = 1e-5
N_EPOCHS = 1000

SAVE_ITER = 100

In [112]:
model = BrainClassifierV1(in_nodes=48 * 55, n_classes=n_classes).cuda()

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

start_epoch = 1
model

BrainClassifierV1(
  (network): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=2640, out_features=1320, bias=True)
    (2): ReLU()
    (3): LayerNorm((1320,), eps=1e-05, elementwise_affine=True)
    (4): Dropout(p=0.4, inplace=False)
    (5): Linear(in_features=1320, out_features=660, bias=True)
    (6): ReLU()
    (7): LayerNorm((660,), eps=1e-05, elementwise_affine=True)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=660, out_features=55, bias=True)
    (10): Softmax(dim=1)
  )
)

In [113]:
ckpt_path = f'exp/Brain-Classifier-{datetime.now()}/'

os.makedirs(ckpt_path, exist_ok=False)
print(ckpt_path)

exp/Sub2_Full-Std_MLP_55classes_LR1e-05_dropout_layernorm_no-shuffle_3/


In [114]:
logs = {
    'train_loss_epoch': [], 'train_loss_batch': [], 'train_acc_epoch': [], 'train_acc_batch': [],
    'val_loss_epoch': [], 'val_loss_batch': [], 'val_acc_epoch': [], 'val_acc_batch': [],
}

In [137]:
for epoch in (pbar := tqdm(range(start_epoch, start_epoch+N_EPOCHS), desc='Training', ncols=125)):
    # TRAINING
    model.train()
    train_loss_epoch = 0.
    train_acc_epoch = 0.
    for i, data in enumerate(trainloader):
        x, y = data
        x = x.cuda()
        y = y.cuda()

        optimizer.zero_grad()
        y_pred = model(x)
        
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        
        train_loss_batch = loss.item()
        logs['train_loss_batch'].append(train_loss_batch)
        train_loss_epoch += train_loss_batch

        train_acc_batch = accuracy_score(
            torch.argmax(y, dim=1).cpu(), torch.argmax(y_pred, dim=1).cpu())
        logs['train_acc_batch'].append(train_acc_batch)
        train_acc_epoch += train_acc_batch
   
    train_loss_epoch /= len(trainloader)
    logs['train_loss_epoch'].append(train_loss_epoch)

    train_acc_epoch /= len(trainloader)
    logs['train_acc_epoch'].append(train_acc_epoch)

    # VALIDATING
    model.eval()
    val_loss_epoch = 0.
    val_acc_epoch = 0.
    for i, data in enumerate(valloader):
        x, y = data
        x = x.cuda()
        y = y.cuda()
        
        y_pred = model(x)
        
        loss = loss_fn(y_pred, y)
        
        val_loss_batch = loss.item()
        logs['val_loss_batch'].append(val_loss_batch)
        val_loss_epoch += val_loss_batch

        val_acc_batch = accuracy_score(
            torch.argmax(y, dim=1).cpu(), torch.argmax(y_pred, dim=1).cpu())
        logs['val_acc_batch'].append(val_acc_batch)
        val_acc_epoch += val_acc_batch

    val_loss_epoch /= len(valloader)
    logs['val_loss_epoch'].append(val_loss_epoch)

    val_acc_epoch /= len(valloader)
    logs['val_acc_epoch'].append(val_acc_epoch)

    # Update loss values in progress bar
    pbar.set_postfix({
        'Train Loss': f'{logs["train_loss_epoch"][-1]:.2f}', 
        'Val Loss': f'{logs["val_loss_epoch"][-1]:.2f}',
        'Train Acc': f'{logs["train_acc_epoch"][-1]:.3f}',
        'Val Acc': f'{logs["val_acc_epoch"][-1]:.3f}',
    })

    # Save model checkpoint
    if epoch % SAVE_ITER == 0:
        torch.save(
            {'model_state_dict': model.state_dict()},
            os.path.join(ckpt_path, f'{epoch}.pkl')
        )

start_epoch = epoch + 1

Training: 100%|██████████| 1200/1200 [14:03<00:00,  1.42it/s, Train Loss=3.04, Val Loss=3.78, Train Acc=1.000, Val Acc=0.235]


Save logs to disk

In [138]:
with open(os.path.join(ckpt_path, 'logs.json'), "w") as f:
    json.dump(logs, f)

## Plotting Logged Metrics

In [None]:
start_idx = 0
end_idx = -1

plt.figure(figsize=(12,4))

plt.subplot(121)
plt.title('Cross Entropy Loss')
plt.plot(logs['train_loss_epoch'][start_idx:end_idx], color='orange', label='Training')
plt.plot(logs['val_loss_epoch'][start_idx:end_idx], color='blue', label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Cross Entropy')
plt.legend(loc=3)
plt.grid()

plt.subplot(122)
plt.title('Accuracy')
plt.plot(logs['train_acc_epoch'][start_idx:end_idx], color='orange', label='Training')
plt.plot(logs['val_acc_epoch'][start_idx:end_idx], color='blue', label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc=4)
plt.grid()

plt.savefig(os.path.join(ckpt_path, 'loss_progression'))
plt.show()

## Confusion Matrices

In [6]:
def get_word_from_token(token):
    return list(class_label_loader.word_tokens.keys())\
        [list(class_label_loader.word_tokens.values()).index(token)]

In [None]:
all_preds_train = []
all_true_train = []

all_preds_val = []
all_true_val = []

model.eval()

for i, data in enumerate(trainloader):
    x, y = data
    x = x.cuda()
    
    y_pred = model(x)

    all_preds_train.extend(torch.argmax(y_pred, dim=1).cpu().numpy())
    all_true_train.extend(torch.argmax(y, dim=1).numpy())

for i, data in enumerate(valloader):
    x, y = data
    x = x.cuda()

    y_pred = model(x)
    all_preds_val.extend(torch.argmax(y_pred, dim=1).cpu().numpy())
    all_true_val.extend(torch.argmax(y, dim=1).numpy())

In [None]:
# Convert predicted class indices back to words
all_preds_train = [get_word_from_token(l) for l in all_preds_train]
all_true_train = [get_word_from_token(l) for l in all_true_train]
all_preds_val = [get_word_from_token(l) for l in all_preds_val]
all_true_val = [get_word_from_token(l) for l in all_true_val]

print('Train acc.:', accuracy_score(all_true_train, all_preds_train))
print('Val. acc. :', accuracy_score(all_true_val, all_preds_val))

In [59]:
# Export predictions
save_dict = {
    'train_preds': all_preds_train,
    'train_true': all_true_train,
    'val_preds': all_preds_val,
    'val_true': all_true_val,
}
with open(os.path.join(ckpt_path, 'predictions_e{epoch}.json'), 'w', encoding='utf-8') as f:
    json.dump(save_dict, f, ensure_ascii=False, indent=2)

In [None]:
plt.rcParams.update({
    'axes.labelsize': 18, 
    'xtick.labelsize': 14,
    'ytick.labelsize': 14,
})

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.set_title('Training', fontsize=24)
ConfusionMatrixDisplay.from_predictions(
    all_true_train, all_preds_train, ax=ax, colorbar=False, xticks_rotation=45)
plt.grid()
plt.tight_layout()
plt.savefig(os.path.join(ckpt_path, f'conf_mat_train_e{epoch}'))
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.set_title('Validation', fontsize=24)
ConfusionMatrixDisplay.from_predictions(
    all_true_val, all_preds_val, ax=ax, colorbar=False, xticks_rotation=45)
plt.grid()
plt.tight_layout()
plt.savefig(os.path.join(ckpt_path, f'conf_mat_val_e{epoch}'))
plt.show()