In [27]:
import numpy as np
import matplotlib.pyplot as plt
import os
from CA3model_passive_5params import CA3model_passive_solver


In [28]:
def plot_save_for_diff_params(param_set, output_dir, name):

    """
    sim_params = [0.1, -60, 2.1, 0.5, 3]
    [sim_points, sim_p_times, sim_traces] = CA3model_passive_solver(sim_params, return_trace=True)

    sim_Vs1 = sim_traces[0]
    sim_Vd1 = sim_traces[1]
    sim_Vs2 = sim_traces[2]
    sim_Vd2 = sim_traces[3]
    """
    for i, params in enumerate(param_set):

        [points, p_times, traces] = CA3model_passive_solver(params, return_trace=True)

        print(points)

        Vs1 = traces[0]
        Vd1 = traces[1]
        Vs2 = traces[2]
        Vd2 = traces[3]

        t = range(len(Vs1))
        fig, axs = plt.subplots(2, 1, figsize=(10, 5), sharex=True, sharey=True)
        axs[1].plot(t, Vs1, 'c', label = 'Vs')
        axs[1].plot(t, Vd1, 'k', label = 'Vd')
        axs[1].plot(p_times[0], points[0], 'o', color = 'c')
        axs[1].plot(p_times[0], points[1], 'o', color = 'k')
        axs[1].plot(p_times[1], points[2], 'o', color = 'c')
        axs[1].plot(p_times[1], points[3], 'o', color = 'k')
        axs[1].set_title("Somatic injection")

        """
        axs[0].plot(t, sim_Vs1, 'r', label = 'sim_Vs')
        axs[0].plot(t, sim_Vd1, 'm', label = 'sim_Vd')
        axs[0].plot(p_times[0], sim_points[0], 'o', color = 'r')
        axs[0].plot(p_times[0], sim_points[1], 'o', color = 'm')
        axs[0].plot(p_times[1], sim_points[2], 'o', color = 'r')
        axs[0].plot(p_times[1], sim_points[3], 'o', color = 'm')
        """

        axs[0].plot(t, Vs2, 'c', label = 'Vs')
        axs[0].plot(t, Vd2, 'k', label = 'Vd')
        axs[0].plot(p_times[2], points[4], 'o', color = 'c')
        axs[0].plot(p_times[2], points[5], 'o', color = 'k')
        axs[0].plot(p_times[3], points[6], 'o', color = 'c')
        axs[0].plot(p_times[3], points[7], 'o', color = 'k')
        axs[0].set_title("Dendritic injection")
        """
        axs[1].plot(t, sim_Vs2, 'r', label = 'sim_Vs')
        axs[1].plot(t, sim_Vd2, 'm', label = 'sim_Vd')
        axs[1].plot(p_times[2], sim_points[4], 'o', color = 'r')
        axs[1].plot(p_times[2], sim_points[5], 'o', color = 'm')
        axs[1].plot(p_times[3], sim_points[6], 'o', color = 'r')
        axs[1].plot(p_times[3], sim_points[7], 'o', color = 'm')
        """

        axs[0].legend()
        axs[1].legend()

        param_text = f"gL={params[0]:.4f}, VL={params[1]:.2f}, gc={params[2]:.4f}, pp={params[3]:.3f}, Cm={params[4]:.3f}"

        plt.suptitle(param_text, fontsize=10)

        post_prob_text = f"{str({name})}" if name is not None else ""

        if post_prob_text:
            plt.gcf().text(0.95, 0.95, post_prob_text, fontsize=10,
                           horizontalalignment='right', verticalalignment='top',
                           transform=plt.gcf().transFigure,
                           bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))

        save_dir = output_dir
        os.makedirs(save_dir, exist_ok=True)

        svg_file = os.path.join(save_dir, f"{param_set}_{i+1}.svg")
        png_file = os.path.join(save_dir, f"{param_set}_{i+1}.png")
        fig.savefig(svg_file)
        fig.savefig(png_file)
        plt.close(fig)

        data_to_save = np.column_stack([np.arange(len(Vs1)), Vs1, Vd1, np.arange(len(Vs2)), Vs2, Vd2])

        points_str = " ".join([f"{pt:.6f}" for pt in points])
        header = f"# Parameters: g_L={params[0]},V_L={params[1]}, gc={params[2]}, pp={params[3]}, Cm={params[4]} \np1={points[0]}, p2={points[1]}, p3={points[2]}, p4={points[3]}, p5={points[4]}, p6={points[5]}, p7={points[6]}, p8={points[7]}\nt1={p_times[0]}, t2={p_times[0]}, t3={p_times[1]}, t4={p_times[1]}, t5={p_times[2]}, t6={p_times[2]}, t7={p_times[3]}, t8={p_times[3]}\n# Columns: t1 Vs1 Vd1 t2 Vs2 Vd2"

        np.savetxt(
            os.path.join(save_dir, f"{param_set}_{i+1}.txt"),
            data_to_save,
            header=header,
            fmt="%.6f"
        )


