Imports

In [1]:
import mne
from mne.datasets import sample
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
raw = mne.io.Raw(raw_fname, preload=True)

import numpy as np

import pywt

import csv

Opening raw data file /anaconda/lib/python2.7/examples/MNE-sample-data/MEG/sample/sample_audvis_filt-0-40_raw.fif...
    Read a total of 4 projection items:
        PCA-v1 (1 x 102)  idle
        PCA-v2 (1 x 102)  idle
        PCA-v3 (1 x 102)  idle
        Average EEG reference (1 x 60)  idle
Current compensation grade : 0
    Range : 6450 ... 48149 =     42.956 ...   320.665 secs
Ready.
Reading 0 ... 41699  =      0.000 ...   277.709 secs...
[done]


Function to reconstruct with wavelet basis

In [2]:
def getSDFromMAD(x):
    mad = np.median(abs(x-np.median(x)))
    return 1.4826*mad

def reconstruct(data, wavelet_type, level, num_sd):
    coefs = pywt.wavedec(data, wavelet_type, level=level)
    concatenated_coefs = np.concatenate(coefs)
    detail_coefs = coefs[1:]
    #Get median absolute deviation of detail coefficients
    detail_sd_approx = getSDFromMAD(np.concatenate(detail_coefs))
    num_detail_coefs = 0;
    num_zeroed = 0;
    if (len(coefs) > 1):
        for i in range(1,len(coefs)):
            for j in range(0,len(coefs[i])):
                num_detail_coefs += 1;
                if (abs(coefs[i][j]) < detail_sd_approx*num_sd):
                    coefs[i][j] = 0
                    num_zeroed += 1
    reconstructed = pywt.waverec(coefs, wavelet_type)
    return (reconstructed, concatenated_coefs, num_zeroed)

Function to evaluate performance and ratio of coefficients that were zeroed

In [3]:
def r2stat(y, f):
    ybar = np.mean(y)
    sstot = np.sum((y - ybar)**2)
    ssres = np.sum((y - f)**2)
    return 1 - ssres / sstot

def evaluate(data, wavelet_type, level, num_sd):
    (reconstructed, concatenated_coefs, num_zeroed) = \
        reconstruct(data, wavelet_type, level, num_sd)
    reconstructed = reconstructed[0:len(data)]
    
    num_coefs = len(concatenated_coefs)
    ratio_zeroed = float(num_zeroed) / num_coefs
    
    r2_stat = r2stat(data, reconstructed)
    
    return(r2_stat, ratio_zeroed)

Function to evaluate each wavelet for certain number of standard deviations, store R2 value and ratio zeroed in csv tables.

In [12]:
def getWaveletPerformance(data, file_name):
    wavelet_list = ['haar', 
                'db2', 'db3', 'db5', 'db9', 'db16',
                'sym2', 'sym3', 'sym5', 'sym9', 'sym16',
                'coif1', 'coif2', 'coif3', 'coif5',
                'bior1.3', 'bior2.2', 'bior2.8', 'bior3.1', 'bior3.7', 'bior6.8',
                'dmey']
    
    with open(file_name, 'wb') as f:
        writer = csv.writer(f)
        writer.writerow(['Wavelet Name','R2, 3 SD', 'R2, 4 SD', 'Ratio Zeroed, 3 SD', 'Ratio Zeroed, 4 SD'])
        for wavelet_type in wavelet_list:
            (r2_3sd, ratio_zeroed_3sd) = evaluate(data, wavelet_type, None, 3)
            (r2_4sd, ratio_zeroed_4sd) = evaluate(data, wavelet_type, None, 4)
            writer.writerow([wavelet_type, 
                             "%.4f" % r2_3sd, 
                             "%.4f" % r2_4sd, 
                             "%.4f" % ratio_zeroed_3sd, 
                             "%.4f" % ratio_zeroed_4sd])
    
    f.close()
    
    

Get EEG data, 10 sec, evaluate wavelet performance

In [13]:
eeg_channel_data = raw.pick_channels(['EEG 020'], copy=True)
start, stop = eeg_channel_data.time_as_index([50, 60])
data, times = eeg_channel_data[:, start:stop]
data = data[0,:]

getWaveletPerformance(data, 'eeg_wavelet_performance.csv')

Get MEG data, 10 sec, evaluate wavelet performance

In [14]:
meg_channel_data = raw.pick_channels(['MEG 0113'], copy=True)
start, stop = meg_channel_data.time_as_index([50, 60])
data, times = meg_channel_data[:, start:stop]
data = data[0,:]

getWaveletPerformance(data, 'meg_wavelet_performance.csv')

Get EEG data, all time, evaulate wavelet performance

In [15]:
data, times = eeg_channel_data[:, :]
data = data[0,:]

getWaveletPerformance(data, 'eeg_wavelet_all_time_performance.csv')