# Stimulation from multiple sources
This tutorial will use the simulation results from the [finite amplitudes stimulation](1_finite_amp.ipynb). We will recap that example, and then move on to the case where we have multiple sources.

## Create the fiber and set up simulation
As before, we create fiber, waveform, potentials and stimulation object.

In [None]:
import numpy as np
from pyfibers import build_fiber, FiberModel, ScaledStim

# create fiber model
n_sections = 265
fiber = build_fiber(FiberModel.MRG_INTERPOLATION, diameter=10, n_sections=n_sections)
print(fiber)

# Setup for simulation. Add zeros at the beginning so we get some baseline for visualization
waveform = np.concatenate((np.zeros(100), np.ones(100), np.zeros(48000)))  # monophasic rectangular pulse

fiber.potentials = fiber.point_source_potentials(0, 250, fiber.length / 2, 1, 10)

time_step = 0.001
time_stop = 20

# Create stimulation object
stimulation = ScaledStim(waveform=waveform, dt=time_step, tstop=time_stop)

We can then calculate the fiber's response to stimulation with a certain stimulation amplitude.

In [None]:
stimamp = -1.5  # mA
ap, time = stimulation.run_sim(stimamp, fiber)
print(f'Number of action potentials detected: {ap}')
print(f'Time of last action potential detection: {time} ms')

## Stimulation from multiple sources
In many cases, it is desirable to model stimulation of a fiber from multiple sources. There are several ways this can be done:
    1. If each source delivers the same waveform, we can simply sum the potentials from each source (using superposition), and proceed as before. If different stimuli deliver different polarities or scaled versions of the same waveform, we can weight the potentials.
    2. If each source delivers a different waveform, we must calculate the potentials from each source at runtime. Thus, we must provide the fiber with multiple potentials sets, one for each source. We must also provide the ScaledStim instance with multiple waveforms, again one for each source. You can also use this approach, if desired, for the case where each source delivers the same waveform. Under this method, you may either provide a single stimulation amplitude, which is then applied to all sources, or you may provide a list of amplitudes, one for each source. Note that for threshold searches, only a single stimulation amplitude is supported for threshold searching.

### Superposition of potentials
In this example, we will consider the case where we have two sources, each delivering the same waveform. We will calculate the potentials from each source, and sum them to get the total potential at each node. We will use bipolar stimulation, where one source is anode and the other is cathode.

In [None]:
fiber.potentials *= 0  # reset potentials
for position, polarity in zip([0.45 * fiber.length, 0.55 * fiber.length], [1, -1]):
    # add the contribution of one source to the potentials
    fiber.potentials += polarity * fiber.point_source_potentials(0, 250, position, 1, 10)

# plot the potentials
import matplotlib.pyplot as plt

plt.figure()
plt.plot(fiber.longitudinal_coordinates, fiber.potentials[0])
plt.xlabel('Position (μm)')
plt.ylabel('Potential (mV)')
plt.show()

In [None]:
# run simulation
ap, time = stimulation.run_sim(stimamp, fiber)
print(f'Number of action potentials detected: {ap}')
print(f'Time of last action potential detection: {time} ms')

# Sources with different waveforms
In this example, we will consider the case where we have two sources, each delivering a different waveform. We must provide the fiber with multiple potentials sets, one for each source. We must also provide the ScaledStim instance with multiple waveforms, one for each source. 

In [None]:
potentials = []
# create curve of potentials
for position in [0.45 * fiber.length, 0.55 * fiber.length]:
    potentials.append(fiber.point_source_potentials(0, 250, position, 1, 1))
fiber.potentials = np.vstack(potentials)
print(fiber.potentials.shape)

plt.figure()
plt.plot(fiber.potentials[0, :], label='source 1')
plt.plot(fiber.potentials[1, :], label='source 2')
plt.legend()

In [None]:
# create waveforms and stack them
waveform1 = np.concatenate((-np.ones(50), np.zeros(49750)))  # monophasic rectangular pulse (cathodic)
waveform2 = np.concatenate((np.ones(200), np.zeros(49600)))  # monophasic rectangular pulse (longer duration)
waveform = np.vstack((waveform1, waveform2))

# Create instance of ScaledStim class
stimulation = ScaledStim(waveform=waveform, dt=time_step, tstop=time_stop)

# turn on saving gating parameters and Vm before running the simulations for thresholds
fiber.set_save_gating()
fiber.set_save_vm()

# run simulation with the same amplitude for all waveforms
ap, time = stimulation.run_sim(-1.5, fiber)
print(f'Number of action potentials detected: {ap}')
print(f'Time of last action potential detection: {time} ms')

# Now, run a simulation with different amplitudes for each waveform
ap, time = stimulation.run_sim([-1.5, 1], fiber)
print(f'Number of action potentials detected: {ap}')
print(f'Time of last action potential detection: {time} ms')

# Finally, run a threshold search (can only search for a single stimulation amplitude across all waveforms)
amp, ap = stimulation.find_threshold(fiber)

In [None]:
import pandas as pd
import seaborn as sns

# plot waveforms
plt.figure()
plt.plot(
    np.array(stimulation.time)[:-1], amp * stimulation.waveform[0, :], label='Waveform 1'
)  # TODO: why is waveform shorter than stimulation.time
plt.plot(np.array(stimulation.time)[:-1], amp * stimulation.waveform[1, :], label='Waveform 2')
plt.ylabel('Amplitude (mA)')
plt.xlabel('Time (ms)')
plt.legend()
plt.xlim([0, 1])

# plot heatmap
data = pd.DataFrame(np.array(fiber.vm[1:-1]))
vrest = fiber[0].e_pas
print('Membrane rest voltage:', vrest)
plt.figure()
g = sns.heatmap(
    data,
    cbar_kws={'label': '$V_m$ $(mV)$'},
    cmap='seismic',
    vmax=np.amax(data.values) + vrest,
    vmin=-np.amax(data.values) + vrest,
)
plt.ylabel('Node index')
plt.xlabel('Time (ms)')
tick_locs = np.linspace(0, len(np.array(stimulation.time)[:1000]), 9)
labels = [round(np.array(stimulation.time)[int(ind)], 2) for ind in tick_locs]
g.set_xticks(ticks=tick_locs, labels=labels)
plt.title(
    'Membrane voltage over time\
          \nRed=depolarized, Blue=hyperpolarized'
)
plt.xlim([0, 1000])
# label source locations
for loc, ls, label in zip([0.45, 0.55], [':', '--'], ['source 1', 'source 2']):
    location = loc * (len(fiber)) - 1
    plt.axhline(location, color='black', linestyle=ls, label=label)
plt.legend()
plt.gcf().set_size_inches(8.15, 4)