In [1]:
import torch
import numpy as np
import ipywidgets as widgets
from IPython.display import display
from pybela import Streamer
from utils import load_model, get_device, get_all_run_ids, get_models_coordinates, find_closest_model, get_models_range

path = "models/trained"
device = get_device()

In [2]:
vars = ['gFaabSensor_1', 'gFaabSensor_2', 'gFaabSensor_3', 'gFaabSensor_4', 'gFaabSensor_5', 'gFaabSensor_6', 'gFaabSensor_7', 'gFaabSensor_8']

streamer = Streamer()
streamer.connect()

ConnectionError: Connection failed: [Errno 111] Connect call failed ('192.168.7.2', 5555).

In [None]:
# preload models
run_ids = get_all_run_ids(path=path)
models = {}
models_running_range = {}
for _id in run_ids:
    models[_id] = load_model(_id, path=path)
    models_running_range[_id] = {"min": torch.FloatTensor([0, 0, 0, 0]).to(
        device), "max": torch.FloatTensor([0, 0, 0, 0]).to(device)}
seq_len = 512

# model space
# min and max of model outputs (after passing the whole dataset)
full_dataset_models_range = get_models_range(path=path)
# model's chosen 4 hyperparameters mapped to values between 0 and 1
models_coordinates = get_models_coordinates(path=path)

num_blocks_to_compute_avg = 10
trigger_width = 25
trigger_idx = 4

# init average, min max and model
model_avg = torch.empty(0).to(device)

starter_id = run_ids[0]
model = models[starter_id]

# settings
running_norm = True

In [None]:
# gain = [1.35, 1.1, 1.5, 1.5]
gain = 4*[1.0]

# Create sliders using values from 'gain' list
gain_sliders = [widgets.FloatSlider(value=gain[i], min=0.0, max=2.0, step=0.01, description=f'Gain {i+1}') for i in range(len(gain))]

# Update function to adjust gain parameters
def update_gain(change):
    global gain
    gain = [slider.value for slider in gain_sliders]
    # Optional: Call a function here to apply the new gain values immediately

# Display widgets and link them to the update function
slider_box = widgets.VBox([widgets.Label('Adjust Gain Parameters:')] + gain_sliders)

In [None]:

counter = 0

async def callback(block):

    # global variables so that the state is kept between callback calls
    global model_avg, model_min, model_max, model, gain, counter
    
    with torch.no_grad():

        _raw_data_tensor = torch.stack([torch.FloatTensor(
            buffer["buffer"]["data"]) for buffer in block])  # num_features, 1024
        # split the data into seq_len to feed it into the model
        inputs = _raw_data_tensor.unfold(1, seq_len, seq_len).permute(
            1, 2, 0)  # n, seq_len, num_features

        # for each sequence of seq_len, feed it into the model
        for _input in inputs:
            out = model.forward_encoder(_input.to(device)).permute(
                1, 0)  # num_outputs, seq_len
            # outputs --> [ff_size, num_heads, num_layers, learning_rate]
            
            # -- normalisation --

            # running normalisation (taking max and min from the current run)
            if running_norm:
                models_running_range[_id]["min"] = torch.stack(
                    (models_running_range[_id]["min"], out.min(dim=1).values), dim=0).min(dim=0).values
                models_running_range[_id]["max"] = torch.stack(
                    (models_running_range[_id]["max"], out.max(dim=1).values), dim=0).max(dim=0).values

                _min, _max = models_running_range[_id]["min"], models_running_range[_id]["max"]

            # absolute normalisation (taking max and min from passing the full dataset)
            else:
                _model_range = full_dataset_models_range[model.id]
                _min, _max = torch.FloatTensor(_model_range["min"]).to(
                    device), torch.FloatTensor(_model_range["max"]).to(device)
            
            # -- normalise before sending to Bela!! --
            normalised_out = (out - _min.unsqueeze(1)) / (_max - _min).unsqueeze(1)

            for idx, feature in enumerate(normalised_out):  # send each feature to Bela
                streamer.send_buffer(idx, 'f', seq_len, feature.tolist())
                counter+=1
                
                if counter < num_blocks_to_compute_avg:
                    streamer.send_buffer(trigger_idx, 'f', seq_len, seq_len*[0.0])
                elif counter == num_blocks_to_compute_avg:
                    streamer.send_buffer(trigger_idx, 'f', seq_len, trigger_width*[1.0] + (seq_len-trigger_width)*[0.0])
                    counter = 0

            model_avg = torch.cat(
                (model_avg, normalised_out.mean(dim=1).unsqueeze(0)), dim=0)
                            

        if len(model_avg) == num_blocks_to_compute_avg:
            # -- gain --
            _avg = model_avg.mean(dim=0).detach().cpu().tolist()
            # multiply the final averaged value by a tuned gain
            _avg = [a * g for a, g in zip(_avg, gain)]

            # -- map to model --
            # find the closest model to the _avg coordinates
            closest_model, _ = find_closest_model(_avg, models_coordinates)
            model = models[closest_model]

            # -- reset avg --
            model_avg = torch.empty(0).to(device)

            print(model.id,np.round(_avg, 4))

In [143]:
display(slider_box)

