In [None]:
from ratdata import data_manager as dm, process, plot as rdplot
import matplotlib.pyplot as plt
import numpy as np
import pickle

In [None]:
dm.db_connect('rat_data.db')

In [None]:
# rec = dm.RecordingFile.select().where(dm.RecordingFile.filename=='2021-05-24T09-48-46 rat3 - OFT2.mat').get()
# rec = dm.RecordingFile.select().where(dm.RecordingFile.filename=='2021-07-16T09-02-07 rat3 OFT2 random.mat').get()
rec = dm.RecordingFile.select().where(dm.RecordingFile.filename=='2021-07-12T09-23-24 rat3 OFT2 on-off.mat').get()
# rec = dm.RecordingFile.select().where(dm.RecordingFile.filename=='2021-05-19T08-57-35 rat3 - OFT2 130Hz.mat').get()

In [None]:
data, dt, time = dm.get_electrode_data_from_recording(rec, select_slice=True)
mean_data = np.mean(data, axis=0)
fs = int(1/dt)
title = 'Biomarker for %s' % rec.filename

In [None]:
%matplotlib widget
rdplot.plot_biomarker_steps(mean_data[57 * fs: 65 * fs], fs, time=time, plot_title=title)

In [None]:
perc = 50
s, env1, env2 = process.beta_envelope(mean_data, fs, 300)
th1 = np.percentile(env1, perc)
th2 = np.percentile(env2, perc)
bursts1 = process.beta_bursts(env1, th1)
bursts2 = process.beta_bursts(env2, th2)
print(th1, th2)
durations1 = [(e[1] - e[0]) / 300 for e in bursts1]
durations2 = [(e[1] - e[0]) / 300 for e in bursts2]
amplitudes1 = [e[2] for e in bursts1]
amplitudes2 = [e[2] for e in bursts2]
plt.figure(figsize=(20, 9))
ax1 = plt.subplot2grid((2, 2), (0, 0), colspan=2)
ax1.plot(s[:1000])
ax1.plot(env1[:1000])
ax1.plot(env2[:1000])
ax2 = plt.subplot2grid((2, 2), (1, 0))
ax2.scatter(durations1, amplitudes1)
ax3 = plt.subplot2grid((2, 2), (1, 1))
ax3.scatter(durations1, amplitudes1)

In [None]:
rat_env_all = dict()
for rat in dm.Rat.select():
    print(f'{rat.full_label}:')
    rat_env_all[rat.full_label] = dict()
    env1_all = np.array([])
    env2_all = np.array([])
    for s in dm.StimSettings.select(dm.StimSettings.stim_type).distinct():
        st = s.stim_type
        print(st, ':')
        for recording in dm.RecordingFile.select().join(dm.StimSettings).where((dm.RecordingFile.rat == rat) & (dm.StimSettings.stim_type == st)):
            print(recording.filename)
            data, dt, time = dm.get_electrode_data_from_recording(recording, select_slice=True)
            mean_data = np.mean(data, axis=0)
            fs = int(1/dt)
            s, env1, env2 = process.beta_envelope(mean_data, fs, 300)
            env1_all = np.concatenate((env1_all, env1))
            env2_all = np.concatenate((env2_all, env2))
        rat_env_all[rat.full_label][st] = {
            'hilbert': env1_all,
            'rectma': env2_all
        }

In [None]:
perc = 75
for rat in dm.Rat.select():
    print(f'{rat.full_label}:')
    rat_bursts = dict()
    rat_bursts['perc'] = perc
    rat_bursts['rat'] = rat.full_label
    for s in dm.StimSettings.select(dm.StimSettings.stim_type).distinct():
        st = s.stim_type
        durations = []
        durations_rect = []
        amplitudes = []
        amplitudes_rect = []
        print(st, ':')
        for recording in dm.RecordingFile.select().join(dm.StimSettings).where((dm.RecordingFile.rat == rat) & (dm.StimSettings.stim_type == st)):
            data, dt, time = dm.get_electrode_data_from_recording(recording, select_slice=True)
            mean_data = np.mean(data, axis=0)
            fs = int(1/dt)
            s, env1, env2 = process.beta_envelope(mean_data, fs, 300)
            th1 = np.percentile(env1, perc)
            th2 = np.percentile(env1, perc)
            bursts1 = process.beta_bursts(env1, th1)
            bursts2 = process.beta_bursts(env2, th2)
            durations1 = [(e[1] - e[0]) / 300 for e in bursts1]
            durations2 = [(e[1] - e[0]) / 300 for e in bursts2]
            amplitudes1 = [e[2] for e in bursts1]
            amplitudes2 = [e[2] for e in bursts2]
            amplitudes.extend(amplitudes1)
            amplitudes_rect.extend(amplitudes2)
            durations.extend(durations1)
            durations_rect.extend(durations2)
            stim_type = recording.stim.get_or_none().stim_type
            print(recording.filename, stim_type)
        rat_bursts[st] = {
            'amplitudes': amplitudes,
            'durations': durations,
            'amplitudes_rect': amplitudes_rect,
            'durations_rect': durations_rect
        }
    with open(f'data/beta_bursts_{perc}_{rat.full_label}_13_30_Hz.pickle', 'wb') as f:
        pickle.dump(rat_bursts, f)

