In [None]:
import os

from EPGAN import epgan_prep_data, epgan_model, neuron_params, default_dir
from scipy import interpolate

import numpy as np
import matplotlib.pyplot as plt
import torch

In [None]:
np.random.seed(77)
np.set_printoptions(precision=3, suppress = True)

## Load parameter bounds for rescaling EPGAN outputs

In [None]:
os.chdir(default_dir + "\\EPGAN\\data\\sim")

# For small HH-model
Pars_min_S = torch.load("generic_par_train_2024_min_S.pt", weights_only=True)
Pars_max_S = torch.load("generic_par_train_2024_max_S.pt", weights_only=True)

# For large HH-model
Pars_min_L = torch.load("generic_par_train_2024_min_L.pt", weights_only=True)
Pars_max_L = torch.load("generic_par_train_2024_max_L.pt", weights_only=True)

## Define current/voltage clamp intervals

In [None]:
current_clamp_list = np.arange(-15, 36, 1) # in pA (pico-ampere)
current_clamp_list_5pA = current_clamp_list[::5]
current_clamp_list_2pA = current_clamp_list[11:32:2]
current_clamp_list_1pA = current_clamp_list[13:24:1]

CC_5pA_inds = np.array([0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50])
CC_2pA_inds = np.array([11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31])
CC_1pA_inds = np.array([13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23])

voltage_clamp_list = np.arange(-120, 80, 10)[:-2] # in mV (milli-volt)

## Load experimental data

In [None]:
# Load membrane potential dynamics 
# 0 - 7 seconds, sampled every 0.01 seconds, 
# membrane potential re-scaled to deci-volt (1e-1 V)
RIM_V, RIM_75_V = epgan_prep_data.load_exp_data_V('RIM', 'single')[:2]
AIY_V, AIY_75_V = epgan_prep_data.load_exp_data_V('AIY', 'single')[:2]
AFD_V, AFD_75_V = epgan_prep_data.load_exp_data_V('AFD', 'single')[:2]
AWB_V = epgan_prep_data.load_exp_data_V('AWB', 'multi')[0]
AWC_V = epgan_prep_data.load_exp_data_V('AWC', 'single')[0]
URX_V = epgan_prep_data.load_exp_data_V('URX', 'multi')[0]
RIS_V = epgan_prep_data.load_exp_data_V('RIS', 'single')[0]
DVC_V = epgan_prep_data.load_exp_data_V('DVC', 'single')[0]
HSN_V = epgan_prep_data.load_exp_data_V('HSN', 'multi')[0]

# Dim = (9, 11, 700) -> (# of neurons, # of current-clamp steps, timepoints)
exp_input_seqs = torch.vstack([RIM_V, AIY_V, AFD_V, AWB_V, AWC_V, URX_V, RIS_V, DVC_V, HSN_V]).swapaxes(1,2).float()

# Load steady-states currents profiles
# -120mV - +60mV in 10mV increment
# Steady-state currents rescaled to 1e-10 A
RIM_SS, RIM_75_SS = epgan_prep_data.load_exp_data_IV('RIM', 'single')[:2]
AIY_SS, AIY_75_SS = epgan_prep_data.load_exp_data_IV('AIY', 'single')[:2]
AFD_SS, AFD_75_SS = epgan_prep_data.load_exp_data_IV('AFD', 'single')[:2]
AWB_SS = epgan_prep_data.load_exp_data_IV('AWB', 'multi')[0][:, :18]
AWC_SS = epgan_prep_data.load_exp_data_IV('AWC', 'single')[0][:, :18]
URX_SS = epgan_prep_data.load_exp_data_IV('URX', 'multi')[0][:, :18]
RIS_SS = epgan_prep_data.load_exp_data_IV('RIS', 'single')[0][:, :18]
DVC_SS = epgan_prep_data.load_exp_data_IV('DVC', 'single')[0][:, :18]
HSN_SS = epgan_prep_data.load_exp_data_IV('HSN', 'multi')[0][:, :18]

# Dim = (9, 11) -> (# of neurons, # of voltage-clamp steps)
exp_SS = torch.vstack([RIM_SS, AIY_SS, AFD_SS, AWB_SS, AWC_SS, URX_SS, RIS_SS, DVC_SS, HSN_SS]).float()

In [None]:
# Initial membrane potentials (mV)
RIM_V0 = -38
AIY_V0 = -53
AFD_V0 = -78
AWB_V0 = -75.07054233602229
AWC_V0 = -73.3002653431254
URX_V0 = -46.62392439719344
RIS_V0 = -50.21278449675801
DVC_V0 = -48.59310891738207
HSN_V0 = -54.458039436385434

V_initcond_exp = np.array([RIM_V0, AIY_V0, AFD_V0, AWB_V0, AWC_V0, URX_V0, RIS_V0, DVC_V0, HSN_V0]) * 1e-2 # Rescaled to deci-volt (1e-1 V) 

## Format experimental data compatible to EP-GAN inputs

### Construct current-clamp stimuli array

In [None]:
# Each timepoint is sampled every 0.02 seconds (e.g., 750 timepoints = 15 seconds)

# Current-clamp protocol for RIM, AIY, AFD
# Dim = (3, 750, 11) -> (# of neurons, timepoints, # of current-clamp steps)
Iext_exp_5pA = epgan_model.construct_Iext(3, current_clamp_list_5pA)

# Current-clamp protocol for AWB, AWC, URX, RIS
# Dim = (4, 750, 11)
Iext_exp_2pA = epgan_model.construct_Iext(4, current_clamp_list_2pA)

