# Create wavelets

In [3]:
#import yaml

from bokeh.layouts import column
from bokeh.models import ColumnDataSource, Slider
from bokeh.plotting import figure
from bokeh.models.widgets import Tabs
#from bokeh.themes import Theme
from bokeh.io import show, output_notebook

#from bokeh.sampledata.sea_surface_temperature import sea_surface_temperature

output_notebook()

In [5]:
import numpy as np
import scipy.signal
#import scipy.fft

In [10]:
def sec2labels(sec, fs, require_odd):
    '''Calculate the x labels (in samples, time) to be used in a plot
of duration `sec` that goes with a signal sampled at `fs` sample
rate. '''
    numx = np.int(np.floor(sec * fs))
    if require_odd is True and numx % 2 == 0:  # Ensure symmetrical around 0
        numx += 1
    x = np.arange(numx) - np.int(np.floor(numx / 2))
    t = x / fs
    return (x, t)

def real_sinusoid(freq, time):
    '''Calculate sinusoid given freq and a time vector. Return real portion.'''
    if np.isscalar(freq):
        rs = np.real(np.exp(2*1j*np.pi*freq*time))
    else:
        rs = np.real(np.exp(2*1j*np.pi*np.expand_dims(freq, axis=-1)*time))
    return rs

def gaussian(freq, time, num_cycles):
    '''Calculate Gaussian window for time vector with width dependent on number of
cycles of the wavelet frequency.'''
    s = num_cycles / (2*np.pi*freq)
    return np.exp(-(time**2) / (2*s**2))

def wavelet(freq, time, num_cycles):
    '''Calculate a morlet wavelet.'''
    return real_sinusoid(freq, time) * gaussian(freq, time, num_cycles)

def wavelet_spectrum(cmw, npoints):
    '''Calculate the frequency spectrum of a morlet wavelet.'''
    cmwX = scipy.fft.rfft(cmw, n=npoints)
    return cmwX / cmwX.max()

def spectrum(signal, npoints):
    '''Calculate the frequency spectrum of a signal.'''
    cmwX = scipy.fft.rfft(signal, n=npoints)
    return cmwX / cmwX.max()

def fwhm2s_and_w(h, freq, fs=None):
    '''Convert FWHM param `h` to `morlet2` `s` and `w` parameters.'''
    s = h / np.sqrt(8 * np.log(2))
    if fs is not None:
        s *= fs
    w = 2 * np.pi * freq
    return (s, w)

class Wave():
    def __init__(self, freq, sec, fs, color):
        self.fs = fs
        self.freq = freq
        self.sec = sec
        self.x, self.t = sec2labels(self.sec, self.fs, require_odd=True)
        if np.isscalar(self.freq):
            data = real_sinusoid(self.freq, self.t)
        else:
            data = np.sum(real_sinusoid(self.freq, self.t), axis=0)
        self.source = ColumnDataSource(data={'x': self.t, 'y': data})
        self.plot = figure(title='Sinusoid')
        self.plot.line('x', 'y', source=self.source, color=color)
        
    def fqcallback(self, attr, old, new):
        self.freq = new
        if np.isscalar(self.freq):
            data = real_sinusoid(self.freq, self.t)
        else:
            data = np.sum(real_sinusoid(self.freq, self.t), axis=0)
        self.source.data['y'] = data
        
class Gauss():
    def __init__(self, num_cycles, freq, sec, fs, color):
        self.num_cycles = num_cycles
        self.freq = freq
        self.sec = sec
        self.fs = fs
        self.x, self.t = sec2labels(self.sec, self.fs, require_odd=True)
        self.source = ColumnDataSource(
            data={'x': self.t, 'y': gaussian(self.freq, self.t, self.num_cycles)}
        )
        self.plot = figure(title='Gaussian')
        self.plot.line('x', 'y', source=self.source, color=color)

    def ncallback(self, attr, old, new):
        self.num_cycles = new
        self.source.data['y'] = gaussian(self.freq, self.t, self.num_cycles)

    def fqcallback(self, attr, old, new):
        self.freq = new
        self.source.data['y'] = gaussian(self.freq, self.t, self.num_cycles)

class Wavelet():
    def __init__(self, num_cycles, freq, sec, fs, colors):
        self.num_cycles = num_cycles
        self.freq = freq
        self.sec = sec
        self.fs = fs
        self.x, self.t = sec2labels(self.sec, self.fs, require_odd=True)
        self.source = ColumnDataSource(
            data={
                'x': self.t,
                'wvy': real_sinusoid(self.freq, self.t),
                'gsy': gaussian(self.freq, self.t, self.num_cycles),
                'wvlty': wavelet(self.freq, self.t, self.num_cycles)
            }
        )
        self.plot = figure(title='Sinusoid * Gaussian = wavelet')
        self.plot.line('x', 'wvy', source=self.source, color=colors[0], alpha=0.5)
        self.plot.line('x', 'gsy', source=self.source, color=colors[1], alpha=0.7)
        self.plot.line('x', 'wvlty', source=self.source, color=colors[2])

    def set_source_data(self):
        self.source.data['wvy'] = real_sinusoid(self.freq, self.t)
        self.source.data['gsy'] = gaussian(self.freq, self.t, self.num_cycles)
        self.source.data['wvlty'] = wavelet(self.freq, self.t, self.num_cycles)

    def ncallback(self, attr, old, new):
        self.num_cycles = new
        self.set_source_data()

    def fqcallback(self, attr, old, new):
        self.freq = new
        self.set_source_data()

