In [1]:
import numpy as np

# Define the folder path once
folder_path = 'results/limu_model/One_out/pretrain/window_size_5sec/epoch_6600_lr_0.0001_d_hidden_72_d_ff_144_n_heads_4_n_layer_4_embNorm_False'

# Load data
mask_pos = np.loadtxt(f'{folder_path}/mask_pos.txt')

# Specify the window size here
window_size = 150

# Number of time series
num_time_series = len(mask_pos)

# Create 2D arrays for x and y coordinates of ground truth and predicted values
true_x = np.zeros((num_time_series, window_size))
true_y = np.zeros((num_time_series, window_size))
pred_x = np.zeros((num_time_series, window_size))
pred_y = np.zeros((num_time_series, window_size))

# Load origin_seq, true.txt, and pred.txt
origin_seq = np.loadtxt(f'{folder_path}/origin_seq.txt')
t_true = np.loadtxt(f'{folder_path}/true.txt')
y_pred = np.loadtxt(f'{folder_path}/pred.txt')


# Initialize counters for true and pred values
true_counter = 0
pred_counter = 0

# Assign values to the arrays
for i, masked_positions in enumerate(mask_pos):
    start = i * 150
    end = start + 150
    
    true_x[i] = origin_seq[start:end, 0]
    true_y[i] = origin_seq[start:end, 1]
    
    for pos in masked_positions:
        true_x[i, int(pos)] = t_true[true_counter, 0]
        true_y[i, int(pos)] = t_true[true_counter, 1]
        
        pred_x[i, int(pos)] = y_pred[pred_counter, 0]
        pred_y[i, int(pos)] = y_pred[pred_counter, 1]
        
        # Increment the counters
        true_counter += 1
        pred_counter += 1

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

def plot_time_series(i):
    # Extract data for the selected time series
    x = true_x[i]
    y = true_y[i]
    px = pred_x[i]
    py = pred_y[i]
    
    # Create a figure with 2 subplots
    fig, axs = plt.subplots(2, figsize=(10, 5))
    
    # Plot x-coordinates
    axs[0].plot(x, 'r-', label='Ground Truth X-coordinates', alpha=0.5)
    axs[0].scatter(mask_pos[i], x[mask_pos[i].astype(int)], c='r', s=100, alpha=0.5, label='Masked X-coordinates')
    axs[0].scatter(mask_pos[i], px[mask_pos[i].astype(int)], c='b', s=50, label='Predicted X-coordinates')
    axs[0].set_title('X-coordinates for Time Series {}'.format(i))
    axs[0].legend()
    
    # Plot y-coordinates
    axs[1].plot(y, 'r-', label='Ground Truth Y-coordinates', alpha=0.5)
    axs[1].scatter(mask_pos[i], y[mask_pos[i].astype(int)], c='r', s=100, alpha=0.5, label='Masked Y-coordinates')
    axs[1].scatter(mask_pos[i], py[mask_pos[i].astype(int)], c='b', s=50, label='Predicted Y-coordinates')
    axs[1].set_title('Y-coordinates for Time Series {}'.format(i))
    axs[1].legend()
    
    # Display the plots
    plt.tight_layout()
    plt.show()

# Create a slider to select the time series
slider = widgets.IntSlider(min=0, max=num_time_series-1, step=1, description='Time Series:')

# Create an interactive plot viewer
widgets.interactive(plot_time_series, i=slider)


interactive(children=(IntSlider(value=0, description='Time Series:', max=51), Output()), _dom_classes=('widget…