<div class="alert alert-block alert-success">
    <b>DEPENDENCIES</b>
    <ul>
        <li> numpy
        <li> scipy
        <li> matplotlib
        <li> mat73
        <li> moviepy
    <\ul>
</div>

In [1]:
import numpy as np

In [2]:
# import scipy as sc
# import scipy.signal as sg
from scipy.signal import argrelmax
from scipy.io.wavfile import write

In [3]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [4]:
import mat73

In [5]:
from moviepy.editor import VideoClip, VideoFileClip, AudioFileClip, TextClip, CompositeVideoClip, CompositeAudioClip, clips_array, vfx
from moviepy.video.io.bindings import mplfig_to_npimage

---

In [6]:
data_dict = mat73.loadmat("../data/cell-ID-6.mat")["ephys"]
# trace       = filtered data (artifact replaced with nans)
# times       = times of the recording (in sec)
# stim_on     = stimuli onset (in sec)
# stim_off    = stimuli offset (in sec)
# pupil_times = the time vector of the video (in sec)
data_dict.keys()

dict_keys(['name', 'pupil_times', 'stim_off', 'stim_on', 'times', 'trace'])

In [7]:
video = VideoFileClip(r"../data/GB0002 22-05-31 11-03-34_.avi")

In [8]:
spike_sampling_freq = np.argwhere(data_dict["times"] >= 1)[0][0]
video_frame_rate = int(np.ceil(video.fps))

```
> Recording frequency ~ 25000 Hz  
> Frame rate of video ~ 20 fps
```
<!-- 102.400264 frames/s -->

In [9]:
spike_threshold = 0.0005

---

In [10]:
peak_indices = np.squeeze(argrelmax(data_dict["trace"]))
spike_times = data_dict["times"][peak_indices[data_dict["trace"][peak_indices] > spike_threshold]]

---

In [11]:
def plot_spikes(t_b, t_a=0, ax=None, shade_stimulus=False, indicate_spikes=False, plot=True):
    """Plot the spike waveform.

    Parameters
    ----------
    t_b : float
        Window start time (s).
    t_a : float, optional
        Window start time (s), by default 0.0.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on, by default None.
    """
    
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(16, 4), dpi=40)

    a = int(t_a * spike_sampling_freq)
    n = int((t_b - t_a) * spike_sampling_freq) # Window size

    ax.axhline(0, color='k', linestyle='--', linewidth=1)

    if shade_stimulus:
        for son, soff in zip(data_dict["stim_on"], data_dict["stim_off"]):
            if son*spike_sampling_freq >= a+n: break
            soff = min(soff, (a+n)/spike_sampling_freq)
            ax.axvspan(son, soff, color='pink', alpha=0.1)

    if indicate_spikes:
        for spike_time in spike_times[(spike_times >= t_a) & (spike_times < t_b)]:
            ax.axvline(spike_time, color='g', linestyle='-.')
    
    if plot:
        ax.plot(data_dict["times"][a:a+n], data_dict["trace"][a:a+n])
    
    return data_dict["times"][a:a+n], data_dict["trace"][a:a+n]

    ## FFT
    # ax[1].axhline(0, color='k', linestyle='--')
    # ax[1].plot(np.fft.fftshift(np.fft.fftfreq(n, 1/spike_sampling_freq)), np.abs(np.fft.fftshift(np.fft.fft(data_dict["trace"][a:a+n]))))
    # ax[1].set_xticks(np.arange(-int(spike_sampling_freq/2), int(spike_sampling_freq/2), 1000))
    # ax[1].set_xlim([-spike_sampling_freq/2, spike_sampling_freq/2]);

In [12]:
def generate_video(t_b, t_a=0.0, speed=1.0, t_w=1.0, fps=None, transition='roll', save=True, output=".temp/__temp__.mp4"):
    """Generate spike video

    Parameters
    ----------
    t_b : float
        Window start time (s).
    t_a : float, optional
        Window start time (s), by default 0.0.
    speed : float, optional
        Playback speed, by default 1.0.
    t_w : float, optional
        Window size (s) if transition is 'window', by default 1.0.
    fps : int, optional
        Frames per second, by default None.
    transition : {'roll', 'window'}, optional, optional
        Type of video animation, by default 'roll'
        'roll' - Growing spike waveform.
        'window' - Shifting window waveform.
    save : bool, optional
        Save video as .mp4?, by default True.
    output : str, optional
        Output file name, by default 'roll.wav'.
    """
    if speed is None:
        if transition == 'roll':
            speed = round((t_b - t_a)/5.0, 3) #1-5
        elif transition == 'window':
            speed = round(t_w/1.0, 3)     #1-5

    sfps = max(24, np.ceil(fps if fps else min(120, spike_sampling_freq*speed, video_frame_rate*speed)))

    if transition == 'roll': t_w = 0
    print(f"> Estimated Video Length ~ {(t_b - t_w - t_a)/speed:.2f}s @ {sfps:>5} fps  | [{t_a:.3f}s, {t_b:.3f}s] @ {speed}x <{transition[0]}> -- `{output if output else f'{transition}.mp4'}`")

    fig, ax = plt.subplots(1, 1, figsize=(16, 4), dpi=40)
    ax.set_ylim(-0.001, 0.002)
    ax.axis('off')

    if transition == 'window':
        def make_frame(t):
            ax.clear()
            ax.set_xlim(t_a + t*speed, t_a + t_w + t*speed)        
            plot_spikes(t_a + t_w + t*speed, t_a + t*speed, ax=ax, shade_stimulus=True, indicate_spikes=False, plot=False)
            return mplfig_to_npimage(fig)
        
        video = VideoClip(make_frame, duration=(t_b - t_w - t_a)/speed)
        if save: video.write_videofile(output if output else f"{transition}.mp4", fps=sfps, verbose=False, logger=None)
        # video.ipython_display(fps=sfps, loop=False, autoplay=True)
        
    elif transition == 'roll':
        ax.set_xlim(t_a, t_b)

        x, y = plot_spikes(t_b, t_a, ax=ax, shade_stimulus=True, indicate_spikes=False, plot=False)
        line, = ax.plot(x, y)
        plt.tight_layout()

        def update(frame):
            f = int(frame*spike_sampling_freq*speed/sfps)
            line.set_data(x[:f], y[:f])
            return line,

        video = animation.FuncAnimation(fig, update, frames=int((t_b - t_a)/speed*sfps), interval=int(1000/sfps), blit=True)
        if save: video.save(output if output else f"{transition}.mp4", writer=animation.FFMpegWriter(fps=sfps))
    
    plt.close()
    
    return video, sfps

