# Benchmark Simulated Data Using SNR cutoff and Holdout Validation

### Imports

In [2]:
## Mount Google Drive Data (If using Google Colaboratory)
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
except:
    print("Mounting Failed.")

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
## External Libraries
import sys
import os
import torch
import torch.nn as nn
import end2end
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from autoencode import AEEnsemble
from datasets import UnsupervisedDataset, SupervisedDataset, BenchmarkDataset
from sklearn.mixture import GaussianMixture
!pip install graspologic
from graspologic.utils import remap_labels
import seaborn as sns
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2)
sns.set_theme()
sns.set_context("paper")

# General Data Directory ##TODO: Please fill in the appropriate directory
os.chdir("/content/")
data_dir = "./gdrive/MyDrive/pedreira"
results_dir = "./gdrive/MyDrive/results"

Define Datasets and AutoEncoder Ensemble

In [3]:
sup_data = BenchmarkDataset(data_dir)
test_idx = []
for i in range(2, 21):
    test_idx.append(list(sup_data.num_units).index(i))
train_idx = np.arange(len(sup_data))
not_test_mask = np.logical_not(np.isin(train_idx, test_idx))
train_idx = train_idx[not_test_mask]
train_data = torch.utils.data.Subset(sup_data, train_idx)
test_data = torch.utils.data.Subset(sup_data, test_idx)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

ae = AEEnsemble(
    optim=torch.optim.Adam,
    convolutional_encoding=False, 
    batch_size=32, 
    epochs=50, 
    lr=(0.001, 0.001, 0.001),
    device=device, 
    activ=nn.ReLU
)

Using cuda:0
Using cuda:0
Using cuda:0
Using cuda:0
Using cuda:0
Using cuda:0


Define Functions

In [4]:
def gt_gmm(latent_vecs, test_targets):
    test_acc = []
    test_prec = []
    test_recall = []
    session_weights = []
    remapped_preds = []
    for (i, latent), targets in zip(enumerate(latent_vecs), test_targets):
        session_weights.append(len(targets))
        units = set(targets)
        num_units = len(units)
        gmm = GaussianMixture(n_components=num_units)
        pred = gmm.fit_predict(latent)
        remapped_pred = remap_labels(targets, pred)
        remapped_preds.append(remapped_pred)
        prec = []
        recall = []
        for unit in units:
            TP = np.sum(np.logical_and(remapped_pred==unit, targets==unit))
            FP = np.sum(np.logical_and(remapped_pred==unit, targets!=unit))
            if FP == 0:
                prec.append(1)
            else:
                prec.append(TP / (TP + FP))
            TN = np.sum(np.logical_and(remapped_pred!=unit, targets!=unit))
            FN = np.sum(np.logical_and(remapped_pred!=unit, targets==unit))
            if FN == 0:
                recall.append(1)
            else:
                recall.append(TP / (TP + FN))

        test_prec.append(np.mean(prec))
        test_recall.append(np.mean(recall))
        test_acc.append(sum(remapped_pred==targets)/len(targets))

    session_weights = np.array(session_weights) / sum(session_weights)
    avg_acc = np.sum(test_acc*session_weights)
    avg_prec = np.mean(test_acc)
    avg_recall = np.mean(test_recall)
    avg_stats = avg_acc, avg_prec, avg_recall
    session_stats = test_acc, test_prec, test_recall
    return avg_stats, session_stats, remapped_preds