In [None]:
for perc in [50, 75]:
    for rat in dm.Rat.select():
        with open(f'data/beta_bursts_{perc}_{rat.full_label}.pickle', 'rb') as f:
            rat_bursts = pickle.load(f)
        perc = rat_bursts['perc']
        rat_label = rat_bursts['rat']
        rat_type = rat.group
        rdplot.plot_beta_bursts_one_rat(rat_bursts, rat_type, f'plots/beta_bursts/{perc}_{rat_label}')

In [None]:
order = {
    'rat C4': (2, 0),
    'rat C6': (0, 3),
    'rat B1': (0, 1),
    'rat B2': (1, 1),
    'rat B3': (2, 1),
    'rat B5': (1, 0),
    'rat A4': (0, 0),
    'rat D1': (0, 2),
    'rat D2': (1, 2),
    'rat D4': (2, 2)
}
for perc in [75]:
    fig, axs = plt.subplots(3, 4, figsize=(20, 15))
    fig_zoom, axs_zoom = plt.subplots(3, 4, figsize=(20, 15))
    fig_rect, axs_rect = plt.subplots(3, 4, figsize=(20, 15))
    fig_rect_zoom, axs_rect_zoom = plt.subplots(3, 4, figsize=(20, 15))
    max_amplitude = 0
    max_duration = 0
    for rat in dm.Rat.select():
        with open(f'data/beta_bursts_{perc}_{rat.full_label}_13_30_Hz.pickle', 'rb') as f:
            rat_bursts = pickle.load(f)
        perc = rat_bursts['perc']
        rat_label = rat_bursts['rat']
        rat_bursts['bounds'] = [14, 18]
        rat_type = rat.group
        ii = order[rat.full_label]
        if rat_type == 'control':
            c = rdplot.sham_ohda_palette['sham']
        else:
            c = rdplot.sham_ohda_palette['ohda']
        d = np.array(rat_bursts['nostim']['durations'])
        a = np.array(rat_bursts['nostim']['amplitudes'])
        idx = d > 0.1
        d = d[idx]
        a = a[idx]
        d_r = np.array(rat_bursts['nostim']['durations_rect'])
        a_r = np.array(rat_bursts['nostim']['amplitudes_rect'])
        idx_r = d_r > 0.1
        d_r = d_r[idx_r]
        a_r = a_r[idx_r]
        
        if max(d) > max_duration:
            max_duration = max(d)
        if max(a) > max_amplitude:
            max_amplitude = max(a)
        axs[ii].scatter(d, a, color=c)
        axs[ii].set_title(rat.full_label)
        axs[ii].axvline(np.mean(d))
        axs[ii].axhline(np.mean(a))
        axs_zoom[ii].scatter(d, a, color=c)
        axs_zoom[ii].set_title(rat.full_label)
        axs_zoom[ii].axvline(np.mean(d))
        axs_zoom[ii].axhline(np.mean(a))
        axs_rect[ii].scatter(d_r, a_r, color=c)
        axs_rect[ii].set_title(rat.full_label)
        axs_rect[ii].axvline(np.mean(d))
        axs_rect[ii].axhline(np.mean(a))
        axs_rect_zoom[ii].scatter(d_r, a_r, color=c)
        axs_rect_zoom[ii].set_title(rat.full_label)
        axs_rect_zoom[ii].axvline(np.mean(d))
        axs_rect_zoom[ii].axhline(np.mean(a))
        
        for ii in order.values():
            axs[ii].set_xlabel('Duration [s]')
            axs[ii].set_ylabel('Amplitude [a.u.]')
            axs[ii].set_xlim([-0.1 * max_duration, 1.1 * max_duration])
            axs[ii].set_ylim([-0.1 * max_amplitude, 1.1 * max_amplitude])
            axs_zoom[ii].set_xlabel('Duration [s]')
            axs_zoom[ii].set_ylabel('Amplitude [a.u.]')
            axs_zoom[ii].set_xlim([-0.01, 0.51])
            axs_zoom[ii].set_ylim([-0.01, 0.51])
            axs_rect[ii].set_xlabel('Duration [s]')
            axs_rect[ii].set_ylabel('Amplitude [a.u.]')
            axs_rect[ii].set_xlim([-0.1 * max_duration, 1.1 * max_duration])
            axs_rect[ii].set_ylim([-0.1 * max_amplitude, 1.1 * max_amplitude])
            axs_rect_zoom[ii].set_xlabel('Duration [s]')
            axs_rect_zoom[ii].set_ylabel('Amplitude [a.u.]')
            axs_rect_zoom[ii].set_xlim([-0.01, 1.01])
            axs_rect_zoom[ii].set_ylim([-0.01, 1.01])
        for ii in [(1, 3), (2, 3)]:
            axs[ii].axis('off')
            axs_zoom[ii].axis('off')
            axs_rect[ii].axis('off')
            axs_rect_zoom[ii].axis('off')
    
    fig.suptitle('Beta (13-30) burst duration and amplitude under no stimulation (Hilbert)', size=18)
    fig_zoom.suptitle('Beta (13-30) burst duration and amplitude under no stimulation (zoomed in) (Hilbert)', size=18)
    fig_rect.suptitle('Beta (13-30) burst duration and amplitude under no stimulation (rect MA)', size=18)
    fig_rect_zoom.suptitle('Beta (13-30) burst duration and amplitude under no stimulation (zoomed in) (rect MA)', size=18)
    fig.subplots_adjust(hspace=0.25, top=0.93)
    fig_zoom.subplots_adjust(hspace=0.25, top=0.93)
    fig_rect.subplots_adjust(hspace=0.25, top=0.93)
    fig_rect_zoom.subplots_adjust(hspace=0.25, top=0.93)