In [167]:
import os
import itertools
import time
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import OneClassSVM
from sklearn.ensemble import IsolationForest
from sklearn.metrics import f1_score
from imblearn.over_sampling import SMOTE
from collections import Counter

In [27]:
DATA_DIR = '/u/13/italinv1/unix/PycharmProjects/veera-thesis/aparc_data/'
BATCH_SIZE = 64
TEST_SPLIT = 0.2
RANDOM_SEED = 123
AGG = 'max'

cases = ['%03d' % n for n in range(28)]

freq_bands = [
    (1, 4),
    (4, 8),
    (8, 10),
    (10, 13),
    (13, 30),
    (30, 40)
]

In [3]:
def download_from_csv(filename):
    # Helper for creating a matrix out of csv data
    df = pd.read_csv(filename, header=None)
    matrix = df.values

    return matrix

In [56]:
def preprocess(data, normalize=True, decim_freqs=None):
    if data.shape[-1] % 2 != 0:
        data = np.insert(data, -1, data[:,:,-1], axis=2)
        
    if decim_freqs and type(decim_freqs) == int:
        data = data[:, :, ::decim_freqs]
    
    elif type(decim_freqs) == str and decim_freqs == 'bands':
        binned_data = []
        for (lo, hi) in freq_bands:
            lo_ind, hi_ind = lo*8, hi*8
            binned_data.append(np.max(data[:,:,lo_ind:hi_ind], axis=-1))
        data = np.transpose(np.asarray(binned_data), (1, 2, 0))
    
    data = data.reshape(data.shape[0], -1)

    if normalize:
        x_scaler = StandardScaler()
        data = x_scaler.fit_transform(data)
    
    return data

In [144]:
def get_dataset(data_path=None):
    if data_path is None:
        return
    
    X = []
    y = []
    for filename in glob(os.path.join(data_path, f"*{AGG}-aparc-data.csv")):
        data_arr = download_from_csv(filename)
        if os.path.basename(filename)[:3] in cases:
            label = 1
        else:
            label = 0
        X.append(data_arr)
        y.append(label)
    X = preprocess(np.asarray(X), decim_freqs='bands', normalize=False)
    return X, np.asarray(y, dtype=int)

In [145]:
X, y = get_dataset(DATA_DIR)
print(X.shape, y.shape)

(666, 18000) (666,)


In [170]:
sm = SMOTE(random_state=RANDOM_SEED, sampling_strategy={0:6000, 1:25})
X_res, y_res = sm.fit_resample(X, y)
print('Original dataset shape %s' % Counter(y))
print('Resampled dataset shape %s' % Counter(y_res))



Original dataset shape Counter({0: 641, 1: 25})
Resampled dataset shape Counter({0: 6000, 1: 25})


In [171]:
X_train, X_test, y_train, y_test = train_test_split(X_res, y_res, test_size=TEST_SPLIT, random_state=RANDOM_SEED)

In [172]:
X_test.shape

(1205, 18000)

In [173]:
y_test.shape

(1205,)

In [174]:
print("Train set is %d documents (%d positive)" % (len(y_train), sum(y_train)))
print("Test set is %d documents (%d positive)" % (len(y_test), sum(y_test)))

Train set is 4820 documents (22 positive)
Test set is 1205 documents (3 positive)


In [193]:
contamination = sum(y_train) / len(y_train)
#clf = SGDClassifier(random_state=RANDOM_SEED, penalty='l1', class_weight={0:5, 1:1})
#clf = OneClassSVM(gamma='scale', nu=0.04)
clf = IsolationForest(contamination=contamination, n_estimators=100, n_jobs=2)

In [194]:
clf.fit(X_train[y_train==0])

IsolationForest(contamination=0.004564315352697096, n_jobs=2)

In [195]:
print(clf.offset_)

-0.5332244137809049


In [196]:
y_hat = clf.predict(X_test)
y_hat

array([1, 1, 1, ..., 1, 1, 1])

In [197]:
y_test2 = y_test.copy()
y_test2[y_test==0] = 1
y_test2[y_test==1] = -1
y_test2

array([1, 1, 1, ..., 1, 1, 1])

In [199]:
f1_score(y_test2, y_hat, pos_label=-1)

0.18181818181818182

In [206]:
def print_results(y_true, y_pred):

    total_anom = sum(y_true == -1)  
    tp = sum((y_true == -1) & (y_pred == -1))
    fp = sum((y_true != -1) & (y_pred == -1))
    tn = sum((y_true != -1) & (y_pred != -1))
    fn = sum((y_true == -1) & (y_pred != -1))
    
    print('[TP] {}\t\t[FP] {}\t\t[MISSED] {}'.format(tp, fp, total_anom-tp))
    print('[TN] {}\t[FN] {}'.format(tn, fn))

In [207]:
print_results(y_test2, y_hat)

[TP] 1		[FP] 7		[MISSED] 2
[TN] 1195	[FN] 2