def auto_gmm(latent_vecs, test_targets):
    test_acc = []
    test_prec = []
    test_recall = []
    session_weights = []
    remapped_preds = []
    max_n_comps = 21
    for (i, latent), targets in zip(enumerate(latent_vecs), test_targets):
        session_weights.append(len(targets))
        units = set(targets)
        num_units = len(units)
        bics = []
        preds = []
        for n_comps in range(1, max_n_comps+1):
            gmm = GaussianMixture(n_components=n_comps)
            preds.append(gmm.fit_predict(latent))
            bics.append(gmm.bic(latent))
        pred = preds[np.argmin(bics)]
        print("predicted num_units=", np.argmin(bics)+1)
        print("true num_units=", num_units)
        remapped_pred = remap_labels(targets, pred)
        remapped_preds.append(remapped_pred)
        prec = []
        recall = []
        for unit in units:
            TP = np.sum(np.logical_and(remapped_pred==unit, targets==unit))
            FP = np.sum(np.logical_and(remapped_pred==unit, targets!=unit))
            if FP == 0:
                prec.append(1)
            else:
                prec.append(TP / (TP + FP))
            TN = np.sum(np.logical_and(remapped_pred!=unit, targets!=unit))
            FN = np.sum(np.logical_and(remapped_pred!=unit, targets==unit))
            if FN == 0:
                recall.append(1)
            else:
                recall.append(TP / (TP + FN))

        test_prec.append(np.mean(prec))
        test_recall.append(np.mean(recall))
        test_acc.append(sum(remapped_pred==targets)/len(targets))

    session_weights = np.array(session_weights) / sum(session_weights)
    avg_acc = np.sum(test_acc*session_weights)
    avg_prec = np.mean(test_acc)
    avg_recall = np.mean(test_recall)
    avg_stats = avg_acc, avg_prec, avg_recall
    session_stats = test_acc, test_prec, test_recall
    return avg_stats, session_stats, remapped_preds

def viz_stats(avg_stats, session_stats, _title, figname):
    avg_acc, avg_prec, avg_recall = avg_stats
    test_acc, test_prec, test_recall = session_stats
    print("Average Accuracy=", avg_acc)
    print("Average Precision=", avg_prec)
    print("Average Recall=", avg_recall)

    fig = plt.figure()
    plt.hold = True
    plt.plot(np.arange(2, 21), test_acc, label="Accuracy")
    plt.plot(np.arange(2, 21), test_prec, label="Precision")
    plt.plot(np.arange(2, 21), test_recall, label="Recall")
    plt.hold = False
    plt.legend()
    plt.xlabel("True Number of Units (before SNR)")
    plt.ylabel("Performance")
    plt.title(_title)
    plt.savefig(results_dir+"/"+figname)

def viz_tsne(unit_nums, latent_vecs, test_targets, remapped_preds, fignames):
    for idx, figname in zip(unit_nums, fignames):
        i = idx - 2 #accounting for offset since idx=0 has 2 units
        latent = latent_vecs[i]
        targets = test_targets[i]
        remapped_pred = remapped_preds[i]

        # Plot tsne
        latent_manifold = tsne.fit_transform(latent.cpu())

        fig, (ax1, ax2) = plt.subplots(ncols=2, sharex=True, sharey=True)
        ax1.set_title("Ground Truth Labels")
        ax1.set_xticks([])
        ax1.set_yticks([])
        for c in range(np.max(targets)+1):
            c_manifold = latent_manifold[targets == c]
            ax1.scatter(c_manifold[:, 0], c_manifold[:, 1], marker=".", s=.5)

        ax2.set_title("Predicted Labels")
        for c in range(np.max(remapped_pred)+1):
            c_manifold = latent_manifold[remapped_pred == c]
            ax2.scatter(c_manifold[:, 0], c_manifold[:, 1], marker=".", s=.5)
        plt.savefig(results_dir+"/"+figname)

def embed_test(test_data, ae):
    latent_vecs = []
    test_targets = []
    for spikes, targets, snrs, num_units in test_data:
        session_latent = []
        possible_targets = np.arange(len(snrs))
        hi_fidel_targets = possible_targets[snrs>=min_snr]
        spikes = torch.FloatTensor(spikes[np.isin(targets, hi_fidel_targets)])
        session_targets = targets[np.isin(targets, hi_fidel_targets)]
        if "cuda" in ae.device:
            spikes = spikes.cuda(0)
        for encoder in ae.encoders:
            session_latent.append(encoder(spikes))
        session_latent = torch.cat(session_latent, dim=1).detach().cpu()
        latent_vecs.append(session_latent)
        test_targets.append(session_targets)
    return latent_vecs, test_targets

