### Visualize the outputs of Biomime6 and Biomime7

In [2]:
import torch
from BioMime.utils.basics import update_config, load_model, load_generator
from BioMime.models.generator import Generator

import matplotlib.pyplot as plt
import seaborn as sns

from utils import plot_muap_simple

import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual

In [3]:
import numpy as np
import csv

In [14]:
# Load weights of Biomime 6 and 7
bm6 = torch.load('biomime_weights/model_linear.pth', torch.device('cpu'))
bm7 = torch.load('biomime_weights/epoch-8_checkpoint.pth', torch.device('cpu'))
bm6_keys = list(bm6.keys())
bm7_keys = list(bm7['generator'].keys())

# Sort keys so that mapping is easier
sortorder6 = np.argsort(bm6_keys)
sortorder7 = np.argsort(bm7_keys)

In [15]:
# From here, I reordered thekeys manually in the csv and saved it at
# ../BioMime/ckp/keys_mapping_biomime6-7.csv
with open('biomime_weights/keys_mapping_biomime6-7.csv', mode='r', newline='') as file:
    csv_reader = csv.reader(file, quotechar='"', quoting=csv.QUOTE_MINIMAL)
    keys_mapping = np.array(list(csv_reader))

In [16]:
# Now rename the bm7 keys according to bm6 convention
bm7_keys_old = np.array(bm7_keys)[sortorder7]
bm7_keys_new = list(keys_mapping[:, -2])
old_state_dict = bm7['generator']
new_state_dict = {}
for i in range(len(bm7_keys_old)):
    new_state_dict[bm7_keys_new[i]] = old_state_dict[bm7_keys_old[i]]

In [19]:
# Simple generation of MUAPs
config = update_config('biomime_weights/config.yaml')
config['Model']['Generator']['num_conds'] = 7
biomime7 = Generator(config.Model.Generator)
biomime7.load_state_dict(new_state_dict)

<All keys matched successfully>

In [22]:
config = update_config('biomime_weights/config.yaml')
biomime6 = Generator(config.Model.Generator)
biomime6.load_state_dict(bm6)

<All keys matched successfully>

In [24]:
fdensity = widgets.FloatSlider(value = 0.75, min = 0.5, max = 1.0, step = 0.01)
depth = widgets.FloatSlider(value = 0.75, min = 0.5, max = 1.0, step = 0.01)
angle = widgets.FloatSlider(value = 0.75, min = 0.5, max = 1.0, step = 0.01)
izone = widgets.FloatSlider(value = 0.75, min = 0.5, max = 1.0, step = 0.01)
cvel = widgets.FloatSlider(value = 0.75, min = 0.5, max = 1.0, step = 0.01)
flength = widgets.FloatSlider(value = 0.75, min = 0.5, max = 1.0, step = 0.01)
# fat = widgets.FloatSlider(value = 0.75, min = 0.5, max = 1, step = 0.05)

def generate_plot_muap(fd, d, a, iz, cv, fl):
    # Generate MUAP given specified conditions
    n_MU = 1
    n_steps = 10

    z = torch.rand(n_MU, config.Model.Generator.Latent) # (1, 16)
    c = torch.tensor((fd, d, a, iz, cv, fl))[None, :]
    sim_muaps = []

    for _ in range(n_steps):
        sim = biomime6.sample(n_MU, c.float(), c.device, z)
        sim = sim.permute(1, 2, 0).detach().numpy()
        sim_muaps.append(sim)

    sim_muaps = np.array(sim_muaps)
    mean_muap = np.mean(sim_muaps, axis=0)

    print(f'Average std across steps and channels: {np.mean(np.std(sim_muaps, axis=0))}')
    print(np.mean(np.mean(mean_muap.reshape((-1, 96)), axis=0)))
    plot_muap_simple(mean_muap[:, ::2, :])
    # plt.imshow(np.sqrt(np.mean(mean_muap**2, axis=2)))

    return None

interact_manual(generate_plot_muap, fd=fdensity, d=depth, a=angle, iz=izone, cv=cvel, fl=flength, f=fat)

interactive(children=(FloatSlider(value=0.75, description='fd', max=1.0, min=0.5, step=0.01), FloatSlider(valu…

<function __main__.generate_plot_muap(fd, d, a, iz, cv, fl)>

In [None]:
# Very quickly do the following tomorrow, to make sure shit makes sense.
# (1) Test 100 true samples, and see how close the predictions from both bm6 and bm7 are
# (2) Make sure that the effect of Fat is just really amplitude flattening (or some frequency metrics)
#     but nothing more crazy. (Go back to the features of MUAPs you were looking at a while ago.)
# (4) Which epoch should I use?