# MEPs in diseased vs healthy mice

In [None]:
from tqdm import tqdm
from mepextract.extracting import Extractor
from scipy.signal import find_peaks
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
import os
import pickle
import json

plt.style.use(['science', 'grid', 'custom', 'vibrant'])

## defining relevant folders and paths

In [None]:
local = '/Users/fomo/Documents/Research/UNIC Research/Motor Evoked Potentials Test/MICE_EEG_2024 (93-257).csv'

temp = pd.read_csv(local)

temp2 = temp.iloc[:10]

temp2.to_csv('/Users/fomo/Documents/Research/UNIC Research/Motor Evoked Potentials Test/MICE_EEG_2024 (93-102).csv', index=False)

In [None]:
test = '/Users/fomo/Documents/Research/UNIC Research/Motor Evoked Potentials Test/MICE_EEG_2024 (93-102).csv'
full = '/Volumes/STORAGE 1.0/UNIC Research/Motor Evoked Potentials/MEPSSEP_raw_data/post 93/MICE_EEG_2024 (93-257).csv'

spreadsheet = pd.read_csv(full)

master_folder = '/Volumes/STORAGE 1.0/UNIC Research/Motor Evoked Potentials/MEPSSEP_raw_data/post 93'
test_folder = '/Users/fomo/Documents/Research/UNIC Research/Motor Evoked Potentials Test'


spreadsheet

## extracting MEPs for all trials

In [None]:
# number of trials to iterate over
n = len(spreadsheet)
sampling_rate = 30000

extracted = []

for i in tqdm(range(0, n), desc="Processing trials"):
    if spreadsheet['sessionType'][i] == 'reject':
        continue
    else:
        # relevant information
        trial = spreadsheet.session[i]
        notes = spreadsheet.loc[i]
        
        # defining extractor object
        extractor = Extractor(
            master_folder=master_folder , 
            trial=trial, 
            group=phenotype, 
            recording_channels=[5, 7],
            sampling_rate=sampling_rate)
        extractor.pre_stimulus = 300
        extractor.post_stimulus = 3000
        extractor.notes = notes
        
        # extracting relevant data
        extractor.extract_raw()
        extractor.extract_events(event_channel_number=13)
        extractor.get_event_data(export=False)
        
        detected = {'trial': trial, 'notes': notes, 'data': extractor.mep}
    
        extracted.append(detected)
    

## test code for first ten trials

In [None]:
from scipy.ndimage import gaussian_filter1d as gf

# defining relevant channels, peak parameters, time axis and current level
time_in_ms = np.arange(-300, 3000) * (1/sampling_rate * 1000)
recording_channels = [5, 7]
search = {'height': (0, 10) ,'width': (1, 50), 'distance': 400}

current_mean = spreadsheet['currentLevel'].mean()
current_std = spreadsheet['currentLevel'].std() 

# temporary data
n = len(extracted)
# fig, ax = plt.subplots(n, 1, figsize=(21, 7*n), dpi=210)

detected_peaks = []

for i in tqdm(range(n), desc="Processing trials"):
    
    temp = extracted[i]['data']
    current = extracted[i]['notes']['currentLevel']
    group = int(extracted[i]['notes']['phenoCode'])
    
    for channel in recording_channels:
        
        current = (current - current_mean)/current_std
        
        data = temp[channel, :, :]
        baseline = temp[channel, :10, :].mean(axis=1).mean(axis=0)
        correction = np.full((1, data.shape[0]), baseline)
        
        mean_events = (data.mean(axis=1)).flatten()
        std_events = (data.std(axis=1)).flatten()
        
        smoothed_mean = gf(mean_events, sigma=1.25)
        smoothed_std = gf(std_events, sigma=1.25)
        
        standardised = (1/current)*(smoothed_mean - smoothed_mean.mean())/smoothed_mean.std()
        
        pos_peaks, pos_parameters = find_peaks(standardised, height = search['height'], distance = search['distance'], width = search['width'])
        
        neg_peaks, neg_parameters = find_peaks(-standardised, height = search['height'], distance = search['distance'], width = search['width'])
        
        delay = {'peaks': [], 'group': group, 'current': current}
        
        for peak in pos_peaks:
            if 300 < peak < 3000:
                amplitude = standardised[peak]
                delay['peaks'].append((peak, amplitude))
        
        for peak in neg_peaks:
            if 300 < peak < 3000:
                amplitude = standardised[peak]
                delay['peaks'].append((peak, amplitude))
                
        detected_peaks.append(delay)
        
#         ax[i].plot(time_in_ms, standardised, label=f'Channel {channel}')
#         ax[i].fill_between(time_in_ms, standardised - 1, standardised + 1, alpha=0.5)
#     
#         # Plot detected peaks
#         ax[i].plot(time_in_ms[pos_peaks], standardised[pos_peaks], 'x', color='red')
#         ax[i].plot(time_in_ms[neg_peaks], standardised[neg_peaks], 'x', color='blue')
#         
#         ax[i].set_title(f'Trial {extracted_peaks[i]["trial"]}')
#         ax[i].set_ylim(-15, 15)
#         # ax[i].set_xlim(-10, 10)
#         ax[i].set_xlabel('Time (ms)')
#         ax[i].set_ylabel('Amplitude')
#         ax[i].legend()
#         
# plt.tight_layout()
# plt.show()
# fig.savefig('/Users/fomo/Desktop/first_ten_trials_zone_one.png')

## plotting delays

In [None]:
detected_peaks

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(14, 7), dpi=210)

# Assuming sampling_rate and detected_peaks are defined

for trial in detected_peaks:
    group = trial['group']
    current = trial['current']

    # Determine color based on group
    color = 'green' if group == 1 else 'red'

    for peak in trial['peaks']:
        plt.scatter(peak[0], np.abs(peak[1]), color=color, alpha=0.5)

plt.xlabel('Delay', fontsize=14)  
plt.ylabel('Amplitude in S.Ds', fontsize=14) 
plt.title('Scatter Plot of Peaks by Group in 5ms after artefact', fontsize=16)
plt.tight_layout()
plt.savefig('/Users/fomo/Desktop/first_ten_trials_zone_one.png')  
plt.show() 
