In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import plot_utils as u
import scipy.signal as signal
import re

In [None]:
result_dir = Path("beta_modulation/")

In [None]:
num_sims = 0
for d in result_dir.iterdir():
    if (d / "Cortical_Pop").exists() and (d / "Cortical_Pop").is_dir():
        num_sims += 1

fig, axs = plt.subplots(num_sims, 4, figsize=(30, num_sims * 3))
i = 0
for d in result_dir.iterdir():
    if (d / "Cortical_Pop").exists() and (d / "Cortical_Pop").is_dir():
        sig = u.load_cortical_soma_voltage(d)
        avgsig = np.mean(sig, axis=1)
        axs[i][0].plot(avgsig)
        f, spectrum = signal.welch(avgsig[15000:], fs=2000)
        axs[i][1].plot(f[:50], spectrum[:50])
        fxx, txx, sxx = signal.spectrogram(avgsig, fs=2000)
        axs[i][2].pcolormesh(txx, fxx[:10], sxx[:10, :])
        wavelet_freq = 25
        wavelet_width = 5 * 2000 / (2 * wavelet_freq * np.pi)
        sww = signal.cwt(avgsig, signal.morlet2, [wavelet_width])
        axs[i][3].plot(np.abs(sww[0]))
        i += 1

In [None]:
tt = np.loadtxt("../Cortex_BasalGanglia_DBS_model/burst_times_1.txt", delimiter=",")
aa = np.loadtxt("../Cortex_BasalGanglia_DBS_model/burst_level_1.txt", delimiter=",")
stt, modulation_signal = u.burst_txt_to_signal(tt, aa, 6000, 30000, 0.01)
plt.figure(figsize=(20, 10))
plt.plot(stt, modulation_signal)
plt.scatter(tt, aa)

In [None]:
fontsize = 15
fontsize_big = 22

for d in result_dir.iterdir():
    modulation_amplitude = 100
    modulation_amplitude_text = "???"
    if (d / "Cortical_Pop").exists() and (d / "Cortical_Pop").is_dir():
        if len(list(d.glob("*.out"))) == 1:
            output_file = list(d.glob("*.out"))[0]
            with open(output_file, 'r') as f:
                for line in f:
                    m = re.match("'beta_burst_modulation_scale': ([\.0-9]+),", line)
                    if m:
                        modulation_amplitude_text = m.groups()[0]
                        modulation_amplitude *= float(modulation_amplitude_text)
        times = np.loadtxt(d / "controller_sample_times.csv")
        beta = np.loadtxt(d / "controller_beta_values.csv")
        sig = u.load_cortical_soma_voltage(d)
        spikes = u.find_population_spikes(sig)
        plt.figure(figsize=(28, 10))
        skip = 3
        for cell in spikes[::skip]:
            x, y = list(zip(*cell))
            x = [e / 2 for e in x]
            y = [e / skip for e in y]
            spikes = plt.scatter(x, y, s=5, c="#666666")
        ax = plt.gca()
        modulation, = ax.plot(stt, modulation_amplitude * modulation_signal - 1)
        controller_beta, = ax.plot(times * 1000, 8000 * beta - 7)
        zero_beta = ax.axhline(-7, color='black')
        avgsig = np.mean(sig, axis=1)
        wavelet_freq = 25
        wavelet_width = 5 * 2000 / (2 * wavelet_freq * np.pi)
        sww = signal.cwt(avgsig, signal.morlet2, [wavelet_width])
        wavelet_amplitude, = ax.plot(np.linspace(0, 30000, len(sww[0])), 0.08 * np.abs(sww[0]) - 12)
        ax.set_title(f"{d.name}, modulation amplitude={modulation_amplitude_text}", fontsize=fontsize)
        ax.set_xlabel('Time [ms]', fontsize=fontsize)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_yticks([])
        ax.set_yticklabels([])
        plt.setp(ax.get_xticklabels(), fontsize=fontsize)
        ax.legend([spikes, modulation, controller_beta, zero_beta, wavelet_amplitude], ['CTX neuron spikes', 'Modulating current', 'Controller beta', 'Zero beta level', 'Morlet wavelet amplitude (avg CTX V)'], fontsize=fontsize, bbox_to_anchor=(0, 0.05), loc="lower left")

In [None]:
for d in result_dir.iterdir():
    times = np.loadtxt(d / "controller_sample_times.csv")
    beta = np.loadtxt(d / "controller_beta_values.csv")
    plt.figure(figsize=(20, 5))
    plt.plot(times, beta)
    plt.title(d.name)
    print(f"{d.name}: {np.mean(beta)}")

In [None]:
psd_fig = plt.figure(figsize=(12, 10))
psd_ax = plt.gca()
for d in result_dir.iterdir():
    times = np.loadtxt(d / "controller_sample_times.csv")
    beta = np.loadtxt(d / "controller_beta_values.csv")
    modulation_amplitude = 1
    if len(list(d.glob("*.out"))) == 1:
            output_file = list(d.glob("*.out"))[0]
            with open(output_file, 'r') as f:
                for line in f:
                    m = re.match("'beta_burst_modulation_scale': ([\.0-9]+),", line)
                    if m:
                        modulation_amplitude_text = m.groups()[0]
                        modulation_amplitude *= float(modulation_amplitude_text)
    lfp_t, lfp = u.load_stn_lfp(d, 30000, 6000)
    freqs, psd = signal.welch(np.transpose(lfp["signal"])[0])
    psd_ax.plot(freqs, psd)
    plt.figure(figsize=(20, 5))
    plt.plot(lfp_t, lfp["signal"])
    plt.plot(times * 1000, 10 * beta)
    plt.plot(stt, modulation_amplitude * modulation_signal / 1000)
    plt.title(d.name)
psd_ax.set_xlim([0, 0.1])