Run Embeddings for different minimum SNR values

In [8]:
min_snrs = [0, 8, 16, 24]
ae_is_trained = [True, True, False, False] #whether the autoencoder is already trained on the corresponding snr
#Train Loop
for min_snr, is_trained in zip(min_snrs, ae_is_trained):
    print("min_snr=", min_snr)
    if is_trained:
        continue
    ae = AEEnsemble(
    optim=torch.optim.Adam,
    convolutional_encoding=False, 
    batch_size=32, 
    epochs=50, 
    lr=(0.001, 0.001, 0.001),
    device=device, 
    activ=nn.ReLU
)
    ae.benchmark(min_snr, train_data, test_data, on_drive=True)

#Eval Loop
snr_acc, snr_prec, snr_recall = [], [], []
for min_snr, is_trained in zip(min_snrs, ae_is_trained):
    print("min_snr=", min_snr)
    prefix="benchmark_snr_%s"%min_snr
    ae.load(prefix=prefix, on_drive=True)
    latent_vecs, test_targets = embed_test(test_data, ae)
    
    avg_stats, session_stats, remapped_preds = gt_gmm(latent_vecs, test_targets)
    acc, prec, recall = avg_stats
    accs, precs, recalls = [], [], []
    accs.append(acc)
    precs.append(prec)
    recalls.append(recall)
    viz_stats(avg_stats, session_stats, "Ground-Truth GMM: minimum SNR=%s"%min_snr, "gtgmm_stats_snr%s"%min_snr)
    fignames = ["gtgmm_tsne_snr%s_numunits%s"%(min_snr, n) for n in [2, 10, 20]]
    viz_tsne([2, 10, 20], latent_vecs, test_targets, remapped_preds, fignames)

    avg_stats, session_stats, remapped_preds = auto_gmm(latent_vecs, test_targets)
    acc, prec, recall = avg_stats
    accs.append(acc)
    precs.append(prec)
    recalls.append(recall)
    viz_stats(avg_stats, session_stats, "Auto GMM: minimum SNR=%s"%min_snr, "autogmm_stats_snr%s"%min_snr)
    fignames = ["autogmm_tsne_snr%s_numunits%s"%(min_snr, n) for n in [2, 10, 20]]
    viz_tsne([2, 10, 20], latent_vecs, test_targets, remapped_preds, fignames)

    snr_acc.append(accs)
    snr_prec.append(precs)
    snr_recall.append(recalls)

snr_acc = np.array(snr_acc)
np.save(results_dir+"/snr_acc.npy", snr_acc)
snr_prec = np.array(snr_prec)
np.save(results_dir+"/snr_prec.npy", snr_prec)
snr_recall = np.array(snr_recall)
np.save(results_dir+"/snr_recall.npy", snr_recall)
plt.figure()
plt.hold = True
plt.plot(min_snrs, snr_acc[:, 0], label="Average Accuracy", c="C0", ls="-")
plt.plot(min_snrs, snr_acc[:, 1], c="C0", ls="--")
plt.plot(min_snrs, snr_prec[:, 0], label="Average Precision", c="C1", ls="-")
plt.plot(min_snrs, snr_prec[:, 1], c="C1", ls="--")
plt.plot(min_snrs, snr_recall[:, 0], label="Average Recall", c="C2", ls="-")
plt.plot(min_snrs, snr_recall[:, 1], c="C2", ls="--")
plt.hold = False
plt.xlabel("Minimum SNR")
plt.ylabel("Performance")
plt.title("V0")
plt.savefig(results_dir+"/v0_snr_stats")

min_snr= 0
min_snr= 8
min_snr= 16
Using cuda:0
Using cuda:0
Using cuda:0
Using cuda:0
Using cuda:0
Using cuda:0

EPOCH 1 of 50


KeyboardInterrupt: ignored