# Dataset

## Initialize Dataset

In [None]:
from matplotlib import pyplot as plt
from matplotlib import rcParams
from scipy import signal
import pandas as pd
import numpy as np
import requests
import os

In [None]:
fname = 'motor_imagery.npz'
url = "https://osf.io/ksqv8/download"

if not os.path.isfile(fname):
  try:
    r = requests.get(url)
  except requests.ConnectionError:
    print("!!! Failed to download data !!!")
  else:
    if r.status_code != requests.codes.ok:
      print("!!! Failed to download data !!!")
    else:
      with open(fname, "wb") as fid:
        fid.write(r.content)

rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] = 15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

In [None]:
class process():
    def __init__():
        """
        This class is used to preprocess the data.
        The input is a dict with a key 'V' containing the voltage data.
        """
        return

    def preprocess(data):
        V = data['V'].astype('float32')
        b, a = signal.butter(3, [50], btype='high', fs=1000)
        V = signal.filtfilt(b, a, V, 0)
        V = np.abs(V)**2
        b, a = signal.butter(3, [10], btype='low', fs=1000)
        V = signal.filtfilt(b, a, V, 0)
        V = V/V.mean(0)
        return V

class plots():
    def __init__():
        return

    def singlechannel1(data, channel, trange):
        plt.figure(figsize=(20,10))
        plt.plot(trange, data[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def singlechannel2(data, data2, channel, trange):
        plt.figure(figsize=(20,10))
        plt.plot(trange, data[:,channel])
        plt.plot(trange, data2[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def all_channels1(data, trange):
        plt.figure(figsize=(20,10))
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

    def all_channels2(data, data2, trange):
        plt.figure(figsize=(20,10))
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.plot(trange, data2[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

In [None]:
DataLoad = np.load(fname, allow_pickle=True)['dat']
type(DataLoad), len(DataLoad), DataLoad.shape, DataLoad[0][0].keys()

In [None]:
realV = {}
imagineV = {}
realSet = {}
imagineSet = {}

desiredKeys = ['t_off', 'stim_id', 't_on', 'V', 'scale_uv', 'locs', 'srate']

for i in range(7):
    print(f"Sample rate of participant (real) {i}: {DataLoad[i][0]['srate']}")
    print(f"Sample rate of participant (imagine) {i}: {DataLoad[i][1]['srate']}")
    
    x = process.preprocess(DataLoad[i][0])
    nt, nchan = x.shape
    nstim = len(DataLoad[i][0]['t_on'])
    trange = np.arange(0, 2000)
    ts = DataLoad[i][0]['t_on'][:, np.newaxis] + trange
    V_epochs = np.reshape(x[ts, :], (nstim, 2000, nchan))
    print(V_epochs.shape)
    realV[i] = V_epochs
    realSet[i] = {key: DataLoad[i][0][key] for key in desiredKeys}

    y = process.preprocess(DataLoad[i][1])
    nt, nchan = y.shape
    nstim = len(DataLoad[i][1]['t_on'])
    trange = np.arange(0, 2000)
    ts = DataLoad[i][1]['t_on'][:, np.newaxis] + trange
    V_epochs = np.reshape(y[ts, :], (nstim, 2000, nchan))
    print(V_epochs.shape)
    imagineV[i] = V_epochs
    imagineSet[i] = {key: DataLoad[i][1][key] for key in desiredKeys}


In [None]:
# realV contains the preprocessed and properly filtered data for the real movement trials
realV
# realSet contains the metadata for the real movement trials
realSet

# imagineV contains the preprocessed and properly filtered data for the imagined movement trials
imagineV
# imagineSet contains the metadata for the imagined movement trials
imagineSet