class SpectrumPlot():
    def __init__(self, signal, fs, npoints, freqrng, color):
        self.signal = signal
        self.fs = fs
        self.npoints = npoints
        self.freqrng = freqrng
        spec = np.abs(spectrum(self.signal, self.npoints))
        hz = np.linspace(0, self.fs/2, len(spec))
        rng = (hz >= self.freqrng[0]) & (hz <= self.freqrng[1])
        self.source = ColumnDataSource(data={'x': hz[rng], 'y': spec[rng]})
        self.plot = figure(title='Frequency spectrum')
        self.plot.line('x', 'y', source=self.source, color=color)

    def set_source_data(self):
        spec = np.abs(spectrum(self.signal, self.npoints))
        hz = np.linspace(0, self.fs/2, len(spec))
        rng = (hz >= self.freqrng[0]) & (hz <= self.freqrng[1])
        self.source.data = {'x': hz[rng], 'y': spec[rng]}

    def scallback(self, attr, old, new):
        self.signal = new
        self.set_source_data()

class WaveletConstructor():
    def __init__(self):
        freq = 10.0
        freqrng = (1.0, 30.0)
        freqstep = 1.0
        sec = 2.0
        num_cycles = 10
        cyclesrng = (1, 20)
        cyclesstep = 1
        fs = 1024
        npoints = 1024
        specrng=(freqrng[0]-1, freqrng[1]+1)
        colors = ['lightblue', 'cornflowerblue', 'midnightblue']
        self.wv = Wave(freq=freq, sec=sec, fs=fs, color=colors[0])
        self.gs = Gauss(num_cycles=num_cycles, freq=freq, sec=sec, fs=fs, color=colors[1])
        self.wvlt = Wavelet(num_cycles=num_cycles, freq=freq, sec=sec, fs=fs, colors=colors)
        self.wvltspec = SpectrumPlot(signal=self.wvlt.source.data['wvlty'], fs=fs, npoints=npoints, freqrng=freqrng, color='midnightblue')
        self.fqslider = Slider(start=freqrng[0], end=freqrng[1], value=freq, step=freqstep, title='Sinusoid frequency (Hz)')
        self.fqslider.on_change('value', self.wv.fqcallback)
        self.fqslider.on_change('value', self.gs.fqcallback)
        self.fqslider.on_change('value', self.wvlt.fqcallback)
        self.nslider = Slider(start=cyclesrng[0], end=cyclesrng[1], value=num_cycles, step=cyclesstep, title='Number of cycles in Gaussian window')
        self.nslider.on_change('value', self.gs.ncallback)
        self.nslider.on_change('value', self.wvlt.ncallback)
        self.wvlt.source.on_change('data', self.update_wvlt_spectrum)
        self.wv.plot.plot_height = 125
        self.gs.plot.plot_height = 125
        self.wvlt.plot.plot_height = 200
        self.wvltspec.plot.plot_height = 125
        self.column = column(
            self.fqslider,  self.nslider, self.wv.plot, self.gs.plot, self.wvlt.plot, self.wvltspec.plot
        )
        
    def update_wvlt_spectrum(self, attr, old, new):
        self.wvltspec.scallback('value', None, new['wvlty'])
    
def wavelet_app(doc):
    wc = WaveletConstructor()
    doc.add_root(wc.column)
    return doc

In [20]:
h = 0.1
freq = 60

s = h / np.sqrt(8 * np.log(2))
w = 2 * np.pi * freq
print(s)
print(w)
w % 2 * np.pi * freq

0.04246609001440096
376.99111843077515


186.82142285763845

In [14]:
1 / np.sqrt(8 * np.log(2))

0.42466090014400953

In [398]:
show(wavelet_app)

In [301]:
freq = np.array([10, 15])
time = wv.t
plot = figure()
wvs = np.expand_dims(freq, axis=-1)*time
wvs.shape

(2, 201)

In [364]:
wvlt = Wavelet(num_cycles=3, freq=3, sec=4.0, fs=100, colors=['blue', 'blue', 'midnightblue'])

In [399]:
wv = Wave(freq=np.array([1, 2, 3, 5, 7]), sec=4.0, fs=100, color='midnightblue')
wv.plot.plot_height = 250
show(wv.plot)

In [402]:
sp = SpectrumPlot(signal=wv.source.data['y'], fs=fs, npoints=1024, freqrng=(0,30), color='midnightblue')
show(sp.plot)

In [265]:
as = ifft(cmwX.*dataX,nConv);
as = as(half_wave+1:end-half_wave);
as = reshape(as,EEG.pnts,EEG.trials);

NoneType