# Current-clamp protocol for DVC, HSN
# Dim = (2, 750, 11)
Iext_exp_1pA = epgan_model.construct_Iext(2, current_clamp_list_1pA)

# Merge all external stimulation matrices alongside # of neuron axis
# Dim = (9, 750, 11)
Iext_exp = torch.cat([Iext_exp_5pA, Iext_exp_2pA, Iext_exp_1pA], dim = 0)

# Current-clamp protocol with 75% membrane potential (removing first 3 membrane potential traces)
input_mask_exp_75 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])

### Merge membrane potential traces with current-clamp stimuli array

In [None]:
# Concatenate membrane potential traces with external current traces
# Dim = (9, 11, 700) -> (# of neurons, # of current-clamp steps, timepoints)
# 0 - 350 timepoints -> membrane potential traces (unit = 1e-1 V)
# 350 - 700 timepoints -> External current stimuli traces (unit = 1e-10 A)
exp_batches_features_V = torch.cat([exp_input_seqs[:, :, ::2], Iext_exp.swapaxes(1, 2)[:, :, 200:-200] * 1e-2], dim = 2)

In [None]:
# Input ablation scenarios

# Membrane potential input ablation (75% membrane potential for RIM, AIY, AFD)
# Dim = (3, 11, 700)
exp_batches_features_V75 = exp_batches_features_V[:3, input_mask_exp_75, :]

# Steady-state input ablation (75% steady-state currents for RIM, AIY, AFD)
# Dim = (3, 18)
exp_SS_75 = torch.vstack([RIM_75_SS, AIY_75_SS, AFD_75_SS]).float()[:, :-2]

In [None]:
print("Membrane potential input: ", exp_batches_features_V.shape)
print("Steady-state currents input: ", exp_SS.shape)

## Plot membrane potential and steady-state current inputs to EP-GAN

In [None]:
plt.figure(figsize = (10, 3))

# Membrane potential traces + input current stimuli for RIM
plt.subplot(1,2,1)
plt.plot(exp_batches_features_V[0].T)
plt.title("Membrane potential (1e-1 V) and Iext (1e-10 A)")

# Steady-state currents profile for RIM
plt.subplot(1,2,2)
plt.plot(exp_SS[0], '-o')
plt.title("Steady-state currents inputs (1e-10 A)")
plt.show()

## Load pre-trained EPGAN models

In [None]:
os.chdir(default_dir + "\\EPGAN\\pretrained_models")

EPGAN_S = torch.load("S1.pth") # For estimating small HH-model parameters 
EPGAN_L = torch.load("L1.pth") # For estimating large HH-model parameters

EPGAN_S.to(device = torch.device("cpu"))
EPGAN_L.to(device = torch.device("cpu"))

## Predict parameters

In [None]:
# Load names for each parameter

param_labels = np.array(list(neuron_params.generic_model_params.keys()))

In [None]:
import seaborn as sns

sns.set(style = 'white', font_scale = 1)

In [None]:
# Generated small HH-model parameters
# Dim = (9, 203) -> (# of neurons, # of parameters)

EPGAN_S.eval()

with torch.no_grad():

    pars_gen_exp_S = epgan_model.test_exp_inputs(EPGAN_S, exp_batches_features_V, exp_SS, Pars_min_S, Pars_max_S, V_initcond_exp)
    # Membrane potential input ablation (75% remaining)
    pars_gen_exp_v75_S = epgan_model.test_exp_inputs(EPGAN_S, exp_batches_features_V75, exp_SS[:3], Pars_min_S, Pars_max_S, V_initcond_exp[:3])
    # Steady-states currents input ablation (75% remaining)
    pars_gen_exp_iv75_S = epgan_model.test_exp_inputs(EPGAN_S, exp_batches_features_V[:3], exp_SS_75, Pars_min_S, Pars_max_S, V_initcond_exp[:3])

In [None]:
# Generated large HH-model parameters
# Dim = (9, 203) -> (# of neurons, # of parameters)

EPGAN_L.eval()

with torch.no_grad():

    pars_gen_exp_L = epgan_model.test_exp_inputs(EPGAN_L, exp_batches_features_V, exp_SS, Pars_min_L, Pars_max_L, V_initcond_exp)
    # Membrane potential input ablation (75% remaining)
    pars_gen_exp_v75_L = epgan_model.test_exp_inputs(EPGAN_L, exp_batches_features_V75, exp_SS[:3], Pars_min_L, Pars_max_L, V_initcond_exp[:3])
    # Steady-states currents input ablation (75% remaining)
    pars_gen_exp_iv75_L = epgan_model.test_exp_inputs(EPGAN_L, exp_batches_features_V[:3], exp_SS_75, Pars_min_L, Pars_max_L, V_initcond_exp[:3])

In [None]:
# Print estimated small HH-model parameters for the first neuron (RIM)
# See supplementary files of EP-GAN paper for the units used by the parameters
# 0: RIM
# 1: AIY
# 2: AFD
# 3: AWB
# 4: AWC
# 5: URX
# 6: RIS
# 7: DVC
# 8: HSN

print(np.vstack([param_labels, pars_gen_exp_S[0]]).T)

In [None]:
# Print estimated large HH-model parameters for the first neuron (RIM)\
# See supplementary files of EP-GAN paper for the units used by the parameters
# 0: RIM
# 1: AIY
# 2: AFD
# 3: AWB
# 4: AWC
# 5: URX
# 6: RIS
# 7: DVC
# 8: HSN

print(np.vstack([param_labels, pars_gen_exp_L[0]]).T)