In [3]:
pip install torch

Collecting torch
  Obtaining dependency information for torch from https://files.pythonhosted.org/packages/16/dd/1bf10180ba812afa1aa7427466083d731bc37b9a1157ec929d0cfeef87eb/torch-2.1.0-cp311-none-macosx_10_9_x86_64.whl.metadata
  Using cached torch-2.1.0-cp311-none-macosx_10_9_x86_64.whl.metadata (24 kB)
Using cached torch-2.1.0-cp311-none-macosx_10_9_x86_64.whl (146.7 MB)
Installing collected packages: torch
Successfully installed torch-2.1.0
Note: you may need to restart the kernel to use updated packages.


In [5]:
pip install torchdiffeq

Collecting torchdiffeq
  Using cached torchdiffeq-0.2.3-py3-none-any.whl (31 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.3
Note: you may need to restart the kernel to use updated packages.


In [4]:
import numpy as np
import sys
import os
from scipy.integrate import solve_ivp
import torch
from torchdiffeq import odeint
import WaveformGenerator
from WaveformGenerator import Waveform

To start we introduce an exploration of the Hodgkin Huxley model before defining and implimenting our forward model class.

In [1]:


# Define HH model parameters
g_Na = 120.0  # Sodium conductance (mS/cm^2)
g_K = 36.0   # Potassium conductance (mS/cm^2)
g_L = 0.3    # Leak conductance (mS/cm^2)
E_Na = 50.0  # Sodium reversal potential (mV)
E_K = -77.0  # Potassium reversal potential (mV)
E_L = -55.0  # Leak reversal potential (mV)
C_m = 1.0    # Membrane capacitance (uF/cm^2)

# Define the HH model equations
def alpha_m(V):
    return 0.1 * (V + 40.0) / (1.0 - np.exp(-(V + 40.0) / 10.0))

def beta_m(V):
    return 4.0 * np.exp(-(V + 65.0) / 18.0)

def alpha_h(V):
    return 0.07 * np.exp(-(V + 65.0) / 20.0)

def beta_h(V):
    return 1.0 / (1.0 + np.exp(-(V + 35.0) / 10.0))

def alpha_n(V):
    return 0.01 * (V + 55.0) / (1.0 - np.exp(-(V + 55.0) / 10.0))

def beta_n(V):
    return 0.125 * np.exp(-(V + 65) / 80.0)

# Define the HH model differential equations
def hodgkin_huxley_eq(V, m, h, n, I):
    dVdt = (I - g_Na * m**3 * h * (V - E_Na) - g_K * n**4 * (V - E_K) - g_L * (V - E_L)) / C_m
    dmdt = alpha_m(V) * (1 - m) - beta_m(V) * m
    dhdt = alpha_h(V) * (1 - h) - beta_h(V) * h
    dndt = alpha_n(V) * (1 - n) - beta_n(V) * n
    return dVdt, dmdt, dhdt, dndt

# Define time parameters
dt = 0.01  # Time step (ms)
t = np.arange(0, 50, dt)  # Time vector (ms)

# Define input stimulus waveform (e.g., a step current)
I_stimulus = np.zeros_like(t)
I_stimulus[100:500] = 10.0  # Apply a current step from 100 ms to 500 ms

# Initialize state variables
V = -65.0  # Membrane voltage (mV)
m, h, n = 0.05, 0.6, 0.32  # Initial values of gating variables

# Simulate the HH model and record data
V_record = []
for I in I_stimulus:
    V_record.append(V)
    dVdt, dmdt, dhdt, dndt = hodgkin_huxley_eq(V, m, h, n, I)
    V += dVdt * dt
    m += dmdt * dt
    h += dhdt * dt
    n += dndt * dt

    

In [6]:


def hodgkin_huxley_neural_ode(t, z, I_stimulus):
    V, m, h, n = z
    dVdt, dmdt, dhdt, dndt = hodgkin_huxley_eq(V, m, h, n, I_stimulus)
    return [dVdt, dmdt, dhdt, dndt]

# Create a function to solve the neural ODE
def solve_neural_ode(I_stimulus, t_span):
    z0 = [-65.0, 0.05, 0.6, 0.32]  # Initial conditions for V, m, h, n
    return odeint(hodgkin_huxley_neural_ode, torch.tensor(z0), t_span, args=(I_stimulus,))


In [7]:
waveform_generator = Waveform(duration=2, sampling_rate=1000)
waveform_generator.gaussian_pulse(amplitude=1, fwhm_seconds=1)
waveform_data = waveform_generator.get_waveform_data()
print("Gaussian Pulse Data:")
print(waveform_data)

Gaussian Pulse Data:
[[0.00000000e+00 6.26734838e-02]
 [1.00000000e-03 6.30216353e-02]
 [2.00000000e-03 6.33713693e-02]
 ...
 [1.99700000e+00 6.33713693e-02]
 [1.99800000e+00 6.30216353e-02]
 [1.99900000e+00 6.26734838e-02]]


In [None]:
class HodgkinHuxley:
    def __init__(self, stimuli_type, frequency, duration=2, sampling_rate=1000, duty_cycle = 0.5, amplitude=1)
        
        #parameters from Noah's waveform generator 
        self.waveform_class = stimuli_type
        self.frequency = frequency
        self.duration = duration
        self.sampling_rate = sampling_rate
        self.duty_cycle = duty_cycle
        self.amplitude = amplitude
        
        #hodgkin Huxley parameters
        self.C_m = 1.0  # membrane capacitance (uF/cm^2)
        self.g_Na = 120.0  # sodium conductance (mS/cm^2)
        self.g_K = 36.0  # potassium conductance (mS/cm^2)
        self.g_L = 0.3  # leak conductance (mS/cm^2)
        self.E_Na = 50.0  # sodium reversal potential (mV)
        self.E_K = -77.0  # potassium reversal potential (mV)
        self.E_L = -54.387  # leak reversal potential (mV)
        
    def generate_pulse(self):
        waveform_generator = Waveform(self.duration, self.sampling_rate)
        if self.waveform_class == 'sine':
            waveform_generator.sine_wave(self.frequency, self.amplitude)
        elif self.waveform_class == 'triangle':
            waveform_generator.triangular_wave(self.frequency, self.amplitude)
        elif self.waveform_class == 'gaussian':
            waveform_generator.gaussian_pulse(self.amplitude, self.frequency)
        elif self.waveform_class == 'square':
            waveform_generator.square_wave(self.frequency, self.amplitude)
        else: 
            raise ValueError("Waveform type is not valid for this simulation please pick either sine, triangle, gaussian, or square, defaulting to sine")
            waveform_generator.sine_wave(self.frequency, self.amplitude)
            
        waveform_data = waveform_generator.get_waveform_data()
        
        return waveform_data
    
    def model(self, ic, t): 
        V, m, h, n = ic #defining initial conditions from instance
        
        # Interpolate the stimulus voltage at the current time
        I_stim = np.interp(t, self.generate_pulse()[0], self.generate_pulse()[1])
        
        # Hodgkin-Huxley equations
        dVdt = (I_stim - self.g_Na * m**3 * h * (V - self.E_Na)
                - self.g_K * n**4 * (V - self.E_K) - self.g_L * (V - self.E_L)) / self.C_m
        dmdt = self.alpha_m(V) * (1 - m) - self.beta_m(V) * m
        dhdt = self.alpha_h(V) * (1 - h) - self.beta_h(V) * h
        dndt = self.alpha_n(V) * (1 - n) - self.beta_n(V) * n

        return [dVdt, dmdt, dhdt, dndt]
    
    #helper functions for the HH
    def alpha_m(self, V):
        return 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))

    def beta_m(self, V):
        return 4.0 * np.exp(-(V + 65) / 18)

    def alpha_h(self, V):
        return 0.07 * np.exp(-(V + 65) / 20)

    def beta_h(self, V):
        return 1.0 / (1 + np.exp(-(V + 35) / 10))

    def alpha_n(self, V):
        return 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))

    def beta_n(self, V):
        return 0.125 * np.exp(-(V + 65) / 80)

    def simulate(self, y0, t):
        solution = odeint(self.model, y0, t)
        return solution    
        
        