# Decoding Subjects from Oscillation Data

...

In [1]:
%matplotlib inline

import random
import numpy as np
import matplotlib.pyplot as plt

import itertools

from om.meg.single import MegSubj
from om.meg.decoding import *

from om.core.db import OMDB
from om.core.osc import Osc
from om.core.io import load_meg_list

In [2]:
# Get database object, set up database to use and check available files
db = OMDB()
dat_source = 'HCP'
sub_nums, source = db.check_dat_files('fooof', dat_source, verbose=True)

# Drop outlier subject
sub_nums = list(set(sub_nums) - set([662551]))

# Set up oscillation object
osc = Osc(default=True)


Number of Subjects available: 81

Subject numbers with FOOOF data available: 
[100307, 102816, 105923, 106521, 109123, 111514, 112920, 113922, 116524, 116726, 140117, 146129, 153732, 154532, 156334, 158136, 162026, 162935, 164636, 166438, 172029, 174841, 175237, 175540, 181232, 185442, 187547, 189349, 191033, 191437, 191841, 192641, 195041, 198653, 204521, 205119, 212318, 212823, 214524, 221319, 223929, 233326, 248339, 250427, 255639, 257845, 283543, 287248, 293748, 352132, 352738, 353740, 358144, 406836, 433839, 512835, 555348, 559053, 568963, 581450, 599671, 601127, 660951, 662551, 665254, 667056, 679770, 706040, 707749, 715950, 725751, 735148, 783462, 814649, 825048, 877168, 891667, 898176, 912447, 917255, 990366]



## KNN Classification

The following KNN classification is trained with oscillations from a group of subjects, where the test question is whether, given a single oscillations from a hold-out test set, can we decode which subject that oscillation comes from. 

Features:
- This is using 3 features: centers frequency, power and bandwidth, from all oscillations (not band specific)

Note:
- This approach is using data from with a single run, for each subject.
- Given this, this analysis is basically asking: are individuals subjects' oscillations idiosyncratic enough that, given a new oscillation from the same run, can we can guess which subject it comes from.
- This isn't really how we want to be decoding, but having data pulled from separate epochs requires a significant amount of re-organizing and re-computing data that is not ready yet. 

In [3]:
#
group_size = 4
combinations = list(itertools.combinations(sub_nums, group_size))

# Get a random sample of possible combinations
n_run = 2
rand_inds = random.sample(range(len(combinations)), n_run)
comb_run = [combinations[i] for i in rand_inds]

print(len(comb_run))

2


In [4]:
#
all_results = []
for comb in comb_run:
    
    cur_results = []
    subjs = load_meg_list(comb, all_oscs=True, db=db, dat_source=dat_source)
    
    for i in range(10):
        cur_results.append(knn(subjs))
    
    all_results.append(np.mean(cur_results))

grand_avg = np.mean(all_results)
chance = 1/group_size

In [5]:
print('Overall performance is {:4.2f}%, with chance performance of {:4.2f}%'.format(grand_avg*100, chance*100))

Overall performance is 53.40%, with chance performance of 25.00%


## TESTS

In [12]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, cross_val_score

In [23]:
#cur_subs = [706040, 707749, 106521, 662551, 665254, 667056, 679770]
cur_subs = sub_nums

subjs = load_meg_list(cur_subs, all_oscs=True, db=db, dat_source=dat_source)

dat, labels = arrange_dat(subjs)

In [24]:
# 
neigh = KNeighborsClassifier(n_neighbors=10)
score = cross_val_score(neigh, dat, labels, cv=20)

In [25]:
print('Overall performance is {:4.2f}%, with chance performance of {:4.2f}%'.format(
    np.mean(score)*100, 1/len(subjs)*100))

Overall performance is 9.08%, with chance performance of 1.25%


In [22]:
res = []
for i in range(20):
    res.append(knn(subjs))
    
print(np.mean(res))

0.204


## TESTING

In [None]:
from mpl_toolkits.mplot3d import Axes3D

In [None]:
meg_subj_1, meg_subj_2 = subjs[0], subjs[1]

In [None]:
inds_1 = [0, 1, 2, 3]
inds_2 = [0, 1, 2, 3]

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(meg_subj_1.centers_all[inds_1], meg_subj_1.powers_all[inds_1], meg_subj_1.bws_all[inds_1])
plt.hold()
ax.scatter(meg_subj_2.centers_all[inds_2], meg_subj_2.powers_all[inds_2], meg_subj_2.bws_all[inds_2], color='r')
plt.hold()

In [None]:
plt.scatter(meg_subj_1.centers_all[inds_1], meg_subj_1.bws_all[inds_1])
plt.hold()
plt.scatter(meg_subj_2.centers_all[inds_2], meg_subj_2.bws_all[inds_2], color='r')