for slider in gain_sliders:
    slider.observe(update_gain, names='value')

VBox(children=(Label(value='Adjust Gain Parameters:'), FloatSlider(value=1.0, description='Gain 1', max=2.0, s…

In [147]:
streamer.start_streaming(vars, on_block_callback=callback)

[94mStarted streaming variables ['gFaabSensor_1', 'gFaabSensor_2', 'gFaabSensor_3', 'gFaabSensor_4', 'gFaabSensor_5', 'gFaabSensor_6', 'gFaabSensor_7', 'gFaabSensor_8']... Run stop_streaming() to stop streaming.[0m


altyemr7 [0.0726 0.0314 0.9249 0.8757]
lt03ebc6 [0.9998 0.0897 0.533  0.7037]
h0o65m8s [0.9152 0.1352 1.     0.3095]
b2kr7jlj [0.5474 0.0329 0.4593 0.7012]
altyemr7 [0.0008 0.0618 0.5276 0.7779]
kyi2wb9t [0.9998 0.0897 0.2909 0.7037]
8kwjo4md [0.8349 0.0025 0.3306 0.6869]
5li6ry52 [0.4142 0.2323 0.3282 0.6233]
jsrbrjo7 [0.5775 0.4583 0.0086 0.9396]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9999]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0.1082 0.9998]
jsrbrjo7 [0.4683 0.5197 0

In [146]:
streamer.stop_streaming()

[94mStopped streaming variables ['gFaabSensor_1', 'gFaabSensor_2', 'gFaabSensor_3', 'gFaabSensor_4', 'gFaabSensor_5', 'gFaabSensor_6', 'gFaabSensor_7', 'gFaabSensor_8']...[0m


# plot

In [8]:
import os
import bokeh
import bokeh.plotting
import bokeh.io
import bokeh.driving
from bokeh.resources import INLINE

os.environ["BOKEH_ALLOW_WS_ORIGIN"]="0j0t0jnmqu776ei6png0k89bho2qg0m6ia345511p3uleiqq2kep"

In [None]:
streamer.start_streaming(variables=["gFaabSensor_1", "gFaabSensor_2"])

In [None]:
streamer.plot_data(x_var="gFaabSensor_1", y_vars=["gFaabSensor_1", "gFaabSensor_2"], y_range=[0, 1], rollover=10000)

In [None]:
plot_data = {
    
    "out_1": {"timestamps": [0,1,2], "data" : [0.1, 0.2, 0.3]},
    "out_2": {"timestamps": [0,1,2], "data" : [0.1, 0.2, 0.3]},
}

In [None]:
from itertools import cycle
import asyncio

def _bokeh_plot_data_app(
                            data,
                            x_var,
                            y_vars,
                            y_range=None,
                            rollover=None,
                            plot_update_delay=90):

    def _app(doc):
        # Instantiate figures
        p = bokeh.plotting.figure(
            frame_width=500,
            frame_height=175,
            x_axis_label="timestamps",
            y_axis_label="value",
        )

        if y_range is not None:
            p.y_range = bokeh.models.Range1d(y_range[0], y_range[1])

        # No padding on x_range makes data flush with end of plot
        p.x_range.range_padding = 0

        # Create a dictionary to store ColumnDataSource instances for each y_var
        template = {"timestamps": [], **{var: [] for var in data}}
        source = bokeh.models.ColumnDataSource(template)

        # # Create line glyphs for each y_var
        colors = cycle([
            "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
            "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
            "#bcbd22", "#17becf", "#1a55FF", "#FF1A1A"
        ])
        for y_var in y_vars:
            p.line(source=source, x="timestamps",
                    y=y_var, line_color=next(colors), legend_label=y_var)

        @bokeh.driving.linear()
        def update(step):
            # Update plot by streaming in data
            new_data = {"timestamps": [
                data[x_var]["timestamp"]]if "timestamp" in data[x_var] else data[x_var]["timestamps"]}
            for y_var in y_vars:
                new_data[y_var] = data[y_var]["data"] if isinstance(
                    data[y_var]["data"], list) else [data[y_var]["data"]]
            source.stream(new_data, rollover)

        doc.add_root(p)
        doc.add_periodic_callback(update, plot_update_delay)
    return _app

def plot_data( x_var, y_vars, y_range=None, plot_update_delay=100, rollover=500):

    # wait until streaming buffers have been populated
    async def wait_for_streaming_buffers_to_arrive():
        while not all(data['data'] for data in {
                var: _buffer for var, _buffer in last_streamed_buffer.items() if var in y_vars}.values()):
            await asyncio.sleep(0.1)
    asyncio.run(wait_for_streaming_buffers_to_arrive())

    bokeh.io.output_notebook(INLINE)
    bokeh.io.show(_bokeh_plot_data_app(data={
        var: _buffer for var, _buffer in last_streamed_buffer.items() if var in y_vars}, x_var=x_var,
        y_vars=y_vars, y_range=y_range, plot_update_delay=plot_update_delay, rollover=rollover))


In [None]:
plot_data(x_var="gFaabSensor_1", y_vars=["gFaabSensor_1", "gFaabSensor_2"], y_range=[0, 1], rollover=10000)

In [None]:
streamer.stop_streaming()