In [1]:
from featureextraction import generate_embeddings
import h5py
import matplotlib.pyplot as plt
from models import get_baseline_convolutional_encoder, build_siamese_net
import numpy as np
import os
import random
from scipy.signal import decimate
from sklearn.manifold import TSNE
import tensorflow as tf
import yaml

In [None]:
# define species set
dataset_folds = [line.rstrip().split(',') for line in open('cv_folds.txt')]
train_set = dataset_folds[0][0:7]
test_set = dataset_folds[0][7:]

# load parameters from config_file
config_file = 'siam13_4c_ew32_25.yaml'
models_path = 'models'
model_file = os.path.join(models_path, config_file)
with open(model_file) as cfile:
    config_params = yaml.safe_load(cfile)
eeg_epoch_width_in_s = config_params['epoch_width']
num_classes = config_params['num_classes']
decimate_factor = 4

In [None]:
# read EEG epochs - randomly sample num_samples from each class of each species
num_samples = 128
epochs_path = 'data/epochs_{}c'.format(num_classes)
file_template = '{}_BL5_ew{}.h5'
train_sw_s_epochs = []
train_ss_s_epochs = []
train_tw_s_epochs = []
train_ts_s_epochs = []
for species in train_set:
    print('Working on', species)
    file_path = os.path.join(epochs_path, file_template.format(species, eeg_epoch_width_in_s))
    with h5py.File(file_path, 'r') as file:
        if 'Sham' in species:
            train_sw_s_epochs.extend(random.sample(list(file['eeg']['wake']), num_samples))
            train_ss_s_epochs.extend(random.sample(list(file['eeg']['sleep']), num_samples))
        elif 'TBI' in species:
            train_tw_s_epochs.extend(random.sample(list(file['eeg']['wake']), num_samples))
            train_ts_s_epochs.extend(random.sample(list(file['eeg']['sleep']), num_samples))
test_sw_s_epochs = []
test_ss_s_epochs = []
test_tw_s_epochs = []
test_ts_s_epochs = []
for species in test_set:
    print('Working on', species)
    file_path = os.path.join(epochs_path, file_template.format(species, eeg_epoch_width_in_s))
    with h5py.File(file_path, 'r') as file:
        if 'Sham' in species:
            test_sw_s_epochs.extend(random.sample(list(file['eeg']['wake']), num_samples))
            test_ss_s_epochs.extend(random.sample(list(file['eeg']['sleep']), num_samples))
        elif 'TBI' in species:
            test_tw_s_epochs.extend(random.sample(list(file['eeg']['wake']), num_samples))
            test_ts_s_epochs.extend(random.sample(list(file['eeg']['sleep']), num_samples))

In [None]:
# generate embeddings for EEG epochs
train_sw_epochs = generate_embeddings(decimate(train_sw_s_epochs, 4), model_file)
train_ss_epochs = generate_embeddings(decimate(train_ss_s_epochs, 4), model_file)
train_tw_epochs = generate_embeddings(decimate(train_tw_s_epochs, 4), model_file)
train_ts_epochs = generate_embeddings(decimate(train_ts_s_epochs, 4), model_file)
test_sw_epochs = generate_embeddings(decimate(test_sw_s_epochs, 4), model_file)
test_ss_epochs = generate_embeddings(decimate(test_ss_s_epochs, 4), model_file)
test_tw_epochs = generate_embeddings(decimate(test_tw_s_epochs, 4), model_file)
test_ts_epochs = generate_embeddings(decimate(test_ts_s_epochs, 4), model_file)

In [None]:
# reduce dimensionality using t-SNE
tsne = TSNE(perplexity=30, early_exaggeration=12.0, learning_rate=600.0, method='exact', init='pca', verbose=1,
            n_iter=40000)
train_sw_epochs = tsne.fit_transform(train_sw_epochs)
train_ss_epochs = tsne.fit_transform(train_ss_epochs)
train_tw_epochs = tsne.fit_transform(train_tw_epochs)
train_ts_epochs = tsne.fit_transform(train_ts_epochs)
test_sw_epochs = tsne.fit_transform(test_sw_epochs)
test_ss_epochs = tsne.fit_transform(test_ss_epochs)
test_tw_epochs = tsne.fit_transform(test_tw_epochs)
test_ts_epochs = tsne.fit_transform(test_ts_epochs)

In [None]:
# plot embeddings
plt.figure(figsize=(24.0, 24.0))
fig, ax = plt.subplots()
ax.scatter(train_sw_epochs[:, 0], train_sw_epochs[:, 1], label='train_sw')
ax.scatter(train_ss_epochs[:, 0], train_ss_epochs[:, 1],label='train_ss')
ax.scatter(train_tw_epochs[:, 0], train_tw_epochs[:, 1], label='train_tw')
ax.scatter(train_ts_epochs[:, 0], train_ts_epochs[:, 1], label='train_ts')
ax.legend()
ax.grid(True)
plt.savefig('{}_train.png'.format(config_file))
plt.show()

In [None]:
plot embeddings
plt.figure(figsize=(24.0, 24.0))
fig, ax = plt.subplots()
ax.scatter(test_sw_epochs[:, 0], test_sw_epochs[:, 1], label='test_sw')
ax.scatter(test_ss_epochs[:, 0], test_ss_epochs[:, 1], label='test_ss')
ax.scatter(test_tw_epochs[:, 0], test_tw_epochs[:, 1], label='test_tw')
ax.scatter(test_ts_epochs[:, 0], test_ts_epochs[:, 1], label='test_ts')
x.legend()
ax.grid(True)
plt.savefig('{}_test.png'.format(config_file))
plt.show()