In [13]:
def generate_audio(t_b, t_a=0.0, speed=1.0, tone_frequency=200, tone_duration=0.002, save=True, output='.temp/__temp__.wav'):
    """ Generate spike audio.

    Parameters
    ----------
    t_b : float
        Window start time (s).
    t_a : float, optional
        Window start time (s), by default 0.0.
    speed : float, optional
        Playback speed, by default 1.0.
    tone_frequency : int, optional
        Frequency of spike chirp, by default 440.
    tone_duration : float, optional
        Duration of spike chirp, by default 0.002.
    save : bool, optional
        Save audio as .wav?, by default True.
    output : str, optional
        Output file name, by default 'roll.wav'.
    """
    if speed is None: speed = round((t_b - t_a)/5.0, 3) #1-5
    
    afps = int(spike_sampling_freq*speed)
    if tone_frequency >= afps/2:
        raise ValueError(f"Speed too low! Min Speed - {2*tone_frequency/spike_sampling_freq}")
    
    print(f"> Estimated Audio Length ~ {(t_b - t_a)/speed:.2f}s @ {afps:>5} afps | [{t_a:.3f}s, {t_b:.3f}s] @ {speed}x     -- `{output}`")
    
    n = int((t_b - t_a) * spike_sampling_freq)
    #x = np.arange(n)*(t_b - t_a)/(n-1)/speed
    
    def add_chirp(signal, index):
        if index + int(spike_sampling_freq*tone_duration) < signal.size:
            signal[index:index+int(spike_sampling_freq*tone_duration)] = np.sin(2*np.pi*tone_frequency/speed*np.linspace(0, tone_duration, int(spike_sampling_freq*tone_duration)))
    
    chirp_indices = ((spike_times[(spike_times >= t_a) & (spike_times < t_b)] - t_a) * spike_sampling_freq).astype(np.int)

    y = np.zeros(n)
    for chirp_index in chirp_indices:
        add_chirp(y, chirp_index)

    audio = np.array(y/max(y), dtype=np.float32)
    if save: write(output, afps, audio)
    
    return audio, afps

In [14]:
def spike_video(t_b, t_a=0.0, speed=1.0, fps=None, output="spike.mp4"):
    _, afps = generate_audio(t_b, t_a, speed, output=".temp/__temp__.wav")
    _, sfps = generate_video(t_b, t_a, speed, fps=fps, save=True, output=".temp/__temp__.mp4")
    
    spike_video = VideoFileClip(".temp/__temp__.mp4")
    spike_audio = AudioFileClip(".temp/__temp__.wav")
    spike_video.audio = CompositeAudioClip([spike_audio])
    video_x = video.subclip(t_a, t_b).speedx(speed).set_fps(sfps)
    
    txt_clip = TextClip(f"x{speed}", fontsize=75, color='black')
    txt_clip = txt_clip.set_position(("right", "top")).set_duration((t_b-t_a)/speed)
    
    clips_array([[CompositeVideoClip([video_x, txt_clip])], [spike_video]]).write_videofile(output) #, verbose=False, logger=None)

In [15]:
spike_video(3, 0, 0.4)

> Estimated Audio Length ~ 7.50s @ 10000 afps | [0.000s, 3.000s] @ 0.4x     -- `.temp/__temp__.wav`
> Estimated Video Length ~ 7.50s @    24 fps  | [0.000s, 3.000s] @ 0.4x <r> -- `.temp/__temp__.mp4`


                                                                                                                       

Moviepy - Building video spike.mp4.
MoviePy - Writing audio in spikeTEMP_MPY_wvf_snd.mp3
MoviePy - Done.
Moviepy - Writing video spike.mp4



                                                                                                                       

Moviepy - Done !
Moviepy - video ready spike.mp4
