In [None]:
from ratdata import data_manager as dm, process, ingest, plot as rdplot
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import scipy.signal as signal
import itertools

In [None]:
dm.db_connect('rat_data.db')

In [None]:
id_ohda = [378, 449]
id_sham = [273]

In [None]:
def recording_length(recording_id):
    r = dm.RecordingFile.select().where(dm.RecordingFile.file_id == recording_id).get()
    if dm.is_recording_rejected(r.filename):
        print(f'Choose a different recording than {r.filename}')
        return None
    recording_data = ingest.read_mce_matlab_file(Path(r.dirname) / r.filename)
    if dm.is_recording_sliced(r.filename):
        recording_length = r.slice.get().length
        print(f'Selecting slice of length {recording_length} from {r.filename}')
    else:
        recording_length = recording_data.electrode_data.shape[-1] / 20000
    return recording_length

In [None]:
min_recording_length = min(min([[recording_length(id) for id in l] for l in [id_ohda, id_sham]]))

In [None]:
fontsize_ax = 22

In [None]:
# plt.figure(figsize=(12, 10))
# ax = plt.gca()
# fs = 20000

# colors = [
#     rdplot.sham_ohda_palette['ohda'],
#     rdplot.sham_ohda_palette['sham']
#     ]
# legend = []
# plot_start = int(2 * min_recording_length)
# plot_stop = int(100 * min_recording_length)
# for i, recording_id in enumerate([id_ohda, id_sham]):
#     r = dm.RecordingFile.select().where(dm.RecordingFile.file_id == recording_id).get()
#     recording_data = ingest.read_mce_matlab_file(Path(r.dirname) / r.filename)
#     electrode_data = recording_data.electrode_data.mean(0)
#     legend.append(f'{r.rat.full_label} ({r.rat.group})')
#     if dm.is_recording_sliced(r.filename):
#         start = int(r.slice.get().start * fs)
#         end = start + int(min_recording_length * fs)
#     else:
#         start = 0
#         end = int(min_recording_length * fs)
#     data_for_psd = electrode_data[start:end]
#     f, psd = signal.welch(data_for_psd, fs, nperseg=len(data_for_psd))
#     ax.tick_params(axis='both', which='major', labelsize=fontsize_ax)
#     ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
#     ax.yaxis.offsetText.set_fontsize(fontsize_ax)
#     ax.plot(f[plot_start:plot_stop], psd[plot_start:plot_stop], color=colors[i], linewidth=3, alpha=0.8)
#     ax.set_xlabel('Frequency [Hz]', fontsize=fontsize_ax)
#     ax.set_ylabel('Power spectral density [mV$^2$/Hz]', fontsize=fontsize_ax)
# ax.set_title(f'FFT PSD (segment length = {min_recording_length} s)', fontsize=fontsize_ax)
# plt.legend(legend, fontsize=fontsize_ax)

# plt.savefig('plots/6ohda_vs_sham.png', bbox_inches='tight')
# plt.savefig('plots/6ohda_vs_sham.svg', bbox_inches='tight')

In [None]:
def plot_two_spectra_by_file_id(id1, id2, ax, colors, plot_start_n, plot_stop_n, fs, title, corner_label):
    legend = []
    for i, recording_id in enumerate([id1, id2]):
        r = dm.RecordingFile.select().where(dm.RecordingFile.file_id == recording_id).get()
        recording_data = ingest.read_mce_matlab_file(Path(r.dirname) / r.filename)
        electrode_data = recording_data.electrode_data.mean(0)
        if r.rat.group == 'control':
            group = 'sham'
        else:
            group = '6-OHDA'
        legend.append(f'{r.rat.full_label} ({group})')
        if dm.is_recording_sliced(r.filename):
            start = int(r.slice.get().start * fs)
            end = start + int(min_recording_length * fs)
        else:
            start = 0
            end = int(min_recording_length * fs)
        data_for_psd = electrode_data[start:end]
        f, psd = signal.welch(data_for_psd, fs, nperseg=(2*fs))
        ax.tick_params(axis='both', which='major', labelsize=fontsize_ax)
        ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
        ax.yaxis.offsetText.set_fontsize(fontsize_ax)
        ax.plot(f[plot_start_n:plot_stop_n], psd[plot_start_n:plot_stop_n], color=colors[i], linewidth=4, alpha=0.8)
        ax.set_xlabel('Frequency [Hz]', fontsize=fontsize_ax)
        ax.set_ylabel('Power spectral density [mV$^2$/Hz]', fontsize=fontsize_ax)
        ax.set_title(f'{title}', fontsize=fontsize_ax)
        ax.set_ylim([0, 2.0e-4])
        ax.text(-20, 2.12e-4, corner_label, fontsize=2 * fontsize_ax)
    ax.legend(legend, fontsize=fontsize_ax)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(24, 8))
fs = 20000

colors = [
    rdplot.sham_ohda_palette['ohda'],
    rdplot.sham_ohda_palette['sham']
    ]
plot_start = int(2 / 0.5)
plot_stop = int(100 / 0.5)

i = 0
titles = ['Beta peak in 6-OHDA rat only', 'Beta peak in neither of the rats']
plot_labels = ['A', 'B']
for pair in itertools.product(id_ohda, id_sham):
    ax = axs[i]
    plot_two_spectra_by_file_id(pair[1], pair[0], ax, colors[::-1], plot_start, plot_stop, fs, titles[i], plot_labels[i])
    i = i + 1

plt.savefig('plots/6ohda_vs_sham_2_100_Hz.png', bbox_inches='tight')
plt.savefig('plots/6ohda_vs_sham_2_100_Hz.svg', bbox_inches='tight')