In [22]:
def plot_save_for_diff_params_with_given_datapoints(param_set, output_dir, name, sim_points):

    for i, params in enumerate(param_set):

        [points, p_times, traces] = CA3model_passive_solver(params, return_trace=True)

        print(points)

        Vs1 = traces[0]
        Vd1 = traces[1]
        Vs2 = traces[2]
        Vd2 = traces[3]

        t = range(len(Vs1))
        fig, axs = plt.subplots(2, 1, figsize=(10, 5), sharex=True, sharey=True)
        axs[1].plot(t, Vs1, 'c', label = 'Vs')
        axs[1].plot(t, Vd1, 'k', label = 'Vd')
        """
        axs[1].plot(p_times[0], points[0], 'o', color = 'c')
        axs[1].plot(p_times[0], points[1], 'o', color = 'k')
        axs[1].plot(p_times[1], points[2], 'o', color = 'c')
        axs[1].plot(p_times[1], points[3], 'o', color = 'k')
        """
        axs[1].set_title("Somatic injection")


        #axs[0].plot(t, sim_Vs1, 'r', label = 'sim_Vs')
        #axs[0].plot(t, sim_Vd1, 'm', label = 'sim_Vd')
        axs[1].plot(p_times[0], sim_points[0], 'o', color = 'c')
        axs[1].plot(p_times[0], sim_points[1], 'o', color = 'k')
        axs[1].plot(p_times[1], sim_points[2], 'o', color = 'c')
        axs[1].plot(p_times[1], sim_points[3], 'o', color = 'k')


        axs[0].plot(t, Vs2, 'c', label = 'Vs')
        axs[0].plot(t, Vd2, 'k', label = 'Vd')
        """
        axs[0].plot(p_times[2], points[4], 'o', color = 'c')
        axs[0].plot(p_times[2], points[5], 'o', color = 'k')
        axs[0].plot(p_times[3], points[6], 'o', color = 'c')
        axs[0].plot(p_times[3], points[7], 'o', color = 'k')
        """
        axs[0].set_title("Dendritic injection")


        #axs[1].plot(t, sim_Vs2, 'r', label = 'sim_Vs')
        #axs[1].plot(t, sim_Vd2, 'm', label = 'sim_Vd')
        axs[0].plot(p_times[2], sim_points[4], 'o', color = 'c')
        axs[0].plot(p_times[2], sim_points[5], 'o', color = 'k')
        axs[0].plot(p_times[3], sim_points[6], 'o', color = 'c')
        axs[0].plot(p_times[3], sim_points[7], 'o', color = 'k')


        axs[0].legend()
        axs[1].legend()

        param_text = f"gL={params[0]:.4f}, VL={params[1]:.2f}, gc={params[2]:.4f}, pp={params[3]:.3f}, Cm={params[4]:.3f}"

        plt.suptitle(param_text, fontsize=10)

        post_prob_text = f"{str({name})}" if name is not None else ""

        if post_prob_text:
            plt.gcf().text(0.95, 0.95, post_prob_text, fontsize=10,
                           horizontalalignment='right', verticalalignment='top',
                           transform=plt.gcf().transFigure,
                           bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))

        save_dir = output_dir
        os.makedirs(save_dir, exist_ok=True)

        svg_file = os.path.join(save_dir, f"{param_set}_{i+1}.svg")
        png_file = os.path.join(save_dir, f"{param_set}_{i+1}.png")
        fig.savefig(svg_file)
        fig.savefig(png_file)
        plt.close(fig)

        data_to_save = np.column_stack([np.arange(len(Vs1)), Vs1, Vd1, np.arange(len(Vs2)), Vs2, Vd2])

        points_str = " ".join([f"{pt:.6f}" for pt in points])
        header = f"# Parameters: g_L={params[0]},V_L={params[1]}, gc={params[2]}, pp={params[3]}, Cm={params[4]} \np1={sim_points[0]}, p2={sim_points[1]}, p3={sim_points[2]}, p4={sim_points[3]}, p5={sim_points[4]}, p6={sim_points[5]}, p7={sim_points[6]}, p8={sim_points[7]}\nt1={p_times[0]}, t2={p_times[0]}, t3={p_times[1]}, t4={p_times[1]}, t5={p_times[2]}, t6={p_times[2]}, t7={p_times[3]}, t8={p_times[3]}\n# Columns: t1 Vs1 Vd1 t2 Vs2 Vd2"

        np.savetxt(
            os.path.join(save_dir, f"{param_set}_{i+1}.txt"),
            data_to_save,
            header=header,
            fmt="%.6f"
        )
