In [3]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

# Define the folder path once
folder_path = 'results_mask_uci/UCI/kdd_model/One_out/linear/pretrain/window_size_4sec/feat_dim_6/freeze_False_epoch_500_lr_0.001_d_hidden_64_d_ff_128_n_heads_8_n_layer_8_pos_encode_learnable_activation_gelu_norm_BatchNorm'

# Load the files using the folder path
mask = np.loadtxt(f'{folder_path}/mask.txt')
y_pred = np.loadtxt(f'{folder_path}/pred.txt')
t_true = np.loadtxt(f'{folder_path}/true.txt')

# Define the sequence of features based on your loading order
features = []
for suffix in ['_x', '_y', '_z']:
    for modality in ['body_acc', 'body_gyro']:
        features.append(modality + suffix)

data_dict = {}
for idx, feat in enumerate(features):
    data_dict[feat] = {
        "mask": mask[idx::6],
        "y_pred": y_pred[idx::6],
        "t_true": t_true[idx::6]
    }

# Calculate the min and max values for acc and gyro
acc_values = [data_dict[feat]['t_true'] for feat in features if 'body_acc' in feat]
gyro_values = [data_dict[feat]['t_true'] for feat in features if 'body_gyro' in feat]

acc_min = min(np.min(value) for value in acc_values)
acc_max = max(np.max(value) for value in acc_values)
gyro_min = min(np.min(value) for value in gyro_values)
gyro_max = max(np.max(value) for value in gyro_values)

# Define the number of points per subplot and the number of subplots
points_per_subplot = 128
num_subplots = 100

# Create a function to plot a specific subplot
def plot_subplot(i):
    start = i * points_per_subplot
    end = start + points_per_subplot

    # Create a figure with 6 subplots
    fig, axs = plt.subplots(6, figsize=(10, 10))
    
    for idx, feat in enumerate(features):
        axs[idx].plot(data_dict[feat]["t_true"][start:end], 'r-', label='True', alpha=0.5)
        axs[idx].plot(data_dict[feat]["y_pred"][start:end], 'b-', label='Predicted', alpha=0.5)
        axs[idx].set_title('{} ({}-{})'.format(feat, start, end))
        
        # Set y-axis limits based on modality
        if 'body_acc' in feat:
            axs[idx].set_ylim([acc_min, acc_max])
        else:
            axs[idx].set_ylim([gyro_min, gyro_max])
            
        axs[idx].legend()

    # Display the plots
    plt.tight_layout()
    plt.show()

# Create a slider to select the subplot
slider = widgets.IntSlider(min=0, max=num_subplots-1, step=1, description='Subplot:')

# Create an interactive plot viewer
widgets.interactive(plot_subplot, i=slider)


interactive(children=(IntSlider(value=0, description='Subplot:', max=99), Output()), _dom_classes=('widget-int…