In [None]:
### FINAL HEAT MAP CODE
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks
import seaborn as sns
from matplotlib.colors import ListedColormap

plt.rcParams.update({'font.size': 12})

def signal_trend(data):
    arg_min = np.argmin(data)
    increase_threshold = 0.25 * len(data)
    decrease_threshold = 0.75 * len(data)

    if arg_min > decrease_threshold:
        return 0  # decreasing (Low-Pass)
    elif arg_min < increase_threshold:
        return 1  # increasing (High-Pass)
    else:
        return 2  # peaking (Band-Pass)

# constants for sim
tau_2 = 5  # ms for AMPA
tau_2g = 2  # ms for GABA
width_block = 1
diff_onset = 1


# ranges for tau_1 and tau_1g to vary from 2 to 20 ms in steps of 2 ms
a_range = np.arange(2, 22, 2)
b_range = np.arange(2, 22, 2)

# conductance ratios for columns and latencies for rows
ge_values = [0.8, 0.9, 1.0, 1.1, 1.2]    #typical is 0.5-2 nS
gi_values = [1, 1, 1, 1, 1]     #typical is 0.5-2 nS
lat_e_values = [1, 1, 1, 1, 1]  #typical is 1-3 ms
lat_i_values = [1, 2, 3, 4, 5]  #typical is 2-5 ms


# Simulation parameters
t = 50
dt = 0.1
N = round(t / dt)
ttotal = np.arange(0, t, dt)

# colormap
cmap = ListedColormap(['blue', 'yellow', 'green'])

fig, axes = plt.subplots(5, 5, figsize=(20, 20))
fig.subplots_adjust(hspace=0.4, wspace=0.4)


for row, (lat_e, lat_i) in enumerate(zip(lat_e_values, lat_i_values)):
    for col, (ge, gi) in enumerate(zip(ge_values, gi_values)):
        results = np.zeros((len(a_range), len(b_range)))
        for m, tau_1 in enumerate(a_range):
            for n, tau_1g in enumerate(b_range):
                g_AMPA_ext_E = 1 * ge  
                g_GABA_E = 1 * gi  

                frequency = 20 #Hz

                tspks = np.linspace(0, t, int(frequency / 10))
                tspksi = [int(x / dt) for x in tspks]
                tspksf = [int((x + width_block) / dt) for x in tspks]
                
                syn_ampa = np.zeros_like(ttotal)
                for o in range(len(tspksi)):
                    syn_ampa[int(tspksi[o] + int(lat_e / dt)):int(tspksf[o] + int(lat_e / dt))] = 1  
                syn_gaba = np.zeros_like(ttotal)
                for o in range(len(tspksi)):
                    syn_gaba[int(tspksi[o] + int(lat_i / dt)):int(tspksf[o] + int(lat_i / dt))] = 1  

                s_AMPA = 0
                x_ampa = 0
                s_GABA = 0
                x_gaba = 0

                Iacum = []
                vacum = []
                v = -70
                gleak = 1 * (10**-4)
                Eleak = -70

                for p in range(N):
                    s_AMPA += dt * (((tau_2 / tau_1) ** (tau_1 / (tau_2 - tau_1)) * x_ampa - s_AMPA) / tau_1)
                    x_ampa += dt * (-x_ampa / tau_2 + syn_ampa[p])
                    
                    s_GABA += dt * (((tau_2g / tau_1g) ** (tau_1g / (tau_2g - tau_1g)) * x_gaba - s_GABA) / tau_1g)
                    x_gaba += dt * (-x_gaba / tau_2g + syn_gaba[p])
                    
                    Iacum.append(g_GABA_E * s_GABA - g_AMPA_ext_E * s_AMPA)

                    v = v + dt * (-gleak * (v - Eleak) + Iacum[-1])
                    vacum.append(v)

                Iacum = np.asarray(Iacum)
                peaks, _ = find_peaks(Iacum)
                filtered_signal = gaussian_filter1d(Iacum[peaks], 4)

                results[m, n] = signal_trend(vacum)

        #plot heatmap for combo of latencies and conductance ratios
        ax = sns.heatmap(results, linewidth=.5, cmap=cmap, cbar_kws={'ticks': [0, 1, 2]}, ax=axes[row, col], vmin=0, vmax=2)
        ax.collections[0].colorbar.set_ticklabels(['Low-Pass', 'High-Pass', 'Band-Pass'])
        ax.set_xlabel(r'$\tau_{1}$ (ms)', fontsize=10)
        ax.set_ylabel(r'$\tau_{1g}$ (ms)', fontsize=10)
        ax.set_title(f'$\Delta_l$={lat_e - lat_i}, $\Delta_g$ = {round(ge - gi, 2)}', fontsize=12)
        
        #labels
        tick_labels = a_range  
        tick_labels_shown = [4, 8, 12, 16, 20]  
        tick_positions_shown = [tick_labels.tolist().index(t) for t in tick_labels_shown]  
        ax.set_xticks(tick_positions_shown)
        ax.set_xticklabels(tick_labels_shown)
        ax.set_yticks(tick_positions_shown)
        ax.set_yticklabels(tick_labels_shown)
        ax.invert_yaxis()

plt.suptitle('Signal Trend of Synaptic Currents with Varying Latencies and Conductances', fontsize=18)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])

plt.show()