In [54]:
import matplotlib.pyplot as plt
import torch
import os

In [55]:
def plot_save_histogram(input_file, output_dir, output_file, data_type, bins):
    file = input_file

    loaded = torch.load(file)

    theta = loaded["theta"]
    x_sim = loaded["x_sim"]

    #print("theta shape:", theta.shape)
    #print("x_sim shape:", x_sim.shape)

    #print(x_sim)

    bincolor = 'skyblue'
    realcolor = 'springgreen'
    simcolor ='red'

    fig, axes = plt.subplots(2, 4, figsize=(15, 7), sharex=False, constrained_layout=True)
    axes[0][0].hist(x_sim[:, 0], bins=bins, color=bincolor, edgecolor='black')
    axes[0][0].set_title('p1: somatic inj, Vs, t1')

    axes[0][1].hist(x_sim[:, 1], bins=bins, color=bincolor, edgecolor='black')
    axes[0][1].set_title('p2: somatic inj, Vd, t1')

    axes[0][2].hist(x_sim[:, 2], bins=bins, color=bincolor, edgecolor='black')
    axes[0][2].set_title('p3: somatic inj, Vs, t2')

    axes[0][3].hist(x_sim[:, 3], bins=bins, color=bincolor, edgecolor='black')
    axes[0][3].set_title('p4: somatic inj, Vd, t2')

    axes[1][0].hist(x_sim[:, 4], bins=bins, color=bincolor, edgecolor='black')
    axes[1][0].set_title('p5: dendritic inj, Vs, t3')

    axes[1][1].hist(x_sim[:, 5], bins=bins, color=bincolor, edgecolor='black')
    axes[1][1].set_title('p6: dendritic inj, Vd, t3')

    axes[1][2].hist(x_sim[:, 6], bins=bins, color=bincolor, edgecolor='black')
    axes[1][2].set_title('p7: dendritic inj, Vs, t4')

    axes[1][3].hist(x_sim[:, 7], bins=bins, color=bincolor, edgecolor='black')
    axes[1][3].set_title('p8: dendritic inj, Vd, t4')

    fig.suptitle(f"{output_file}", fontsize=16)

    if data_type == "real":
        axes[0][0].axvline(x=-5.6, color=realcolor, linestyle='--', linewidth=2)
        axes[0][1].axvline(x=-1.7, color=realcolor, linestyle='--', linewidth=2)
        axes[0][2].axvline(x=-10.5, color=realcolor, linestyle='--', linewidth=2)
        axes[0][3].axvline(x=-5.4, color=realcolor, linestyle='--', linewidth=2)
        axes[1][0].axvline(x=-1.75, color=realcolor, linestyle='--', linewidth=2)
        axes[1][1].axvline(x=-8.36, color=realcolor, linestyle='--', linewidth=2)
        axes[1][2].axvline(x=-5.3, color=realcolor, linestyle='--', linewidth=2)
        axes[1][3].axvline(x=-12.1, color=realcolor, linestyle='--', linewidth=2)

    elif data_type == "simulated":
        axes[0][0].axvline(x=-8.799761107919297, color=simcolor, linestyle='--', linewidth=2)
        axes[0][1].axvline(x=-7.780730625812472, color=simcolor, linestyle='--', linewidth=2)
        axes[0][2].axvline(x=-8.965067337261244, color=simcolor, linestyle='--', linewidth=2)
        axes[0][3].axvline(x=-7.506729998553134, color=simcolor, linestyle='--', linewidth=2)
        axes[1][0].axvline(x=-6.587400141841358, color=simcolor, linestyle='--', linewidth=2)
        axes[1][1].axvline(x=-7.5225870365844125, color=simcolor, linestyle='--', linewidth=2)
        axes[1][2].axvline(x=-8.031715153049307, color=simcolor, linestyle='--', linewidth=2)
        axes[1][3].axvline(x=-8.530727611781275, color=simcolor, linestyle='--', linewidth=2)

    elif data_type == "both":
        axes[0][0].axvline(x=-5.6, color=realcolor, linestyle='--', linewidth=2)
        axes[0][1].axvline(x=-1.7, color=realcolor, linestyle='--', linewidth=2)
        axes[0][2].axvline(x=-10.5, color=realcolor, linestyle='--', linewidth=2)
        axes[0][3].axvline(x=-5.4, color=realcolor, linestyle='--', linewidth=2)
        axes[1][0].axvline(x=-1.75, color=realcolor, linestyle='--', linewidth=2)
        axes[1][1].axvline(x=-8.36, color=realcolor, linestyle='--', linewidth=2)
        axes[1][2].axvline(x=-5.3, color=realcolor, linestyle='--', linewidth=2)
        axes[1][3].axvline(x=-12.1, color=realcolor, linestyle='--', linewidth=2)

        axes[0][0].axvline(x=-8.799761107919297, color=simcolor, linestyle='--', linewidth=2)
        axes[0][1].axvline(x=-7.780730625812472, color=simcolor, linestyle='--', linewidth=2)
        axes[0][2].axvline(x=-8.965067337261244, color=simcolor, linestyle='--', linewidth=2)
        axes[0][3].axvline(x=-7.506729998553134, color=simcolor, linestyle='--', linewidth=2)
        axes[1][0].axvline(x=-6.587400141841358, color=simcolor, linestyle='--', linewidth=2)
        axes[1][1].axvline(x=-7.5225870365844125, color=simcolor, linestyle='--', linewidth=2)
        axes[1][2].axvline(x=-8.031715153049307, color=simcolor, linestyle='--', linewidth=2)
        axes[1][3].axvline(x=-8.530727611781275, color=simcolor, linestyle='--', linewidth=2)


    save_dir = output_dir
    os.makedirs(save_dir, exist_ok=True)

    svg_file = os.path.join(save_dir, f"{output_file}_bins{bins}.svg")
    png_file = os.path.join(save_dir, f"{output_file}_bins{bins}.png")
    fig.savefig(svg_file)
    fig.savefig(png_file)
    plt.close(fig)

In [None]:
"""
file = input_file

loaded = torch.load(file)

theta = loaded["theta"]
x_sim = loaded["x_sim"]

print("theta shape:", theta.shape)
print("x_sim shape:", x_sim.shape)

print(x_sim)
"""