In [None]:
import torch
import torch.nn.functional as F
import random
import json
from tqdm import tqdm
from andi_datasets.utils_challenge import *
from andi_datasets.datasets_challenge import _get_dic_andi2, challenge_phenom_dataset
from utils.padding import FEATURE_PADDING_VALUE, LABEL_PADDING_VALUE
from utils.features import getFeatures
from utils.plotting import *
from utils.postprocessing import *
from models.models import ClassificationModel, RegressionModel
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_svg import FigureCanvasSVG
import os
import numpy as np
import pickle



# Load predictions and ground truth

In [None]:
ROOT_DIR = "<test dir>"

for file in os.listdir(ROOT_DIR):
    path = os.path.join(ROOT_DIR, file)

    if file == "gt_a.npy":
        gt_a = np.load(path)
    
    if file == "pred_a.npy":
        pred_a = np.load(path)
    
    if file == "gt_k.npy":
        gt_k = np.load(path)
    
    if file == "pred_k.npy":
        pred_k = np.load(path)
    
    if file == "gt_state.npy":
        gt_state = np.load(path)
    
    if file == "pred_state.npy":
        pred_state = np.load(path)

print(gt_a.shape)

# Mean Absolute Error For K for Various CPs

In [None]:
d_cps_k = {str(i): [] for i in range(6)}  # Initialize with all possible changepoint numbers

predictions = pred_k.copy()
ground_truth = gt_k.copy()
tracks_per_cp = {str(i): 0 for i in range(6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    # if all(count >= 100 for count in tracks_per_cp.values()):
    #     break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    # if len(g) == 200: 
    if len(g)==200 and no_of_cps in range(6) and tracks_per_cp[str(no_of_cps)] < 10000:
        total_tracks += 1
        p = median_filter_1d(smooth_series(p, lower_limit=0, upper_limit=6))
        mae = np.mean(np.abs(p - g))

        d_cps_k[str(no_of_cps)].append(mae)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

plot_mae_cps(d_cps_k, label_type="k")
# mean_mae = {k: np.mean(v) if v else 0 for k, v in d_cps.items()}
# print(mean_mae)
# print("\nFinal counts:")
# for k, v in tracks_per_cp.items():
#     print(f"Changepoints {k}: {v} tracks")

# print("\nFinal d_cps lengths:")
# for k, v in d_cps.items():
#     print(f"Changepoints {k}: {len(v)} values")


# Mean Absolute Error Alpha for Various CPs

In [None]:
d_cps_alpha = {str(i): [] for i in range(6)}  # Initialize with all possible changepoint numbers

predictions = pred_a.copy()
ground_truth = gt_a.copy()
tracks_per_cp = {str(i): 0 for i in range(6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    # if all(count >= 100 for count in tracks_per_cp.values()):
    #     break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(6) and tracks_per_cp[str(no_of_cps)] < 10000:
        total_tracks += 1
        p = median_filter_1d(smooth_series(p, lower_limit=0, upper_limit=1.999))
        mae = np.mean(np.abs(p - g))

        d_cps_alpha[str(no_of_cps)].append(mae)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

plot_mae_cps(d_cps_alpha)

# mean_mae = {k: np.mean(v) if v else 0 for k, v in d_cps.items()}
# print(mean_mae)
# print("\nFinal counts:")
# for k, v in tracks_per_cp.items():
#     print(f"Changepoints {k}: {v} tracks")

# print("\nFinal d_cps lengths:")
# for k, v in d_cps.items():
#     print(f"Changepoints {k}: {len(v)} values")


In [None]:
plot_mae_cps_combined(d_cps_k, d_cps_alpha)

# RMSE for all combinations

for alpha

In [None]:
d_cps = {str(i): [] for i in range(1, 6)}  # Initialize with all possible changepoint numbers

predictions = pred_a.copy()
ground_truth = gt_a.copy()
tracks_per_cp = {str(i): 0 for i in range(1, 6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(1, 6) and tracks_per_cp[str(no_of_cps)] < 100:
        total_tracks += 1
        
        cp_pred, _ = get_cps_ruptures(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
        cp_gt = get_cps_gt(g)

        cp_pred = cp_pred[1:-1]
        cp_gt = cp_gt[1:-1]
        
        if cp_gt == cp_pred:
            rmse = 0
        else:
            rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(rmse)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

results = {}
for key, items in d_cps.items():
    valid_items = [x for x in items if not (np.isnan(x) or np.isinf(x))]   
    if valid_items:
        avg_rmse = sum(valid_items) / len(valid_items)
    else:
        avg_rmse = float('nan')  # or handle empty case differently

    results[key] = {
        'average_rmse': avg_rmse,
    }

print(results)

# plot_rmse_cps(d_cps, label_type="alpha_without_0")

for K

In [None]:
d_cps = {str(i): [] for i in range(1, 6)}  # Initialize with all possible changepoint numbers

predictions = pred_k.copy()
ground_truth = gt_k.copy()
tracks_per_cp = {str(i): 0 for i in range(1, 6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(1, 6) and tracks_per_cp[str(no_of_cps)] < 100:
        total_tracks += 1

        cp_pred, _ = get_cps_ruptures(p, lower_limit=0, upper_limit=6, threshold=0.05)
        cp_gt = get_cps_gt(g)

        cp_pred = cp_pred[1:-1]
        cp_gt = cp_gt[1:-1]

        if cp_gt == cp_pred:
            rmse = 0
        else:
            rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(rmse)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

plot_rmse_cps(d_cps, label_type="k_without_0")


for s

In [None]:
d_cps = {str(i): [] for i in range(1, 6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(1, 6)}  # For counting tracks per changepoint
total_tracks = 0

predictions = pred_state.copy()
ground_truth = gt_k.copy()

for gt, pd in zip(ground_truth, predictions):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
        
    gt = gt[gt!=LABEL_PADDING_VALUE]
    pd = pd[:len(gt)]
    
    # Only process sequences of length 200
    if len(gt) != 200:
        continue
        
    pd = replace_short_sequences(pd, min_length=3)
    
    cp_pred = state_cps_function(pd)[1:-1]
    cp_gt = get_cps_gt(gt)[1:-1]
    no_of_cps = len(cp_gt)
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps in range(1, 6) and tracks_per_cp[str(no_of_cps)] < 100:

    # if no_of_cps >= 6 or no_of_cpstracks_per_cp[str(no_of_cps)] >= 100:
    #     continue
        
        total_tracks += 1
        # if no_of_cps == 5:
        #     print(cp_pred, cp_gt)
        
        if cp_gt == cp_pred:
            rmse = 0
        else:
            rmse, _ = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(rmse)
        tracks_per_cp[str(no_of_cps)] += 1
        
        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

plot_rmse_cps(d_cps, label_type="state_without_0")

for alpha + K

In [None]:
# Initialize dictionaries to store results and track counts
d_cps = {str(i): [] for i in range(1, 6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(1, 6)}  # For counting tracks per changepoint
total_tracks = 0

# Create copies of all predictions and ground truth arrays
pred_a_data = pred_a.copy()
gt_a_data = gt_a.copy()
pred_k_data = pred_k.copy()
gt_k_data = gt_k.copy()
pred_state_data = pred_state.copy()
gt_state_data = gt_state.copy()

for i in range(len(gt_a_data)):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
    
    # Get valid indices for all sequences
    idx_a = padding_starts_index(gt_a_data[i])
    
    # Get predictions and ground truth for each model
    p_alpha = pred_a_data[i][:idx_a]
    g_alpha = gt_a_data[i][:idx_a]
    
    p_k = pred_k_data[i][:idx_a]
    g_k = gt_k_data[i][:idx_a]
    
    p_state = pred_state_data[i][:idx_a]
    g_state = gt_state_data[i][:idx_a]
    
    # Only process if sequence length is 200
    if len(g_k) != 200:
        continue
    
    # Count changepoints in ground truth (using any of the ground truth series)
    no_of_cps = count_changepoints(g_k)
    no_of_cps_alpha = count_changepoints(g_alpha)

    if no_of_cps_alpha != no_of_cps:
        continue
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps in range(1, 6) and tracks_per_cp[str(no_of_cps)] < 100:
    
        total_tracks += 1
        
        # Get combined changepoints using the combined_cps function
        merged_cps, _, _, _, _, _ = combined_cps_k_focused(p_alpha, p_k, p_state, include_state=True)
        # merged_cps, _, _, _, _, _ = combined_cps_k_focused(p_alpha, p_k, p_state)
        
        merged_cps = [0] + merged_cps
        merged_cps = merged_cps[1:-1]
        
        # Get ground truth changepoints (using any ground truth series as they should be the same)
        cp_gt = get_cps_gt(g_k)
        cp_gt = cp_gt[1:-1]  # Remove first and last points
        # Compare predicted and ground truth changepoints
        if cp_gt == merged_cps:
            rmse = 0
        else:
            rmse, _ = single_changepoint_error(cp_gt, merged_cps)
        
        # Store results
        d_cps[str(no_of_cps)].append(rmse)
        tracks_per_cp[str(no_of_cps)] += 1
        
        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

plot_rmse_cps(d_cps, label_type="alpha_k_state_focused_without_0")

for alpha + K + s

plot all

In [None]:
rmse_alpha_k_focused = np.load("rmse_values_alpha_k_focused_without_0_for_graph.npy")
rmse_alpha_k_state_k_focused = np.load("rmse_values_alpha_k_state_focused_without_0_for_graph.npy")
rmse_alpha = np.load("rmse_values_alpha_without_0_for_graph.npy")
rmse_k = np.load("rmse_values_k_without_0_for_graph.npy")
rmse_state = np.load("rmse_values_state_without_0_for_graph.npy")

n_cps = list(range(1,6))

fig = Figure(figsize=(8, 6), dpi=300)
canvas = FigureCanvasSVG(fig)
ax = fig.add_subplot(111)

lines = [
    (rmse_k, K_COLOR, "K"),
    (rmse_alpha, ALPHA_COLOR, "α"),
    (rmse_state, STATE_COLOR, "s"),
    (rmse_alpha_k_state_k_focused, AKS_COLOR, "α+K+s", '--', 'white'),  # Added linestyle and markerfacecolor
    (rmse_alpha_k_focused, ALPHA_K_FOCUSED_COLOR, "α+K"),
]

for item in lines:
    rmse_values, color, label, *style_args = item
    linestyle = style_args[0] if len(style_args) > 0 else '-'
    markerfacecolor = style_args[1] if len(style_args) > 1 else color
    
    ax.plot(n_cps, rmse_values, 
            color=color,
            linewidth=2,
            alpha=1,
            marker='o',
            markersize=10,
            markerfacecolor=markerfacecolor,
            markeredgecolor=color,
            markeredgewidth=2,
            linestyle=linestyle,
            label=label)

all_values = np.concatenate([
    rmse_alpha_k_focused,
    rmse_alpha_k_state_k_focused,
    rmse_alpha,
    rmse_k,
    rmse_state
])
valid_values = all_values[~(np.isnan(all_values) | np.isinf(all_values))]

if len(valid_values) > 0:
    ymin = np.min(valid_values)
    ymax = np.max(valid_values)
    margin = (ymax - ymin) * 0.1
    ax.set_ylim(ymin - margin, ymax + 1.5* margin)

ax.set_xticks(n_cps)
ax.tick_params(which='both', direction='out', length=6, width=1,
                colors='black', pad=2, labelsize=32)

for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1)
    spine.set_color('black')

ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
ax.set_ylabel('RMSE', fontsize=32, rotation=90, va='center', labelpad=20)

legend = ax.legend(fontsize=28, frameon=True,
                    loc='upper right',
                    bbox_to_anchor=(0.98, 0.98),
                    edgecolor='none',
                    facecolor='white',
                    framealpha=0.8)

fig.tight_layout()
fig.patch.set_facecolor('none')
ax.set_facecolor('none')

canvas.print_figure('all_methods_rmse_comparison_without_0.svg',
                    bbox_inches='tight',
                    pad_inches=0.1,
                    format='svg')

# Changepoint for Various Positions Alpha

In [None]:
# Initialize bins for different positions
bin_size = 20  # Assuming this is your bin size
num_bins = 200 // bin_size
d_positions = {str(i * bin_size): [] for i in range(num_bins)}
tracks_per_bin = {str(i * bin_size): 0 for i in range(num_bins)}
total_tracks = 0

predictions = pred_a.copy()
ground_truth = gt_a.copy()

for i in range(len(ground_truth)):
    if all(count >= 100 for count in tracks_per_bin.values()):
        break
    
    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    # Only process if sequence length is 200
    if len(g) == 200:
        
        cp_gt = get_cps_gt(g)[1:-1]

        if len(cp_gt) == 1:

            cp_pred, _ = get_cps_ruptures(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
            cp_pred = cp_pred[1:-1]        

            cp_position = cp_gt[0]
            bin_start = str((cp_position // bin_size) * bin_size)
            
            # Only process if we need more tracks for this position bin
            if bin_start in tracks_per_bin and tracks_per_bin[bin_start] < 100:
                # Calculate Jaccard index
                if cp_gt == cp_pred:
                    jaccard_value = 1
                else:
                    rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
                
                # Store the result
                d_positions[bin_start].append(jaccard_value)
                tracks_per_bin[bin_start] += 1
                total_tracks += 1
                
                # Print progress
                print(f"Total tracks: {total_tracks} | Tracks per bin:", end=" ")
                for k, v in tracks_per_bin.items():
                    print(f"{k}:{v}", end=" ")
                print("", end="\r")

In [None]:
# if you have it saved load from file
# with open('changepoint_position_jaccard_value_alpha.pkl', 'wb') as f:
#     pickle.dump(d_positions, f)

plot_jaccard_position(d_positions, label_type="alpha", bin_size=20)

In [None]:
# d_positions = {}  # Dictionary to store jaccard values by position bins
# bin_size = 20  # Size of position bins
# predictions = pred_a.copy()
# ground_truth = gt_a.copy()
# tracks = 0

# for i in range(len(ground_truth)):
#     idx = padding_starts_index(ground_truth[i])
#     p = predictions[i][:idx]
#     g = ground_truth[i][:idx]
    
#     no_of_cps = count_changepoints(g)

#     # Only look at tracks with length 200 and exactly 1 changepoint
#     if len(g) == 200 and no_of_cps == 1:
#         tracks += 1
        
#         if tracks > 2000:
#             break

#         cp_pred, _ = get_cps_ruptures(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
#         cp_gt = get_cps_gt(g)

#         # Calculate Jaccard index
#         cp_pred = cp_pred[1:-1]
#         cp_gt = cp_gt[1:-1]

#         position_bin = (cp_gt[0] - 1) // bin_size * bin_size  # Subtract 1 to start from 0
#         bin_label = str(position_bin)

#         if cp_gt == cp_pred:
#             jaccard_value = 1
#         else:
#             rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)

#         # Store in position-based dictionary
#         if bin_label not in d_positions:
#             d_positions[bin_label] = [jaccard_value]
#         else:
#             d_positions[bin_label].append(jaccard_value)

#         print(tracks, end="\r")

In [None]:
# tracks = 0
# for i in range(len(ground_truth)):
#     idx = padding_starts_index(ground_truth[i])
#     p = predictions[i][:idx]
#     g = ground_truth[i][:idx]
    
#     no_of_cps = count_changepoints(g)

#     # Only look at tracks with length 200 and exactly 1 changepoint
#     if len(g) == 200 and no_of_cps == 1:
#         tracks += 1
#         if tracks == 200:
#             plt.scatter([i for i in range(len(p))], p)  
#             plt.scatter([i for i in range(len(g))], g)
#             break
# plt.show()      

In [None]:
# Calculate average Jaccard value for each position bin
results = {}
for key, items in d_positions.items():
    avg_jaccard = sum(items) / len(items)
    num_tracks = len(items)
    results[key] = {
        'average_jaccard': avg_jaccard,
        'num_tracks': num_tracks,
        'all_values': items
    }
    
print("\nResults by position bin:")
for key in sorted(results.keys(), key=int):
    print(f"\nPosition bin {key}-{int(key)+bin_size}:")
    print(f"Average Jaccard value: {results[key]['average_jaccard']:.4f}")
    print(f"Number of tracks: {results[key]['num_tracks']}")

# Create figure
plt.figure(figsize=(10, 6))

# Extract x and y values for plotting
x_values = [int(key) + bin_size/2 for key in sorted(results.keys(), key=int)]  # Use bin centers
y_values = [results[key]['average_jaccard'] for key in sorted(results.keys(), key=int)]

# Create scatter plot
plt.scatter(x_values, y_values, color='blue', s=100)

# Add value labels above each point
for i, v in enumerate(y_values):
    plt.text(x_values[i], v + 0.01, f'{v:.3f}', ha='center')

plt.xlabel('Changepoint Position')
plt.ylabel('Average Jaccard Value')
plt.title('Average Jaccard Values by Changepoint Position For Alpha')
plt.grid(True, linestyle='--', alpha=0.7)

# Set x-ticks to show bin edges
bin_edges = range(0, 201, bin_size)
plt.xticks(bin_edges)

# Set y-axis limits to start from 0 and have some padding at the top
plt.ylim(0, max(y_values) + 0.1)

plt.tight_layout()
plt.savefig("jaccard_vs_position_test_set_alpha.svg", format="svg")
plt.show()

# Changepoints for Various Positions K

In [None]:
# Initialize bins for different positions
bin_size = 20
num_bins = 200 // bin_size
d_positions = {str(i * bin_size): [] for i in range(num_bins)}
tracks_per_bin = {str(i * bin_size): 0 for i in range(num_bins)}
total_tracks = 0

predictions = pred_k.copy()
ground_truth = gt_k.copy()

for i in range(len(ground_truth)):
    # Check if we have 100 tracks for all position bins
    if all(count >= 100 for count in tracks_per_bin.values()):
        break
    
    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    # Only process if sequence length is 200
    if len(g) == 200:
        
        cp_gt = get_cps_gt(g)[1:-1]

        if len(cp_gt) == 1:
            cp_pred, _ = get_cps_ruptures(p, lower_limit=0, upper_limit=6, threshold=0.05)  # Changed upper_limit for k
            cp_pred = cp_pred[1:-1]        

            cp_position = cp_gt[0]
            bin_start = str((cp_position // bin_size) * bin_size)
            
            # Only process if we need more tracks for this position bin
            if bin_start in tracks_per_bin and tracks_per_bin[bin_start] < 100:
                # Calculate Jaccard index
                if cp_gt == cp_pred:
                    jaccard_value = 1
                else:
                    rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
                
                # Store the result
                d_positions[bin_start].append(jaccard_value)
                tracks_per_bin[bin_start] += 1
                total_tracks += 1
                
                # Print progress
                print(f"Total tracks: {total_tracks} | Tracks per bin:", end=" ")
                for k, v in tracks_per_bin.items():
                    print(f"{k}:{v}", end=" ")
                print("", end="\r")

In [None]:
# if saved load from file
# with open('changepoint_position_jaccard_value_k.pkl', 'wb') as f:
#     pickle.dump(d_positions, f)

plot_jaccard_position(d_positions, label_type="k", bin_size=20)

# Changepoints for Various Positions State

In [None]:
# Initialize bins for different positions
bin_size = 20
num_bins = 200 // bin_size
d_positions = {str(i * bin_size): [] for i in range(num_bins)}
tracks_per_bin = {str(i * bin_size): 0 for i in range(num_bins)}
total_tracks = 0

predictions = pred_state.copy()
ground_truth = gt_state.copy()

for gt, pd in zip(ground_truth, predictions):
    # Check if we have 100 tracks for all position bins
    if all(count >= 100 for count in tracks_per_bin.values()):
        break
        
    gt = gt[gt!=LABEL_PADDING_VALUE]
    pd = pd[:len(gt)]
    
    # Only process if sequence length is 200
    if len(gt) == 200:
        pd = replace_short_sequences(pd, min_length=3)
        
        cp_gt = get_cps_gt(gt)[1:-1]

        if len(cp_gt) == 1:  # Only process sequences with exactly one changepoint
            cp_pred = state_cps_function(pd)[1:-1]

            cp_position = cp_gt[0]
            bin_start = str((cp_position // bin_size) * bin_size)
            
            # Only process if we need more tracks for this position bin
            if bin_start in tracks_per_bin and tracks_per_bin[bin_start] < 100:
                # Calculate Jaccard index
                if cp_gt == cp_pred:
                    jaccard_value = 1
                else:
                    rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
                
                # Store the result
                d_positions[bin_start].append(jaccard_value)
                tracks_per_bin[bin_start] += 1
                total_tracks += 1
                
                # Print progress
                print(f"Total tracks: {total_tracks} | Tracks per bin:", end=" ")
                for k, v in tracks_per_bin.items():
                    print(f"{k}:{v}", end=" ")
                print("", end="\r")

In [None]:
# if saved load from file
# with open('changepoint_position_jaccard_value_state.pkl', 'wb') as f:
#     pickle.dump(d_positions, f)

plot_jaccard_position(d_positions, label_type="state", bin_size=20)

# Changepoints for All variables Positions

In [None]:
# Load (when you need it later)
with open('plots/changepoint_position_jaccard_value_k.pkl', 'rb') as f:
    d_positions_k = pickle.load(f)

# Load (when you need it later)
with open('plots/changepoint_position_jaccard_value_state.pkl', 'rb') as f:
    d_positions_state = pickle.load(f)

# Load (when you need it later)
with open('plots/changepoint_position_jaccard_value_alpha.pkl', 'rb') as f:
    d_positions_alpha = pickle.load(f)

fig, ax = plot_multiple_position_jaccard(d_positions_alpha, d_positions_k, d_positions_state)

# Alpha for Various CPs

In [None]:
d_cps = {str(i): [] for i in range(6)}  # Initialize with all possible changepoint numbers

predictions = pred_a.copy()
ground_truth = gt_a.copy()
tracks_per_cp = {str(i): 0 for i in range(6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(6) and tracks_per_cp[str(no_of_cps)] < 100:
        total_tracks += 1
        
        cp_pred, _ = get_cps_ruptures(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
        cp_gt = get_cps_gt(g)

        cp_pred = cp_pred[1:-1]
        cp_gt = cp_gt[1:-1]

        if cp_gt == cp_pred:
            jaccard_value = 1
        else:
            rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(jaccard_value)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

# print("\nFinal counts:")
# for k, v in tracks_per_cp.items():
#     print(f"Changepoints {k}: {v} tracks")

# print("\nFinal d_cps lengths:")
# for k, v in d_cps.items():
#     print(f"Changepoints {k}: {len(v)} values")


In [None]:
plot_jaccard_cps(d_cps, label_type="alpha")

# K for Various CPs

In [None]:
d_cps = {str(i): [] for i in range(6)}  # Initialize with all possible changepoint numbers

predictions = pred_k.copy()
ground_truth = gt_k.copy()
tracks_per_cp = {str(i): 0 for i in range(6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    no_of_cps_alpha = count_changepoints(gt_a[i][:idx])

    if no_of_cps_alpha != no_of_cps:
        continue
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(6) and tracks_per_cp[str(no_of_cps)] < 100:
        total_tracks += 1
        
        cp_pred, _ = get_cps_ruptures(p, lower_limit=0, upper_limit=6, threshold=0.05)  # Changed upper limit to 6 for K
        cp_gt = get_cps_gt(g)

        cp_pred = cp_pred[1:-1]
        cp_gt = cp_gt[1:-1]

        if cp_gt == cp_pred:
            jaccard_value = 1
        else:
            rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(jaccard_value)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

# print("\nFinal counts:")
# for k, v in tracks_per_cp.items():
#     print(f"Changepoints {k}: {v} tracks")

# print("\nFinal d_cps lengths:")
# for k, v in d_cps.items():
#     print(f"Changepoints {k}: {len(v)} values")

In [None]:
plot_jaccard_cps(d_cps, label_type="k_test")

# State Various CPs

In [None]:
# Initialize dictionaries to store results and track counts
d_cps = {str(i): [] for i in range(6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(6)}  # For counting tracks per changepoint
total_tracks = 0

predictions = pred_state.copy()
ground_truth = gt_k.copy()

for gt, pd in zip(ground_truth, predictions):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
        
    gt = gt[gt!=LABEL_PADDING_VALUE]
    pd = pd[:len(gt)]
    
    # Only process sequences of length 200
    if len(gt) != 200:
        continue
        
    pd = replace_short_sequences(pd, min_length=3)
    
    cp_pred = state_cps_function(pd)[1:-1]
    cp_gt = get_cps_gt(gt)[1:-1]

    # Count the number of changepoints in ground truth
    no_of_cps = len(cp_gt)
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps >= 6 or tracks_per_cp[str(no_of_cps)] >= 100:
        continue
        
    total_tracks += 1
    
    if cp_gt == cp_pred:
        jaccard_value = 1
    else:
        rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
        
    d_cps[str(no_of_cps)].append(jaccard_value)
    tracks_per_cp[str(no_of_cps)] += 1
    
    # Print progress
    print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
    for k, v in tracks_per_cp.items():
        print(f"{k}:{v}", end=" ")
    print("", end="\r")

In [None]:
plot_jaccard_cps(d_cps, label_type="state")

# Using all variable for changepoint detections

In [None]:
# Initialize dictionaries to store results and track counts
d_cps = {str(i): [] for i in range(6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(6)}  # For counting tracks per changepoint
total_tracks = 0

# Create copies of all predictions and ground truth arrays
pred_a_data = pred_a.copy()
gt_a_data = gt_a.copy()
pred_k_data = pred_k.copy()
gt_k_data = gt_k.copy()
pred_state_data = pred_state.copy()
gt_state_data = gt_state.copy()

for i in range(len(gt_a_data)):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
    
    # Get valid indices for all sequences
    idx_a = padding_starts_index(gt_a_data[i])
    
    # Get predictions and ground truth for each model
    p_alpha = pred_a_data[i][:idx_a]
    g_alpha = gt_a_data[i][:idx_a]
    
    p_k = pred_k_data[i][:idx_a]
    g_k = gt_k_data[i][:idx_a]
    
    p_state = pred_state_data[i][:idx_a]
    g_state = gt_state_data[i][:idx_a]
    
    # Only process if sequence length is 200
    if len(g_k) != 200:
        continue
    
    # Count changepoints in ground truth (using any of the ground truth series)
    no_of_cps = count_changepoints(g_k)
    no_of_cps_alpha = count_changepoints(g_alpha)

    if no_of_cps_alpha != no_of_cps:
        continue
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps >= 6 or tracks_per_cp[str(no_of_cps)] >= 100:
        continue
    
    total_tracks += 1
    
    # Get combined changepoints using the combined_cps function
    # combined_cps_k_focused_with_state
    # merged_cps, _, _, _, _, _ = combined_cps_with_state(p_alpha, p_k, p_state)
    merged_cps, _, _, _, _, _ = combined_cps_k_focused(p_alpha, p_k, p_state, include_state=True)

    # merged_cps, _, _, _, _, _ = combined_cps_k_focused(p_alpha, p_k, window_size=5)
    # merged_cps, _, _, _, _ = combined_cps(p_alpha, p_k)
    
    merged_cps = [0] + merged_cps
    merged_cps = merged_cps[1:-1]
    
    # Get ground truth changepoints (using any ground truth series as they should be the same)
    cp_gt = get_cps_gt(g_k)
    cp_gt = cp_gt[1:-1]  # Remove first and last points
    
    # Compare predicted and ground truth changepoints
    if cp_gt == merged_cps:
        jaccard_value = 1
        rmse = 0
    else:
        rmse, jaccard_value = single_changepoint_error(cp_gt, merged_cps)
    
    # Store results
    d_cps[str(no_of_cps)].append(jaccard_value)
    tracks_per_cp[str(no_of_cps)] += 1
    
    # Print progress
    print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
    for k, v in tracks_per_cp.items():
        print(f"{k}:{v}", end=" ")
    print("", end="\r")

# Print final statistics
print("\nFinal counts:")
for k, v in tracks_per_cp.items():
    print(f"Changepoints {k}: {v} tracks")

print("\nFinal d_cps lengths:")
for k, v in d_cps.items():
    print(f"Changepoints {k}: {len(v)} values")

In [None]:
plot_jaccard_cps(d_cps, label_type="alpha_k_state_k_focused")

Once you have all can plot here

In [None]:
j_alpha_k_focused = np.load("plots/jaccard_values_alpha_and_k_focused_for_graph.npy")
# j_alpha_k = np.load("jaccard_values_alpha_and_k_for_graph.npy")
j_alpha_k_state_k_focused = np.load("plots/jaccard_values_alpha_k_state_k_focused_for_graph.npy")
# j_alpha_k_state = np.load("jaccard_values_alpha_k_state_for_graph.npy")
j_alpha = np.load("plots/jaccard_values_alpha_for_graph.npy")
j_k = np.load("plots/jaccard_values_k_for_graph.npy")
j_state = np.load("plots/jaccard_values_state_new_for_graph.npy")

# Create x-axis values
n_cps = [i for i in range(6)]

# Create figure with specific DPI
fig = Figure(figsize=(8, 6), dpi=300)
canvas = FigureCanvasSVG(fig)
ax = fig.add_subplot(111)

lines = [
    (j_k, K_COLOR, "K"),
    (j_alpha, ALPHA_COLOR, "α"),
    (j_state, STATE_COLOR, "s"),
    (j_alpha_k_state_k_focused, "black", "α+K+s", '--', 'white'),
    (j_alpha_k_focused, "black", "α+K")
    ]

for item in lines:
    jaccard_values, color, label, *style_args = item
    linestyle = style_args[0] if len(style_args) > 0 else '-'
    markerfacecolor = style_args[1] if len(style_args) > 1 else color

    ax.plot(n_cps, jaccard_values,
    color=color,
    linewidth=2,
    alpha=1,
    marker='o',
    markersize=10,
    markerfacecolor=markerfacecolor,
    markeredgecolor=color,
    markeredgewidth=2,
    linestyle=linestyle,
    label=label)

# Set y-axis limits with margin
all_values = np.concatenate([j_alpha_k_focused, 
# j_alpha_k, 
j_alpha_k_state_k_focused, 
# j_alpha_k_state, 
j_alpha, 
j_k, 
j_state])
ymin = np.min(all_values)
ymax = np.max(all_values)
margin = (ymax - ymin) * 0.1

# ax.set_ylim(ymin - margin, ymax + margin)
ax.set_ylim(0,1)
ax.set_yticks([0, 0.5, 1])
ax.set_yticklabels(['0', '0.5', '1'])
# Set x-axis to show only integer values
ax.set_xticks(n_cps)


# Configure axis and style
ax.set_axisbelow(True)

# Configure tick parameters
ax.tick_params(which='both', direction='out', length=6, width=1,
colors='black', pad=2)

# Set tick label sizes
ax.tick_params(axis='both', labelsize=32)

for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1)
    spine.set_color('black')

ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
ax.set_ylabel(r'$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)

legend = ax.legend(fontsize=28, frameon=True,
loc='upper right',
bbox_to_anchor=(0.98, 0.98),
edgecolor='none',
facecolor='white',
framealpha=0.8)

fig.tight_layout()
fig.patch.set_facecolor('none')
ax.set_facecolor('none')

canvas.print_figure('all_methods_jaccard_comparison_new.svg',
bbox_inches='tight',
pad_inches=0.1,
format='svg')

# CPs for different models

In [None]:
# Define models in the correct order
# Constants
d_indices = {model: [[] for _ in range(NO_OF_CPS_MAX + 1)] for model in MODELS}

# Get masks for valid tracks
sequence_lengths = (gt_k != LABEL_PADDING_VALUE).sum(axis=1)
length_mask = sequence_lengths == 200

changes = gt_k[:, 1:] != gt_k[:, :-1]
changepoints = changes.sum(axis=1)
changepoint_mask = (changepoints >= 0) & (changepoints <= NO_OF_CPS_MAX)

final_mask = changepoint_mask & length_mask
valid_indices = np.where(final_mask)[0]

# First pass: categorize all valid tracks
for idx in valid_indices:
    cp_gt = np.where(gt_k[idx, 1:] != gt_k[idx, :-1])[0] + 1
    no_of_cps = len(cp_gt)
    states = gt_state[idx]
    
    # Determine model
    if 1 in states:
        model = "conf."
    elif 0 in states:
        model = "imm."
    elif 3 in states:
        model = "dir."
    else:
        model = "free"
        
    d_indices[model][no_of_cps].append(idx)

# Initialize result dictionary
d_cps = {model: [[] for _ in range(NO_OF_CPS_MAX + 1)] for model in MODELS}

# Process tracks with balanced sampling
progress_bar = tqdm(desc="Processing tracks", total=len(MODELS) * (NO_OF_CPS_MAX + 1) * TRACKS_PER_CATEGORY)

for model in MODELS:
    for no_of_cps in range(NO_OF_CPS_MAX + 1):
        available_indices = d_indices[model][no_of_cps]
        
        # If we don't have enough tracks, sample with replacement
        if len(available_indices) == 0:
            continue
            
        indices_to_process = np.random.choice(
            available_indices,
            size=TRACKS_PER_CATEGORY,
            replace=len(available_indices) < TRACKS_PER_CATEGORY
        )
        
        for idx in indices_to_process:

            cp_gt = np.where(gt_k[idx, 1:] != gt_k[idx, :-1])[0] + 1
            pk = pred_k[idx]
            pa = pred_a[idx]
            ps = pred_state[idx]
            
            # Calculate changepoints and Jaccard
            # cp_pred, _ = get_cps_ruptures(pk, lower_limit=0, upper_limit=6, threshold=0.05)
            cp_pred, _, _, _, _, _ = combined_cps_k_focused(pa, pk, ps)
            cp_pred = [0] + cp_pred
            cp_pred = cp_pred[1:-1]
            
            if list(cp_gt) == list(cp_pred):
                jaccard_value = 1
            else:
                rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
            d_cps[model][no_of_cps].append(jaccard_value)
            progress_bar.update()

progress_bar.close()

# Print statistics about the sampling
print("\nSampling Statistics:")
print("-" * 50)
for model in MODELS:
    print(f"\nModel: {model}")
    for no_of_cps in range(NO_OF_CPS_MAX + 1):
        original_count = len(d_indices[model][no_of_cps])
        sampled_count = len(d_cps[model][no_of_cps])
        print(f"  CPs={no_of_cps}: Original={original_count}, Sampled={sampled_count}")
        if original_count < TRACKS_PER_CATEGORY and original_count > 0:
            print(f"    (Sampled with replacement to reach target)")

# Save
with open('jaccard_for_models.json', 'w') as f:
    json.dump(d_cps, f)


In [None]:
# # load data if already saved 
# with open('jaccard_for_models.json', 'r') as f:
#    d_cps = json.load(f)

In [None]:
for model in MODELS:
    print(f"\nModel: {model}")
    for no_of_cps in range(NO_OF_CPS_MAX + 1):
        sampled_count = len(d_cps[model][no_of_cps])
        print(f"  CPs={no_of_cps}: Original={original_count}, Sampled={sampled_count}")
        if original_count < TRACKS_PER_CATEGORY and original_count > 0:
            print(f"    (Sampled with replacement to reach target)")

In [None]:
plot_model_jaccard(d_cps)

In [None]:
# Print results
model = []  # Clear your data lists
jaccard_values = []
# Calculate average Jaccard value for each number of changepoints
results = {}
for key, items in d_cps.items():
    avg_jaccard = sum(items) / len(items)
    num_tracks = len(items)
    results[key] = {
        'average_jaccard': avg_jaccard,
        'num_tracks': num_tracks,
        'all_values': items
    }
    
print("\nResults by number of changepoints:")
for key in sorted(results.keys()):
    print(f"\nNumber of changepoints: {key}")
    print(f"Average Jaccard value: {results[key]['average_jaccard']:.4f}")
    print(f"Number of tracks: {results[key]['num_tracks']}")

# Optionally, create lists for plotting
for key in sorted(results.keys()):
    delta_alpha.append(int(key))
    jaccard_values.append(results[key]['average_jaccard'])
# You can now use delta_alpha and jaccard_values for plotting if needed
# Example plot:
plt.figure(figsize=(10, 6))
plt.plot(delta_alpha, jaccard_values, marker='o')
plt.xlabel('Number of Changepoints')
plt.ylabel('Average Jaccard Value')
plt.title('Jaccard Values vs Number of Changepoints for K')
plt.grid(True)
plt.savefig("jaccard_vs_changepoints_test_set_for_k_2.svg", format="svg")
plt.show()

# Error in K and Alpha Per State/Model

- Given the ground truth state, works out error in K and alpha

In [None]:
errors_k = {'imm.': [], 'conf.': [], 'free': [], 'dir.': []}
errors_a = {'imm.': [], 'conf.': [], 'free': [], 'dir.': []}

# Create mask for valid entries
valid_mask = np.all(gt_state != LABEL_PADDING_VALUE, axis=1)

pk2 = np.clip(pred_k, 0, 6)
pa2 = np.clip(pred_a, 0, 2)

# Calculate absolute errors
k_error = np.abs(gt_k - pk2).sum(axis=1)/200
a_error = np.abs(gt_a - pa2).sum(axis=1)/200

# Group errors by state
for idx in range(len(gt_state)):
    if valid_mask[idx]:
        states = gt_state[idx]
        k_err = k_error[idx]
        a_err = a_error[idx]
        
        if 1 in states:
            errors_k['conf.'].append(k_err)
            errors_a['conf.'].append(a_err)
        elif 0 in states:
            errors_k['imm.'].append(k_err)
            errors_a['imm.'].append(a_err)
        elif 3 in states:
            errors_k['dir.'].append(k_err)
            errors_a['dir.'].append(a_err)
        else:
            errors_k['free'].append(k_err)
            errors_a['free'].append(a_err)

# Calculate MAE for each state
mae_k = {state: np.mean(errs) for state, errs in errors_k.items()}
mae_a = {state: np.mean(errs) for state, errs in errors_a.items()}

plot_errors_by_model(mae_k, mae_a)

In [None]:
# # Initialize error lists
# errors_k = [[] for _ in range(NUM_CLASSES)]
# errors_alpha = [[] for _ in range(NUM_CLASSES)]

# # Clip predictions
# # pred_k = np.clip(pred_k, 0, 6)
# # pred_a = np.clip(pred_a, 0, 2)

# pk = median_filter_1d(smooth_series(pred_k, lower_limit=0, upper_limit=6))
# pa = median_filter_1d(smooth_series(pred_a, lower_limit=0, upper_limit=1.999))

# # Calculate errors
# k_error = gt_k - pred_k 
# alpha_error = gt_a - pred_a 

# # Collect errors by state
# for class_idx in range(NUM_CLASSES):
#     valid_positions = (gt_state == class_idx) & (gt_state != LABEL_PADDING_VALUE)
#     errors_k[class_idx].extend(k_error[valid_positions])
#     errors_alpha[class_idx].extend(alpha_error[valid_positions])

# # Calculate statistics
# mean_k = [np.mean(err) if len(err) > 0 else np.nan for err in errors_k]
# mean_alpha = [np.mean(err) if len(err) > 0 else np.nan for err in errors_alpha]
# std_k = [np.std(err) if len(err) > 0 else np.nan for err in errors_k]
# std_alpha = [np.std(err) if len(err) > 0 else np.nan for err in errors_alpha]

# # Calculate 95% confidence intervals
# ci_k = np.array([stats.sem(err) * stats.t.ppf((1 + 0.95) / 2, len(err)-1) 
#                  if len(err) > 1 else np.nan for err in errors_k])
# ci_alpha = np.array([stats.sem(err) * stats.t.ppf((1 + 0.95) / 2, len(err)-1) 
#                     if len(err) > 1 else np.nan for err in errors_alpha])

# # Absolute error statistics
# abs_mean_k = [np.mean(np.abs(err)) if len(err) > 0 else np.nan for err in errors_k]
# abs_mean_alpha = [np.mean(np.abs(err)) if len(err) > 0 else np.nan for err in errors_alpha]

# # Print statistics
# print("\nError Statistics by State:")
# print("-" * 50)
# for state in range(NUM_CLASSES):
#     print(f"\nState {state}:")
#     print(f"K error      - Mean ± Std: {mean_k[state]:.3f} ± {std_k[state]:.3f}")
#     print(f"K error      - Mean ± 95% CI: {mean_k[state]:.3f} ± {ci_k[state]:.3f}")
#     print(f"K error      - Abs Mean: {abs_mean_k[state]:.3f}")
#     print(f"Alpha error  - Mean ± Std: {mean_alpha[state]:.3f} ± {std_alpha[state]:.3f}")
#     print(f"Alpha error  - Mean ± 95% CI: {mean_alpha[state]:.3f} ± {ci_alpha[state]:.3f}")
#     print(f"Alpha error  - Abs Mean: {abs_mean_alpha[state]:.3f}")

# mean_k = np.array(mean_k)
# mean_alpha = np.array(mean_alpha)
# std_k = np.array(std_k)
# std_alpha = np.array(std_alpha)

# # Create figure using Figure and Canvas
# fig = Figure(figsize=(10, 6), dpi=300)
# canvas = FigureCanvasSVG(fig)
# ax = fig.add_subplot(111)
# states = np.arange(NUM_CLASSES)

# # Create plots with filled circles and dashed error bars
# k_line = ax.errorbar(states, mean_k, yerr=std_k, 
#                     label='K',
#                     fmt='o',
#                     markerfacecolor=K_COLOR,
#                     markeredgecolor=K_COLOR,
#                     markeredgewidth=2,
#                     capsize=CAPSIZE,
#                     # alpha=0.9,
#                     color=K_COLOR,
#                     markersize=MARKER_SIZE,
#                     capthick=LINE_WIDTH,
#                     elinewidth=LINE_WIDTH,
#                     ls='none')

# k_line[-1][0].set_linestyle('--')

# alpha_line = ax.errorbar(states, mean_alpha, yerr=std_alpha,
#                         label='α',
#                         fmt='o',
#                         markerfacecolor=ALPHA_COLOR,
#                         markeredgecolor=ALPHA_COLOR,
#                         markeredgewidth=2,
#                         capsize=CAPSIZE,
#                         # alpha=0.7,
#                         color=ALPHA_COLOR,
#                         markersize=MARKER_SIZE,
#                         capthick=LINE_WIDTH,
#                         elinewidth=LINE_WIDTH,
#                         ls='none')

# alpha_line[-1][0].set_linestyle('--')

# # Calculate y-axis limits with padding
# y_min = min(min(mean_k - std_k), min(mean_alpha - std_alpha))
# y_max = max(max(mean_k + std_k), max(mean_alpha + std_alpha))
# y_range = y_max - y_min
# padding = 0.3 * y_range  # 10% padding
# ax.set_ylim(y_min - padding, y_max + padding)

# # Baseline
# ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5, alpha=0.5)

# # Configure spines
# for spine in ax.spines.values():
#     spine.set_visible(True)
#     spine.set_linewidth(1)
#     spine.set_color('black')

# ax.set_xlabel('State (s)', fontsize=FONT_SIZE, labelpad=20)  # Updated x-label
# ax.set_ylabel(r'$\overline{Error}$ ± σ', fontsize=FONT_SIZE, labelpad=60)  # Updated y-label with sigma

# # Configure ticks
# ax.tick_params(which='both', direction='out', length=6, width=1,
#               colors='black', pad=2, labelsize=FONT_SIZE)
# ax.set_xticks(states)
# ax.set_xticklabels([str(i) for i in states])

# # Legend
# legend = ax.legend(fontsize=FONT_SIZE,
#                   frameon=True,
#                   edgecolor='black',
#                   loc='upper right')
# legend.get_frame().set_alpha(1)

# # Set transparent background
# fig.patch.set_facecolor('none')
# ax.set_facecolor('none')

# # Layout adjustments
# fig.tight_layout(pad=1.0)

# # Save using canvas.print_figure
# canvas.print_figure('error_by_state.svg',
#                    bbox_inches='tight',
#                    pad_inches=0.1,
#                    format='svg')

# Track Length Error

In [None]:
# Usage
# fig, ax = create_loss_length_plot(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state, ROOT_DIR)
fig, ax = create_loss_length_plot_smooth(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state, "./")


# Plot Heatmaps 

In [None]:
def mae(pred, gt, padding_value=LABEL_PADDING_VALUE):

   non_zero_mask = (gt != padding_value)
   loss = np.abs(pred - gt) * non_zero_mask

   sequence_lengths = non_zero_mask.sum(axis=1)
   sequence_sums = loss.sum(axis=1)
   sequence_averages = sequence_sums / sequence_lengths

   return sequence_averages.mean()

male_k = mae(pred_k.copy(), gt_k.copy())
mae_alpha = mae(pred_a.copy(), gt_a.copy())

print(mae_alpha, male_k)

In [None]:
# # Usage example:
# # For alpha:
# fig, ax = create_heatmap(
#     gt_a.flatten(), 
#     pred_a.flatten(), 
#     mae_alpha,
#     param_name='alpha',
#     value_range=(0, 2),
#     output_dir=ROOT_DIR,
#     color="alpha_color"
# )

# For K:
fig, ax = create_heatmap(
    gt_k.flatten(), 
    pred_k.flatten(), 
    male_k,
    param_name='K',
    value_range=(0, 3),
    output_dir=ROOT_DIR,
    color="k_color"
)

# Confusion Matrix 

In [None]:
cm_normalized = calculate_normalized_confusion_matrix(pred_state, gt_state)
ax = plot_confusion_matrix(cm_normalized, ROOT_DIR)
metrics = calculate_metrics(pred_state, gt_state)

# Jaccard Plots Alpha/K Publication Level

In [None]:
file_path = 'path to jaccard.json'
counter_path = 'path to counter.json'
# fig, ax = create_jaccard_plot_new(file_path, counter_path, parameter_type='alpha')
# plt.show()

# Example Tracks by model

In [None]:
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

AlphaModel = RegressionModel().to(DEVICE)
KModel = RegressionModel().to(DEVICE)
StateModel = ClassificationModel().to(DEVICE)

AlphaModel.load_state_dict(torch.load("models/optimal_weights/alpha_weights_with_fixed"))
KModel.load_state_dict(torch.load("models/optimal_weights/k_weights"))
StateModel.load_state_dict(torch.load("models/optimal_weights/state_weights"))

AlphaModel.eval()
KModel.eval()
StateModel.eval()

def getPredictions(df, max_size=200):

    features = np.nan_to_num(getFeatures(df["x"].values, df["y"].values), nan=0.0, posinf=0.0, neginf=0.0)
    features = torch.tensor(features, dtype=torch.float32, device=DEVICE).unsqueeze(0)
    length = features.size(1)

    if length < max_size:
        features = F.pad(features, (0, 0, 0, max_size - length), value=FEATURE_PADDING_VALUE)
    elif length > max_size:
        features = features[:, :max_size]
        print(f"Note that the input series is longer than the maximum size. The input series has been truncated to the first {max_size} values.")

    with torch.no_grad():
        # convert to numpy arrays for downstream analysis
        pred_alpha_list = AlphaModel(features).cpu().numpy().flatten().squeeze()[:length]
        pred_k_list = KModel(features).cpu().numpy().flatten().squeeze()[:length]
        states_log_probs = StateModel(features)
        pred_states_list = torch.argmax(states_log_probs, dim=-1).cpu().numpy().flatten().squeeze()[:length]

    pred_alpha_list = np.array(pred_alpha_list)
    pred_k_list = np.array(pred_k_list)
    pred_states_list = np.array(pred_states_list)

    pred_alpha_list = np.clip(median_filter_1d(smooth_series(pred_alpha_list)), 0, 2)
    pred_k_list = np.clip(median_filter_1d(smooth_series(pred_k_list)), 0, 6)

    return pred_alpha_list, pred_k_list, pred_states_list



Single State

In [36]:
def single_state(index=0):
    dic = _get_dic_andi2(1)
    dic['T'] = 200 
    dic['N'] = 100
    # dic['alphas'], dic['Ds'] = np.array([0.01, 0]), np.array([500, 0])

    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)

    df = dfs_traj[0]
    for i in df["traj_idx"].unique():
        if len(df[df["traj_idx"] == i]) <= 25:
            index = i
            break           
    return df[df["traj_idx"] == index]

In [None]:
example_track = single_state() 
plt.plot(example_track["x"], example_track["y"])
plt.figure()
example_track.to_csv(f"short_trajectory_single_example.csv", index=False)
pa, pk, ps = getPredictions(example_track)
create_plots_for_track_examples(example_track, pa, pk, ps)

MultiState

In [None]:
def multi_state():
    dic = _get_dic_andi2(2)
    dic['T'] = 200 
    dic['N'] = 100
    dic['M'] = np.array([[0.5, 0.5],[0.5, 0.5]])
    dic['alphas'] = np.array([[1,0],[0.5,0]])
    dic["Ds"] = np.array([[1,0],[0.1,0]])
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    
    df = dfs_traj[0]

    for i in df["traj_idx"].unique():
        if len(get_cps_gt(df[df["traj_idx"] == i]["D"].values)[1:-1]) == 2 and len(df[df["traj_idx"] == i]["D"].values) <=40:
            index = i
            break

    return df[df["traj_idx"] == index]

In [None]:
example_track = multi_state() 
pa, pk, ps = getPredictions(example_track)
example_track.to_csv(f"short_bad_trajectory_multi_example.csv", index=False)
create_plots_for_track_examples(example_track, pa, pk, ps)

Confined

In [None]:
def confined():
    dic = _get_dic_andi2(5)
    dic['T'] = 200 
    dic['N'] = 100

    Ds_array = np.array([[1.5, 0.1],[0.3,0.1]])
    alphas_array = np.array([[0.9, 0.1],[0.5,0.1]])
    dic['Ds'] = Ds_array
    dic['alphas'] = alphas_array
    dic['trans'] = 0
    dic['Nc'] = random.randint(30,35)
    dic['r'] = random.randint(5, 10)
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    df = dfs_traj[0]
    
    for _, group in df.groupby("traj_idx"):
        if len(get_cps_gt(group["D"].values)[1:-1]) == 1:
            index = group["traj_idx"].iloc[0]  # Get the actual traj_idx value
            break
    return df[df["traj_idx"] == index]

In [None]:
example_track = confined() 
pa, pk, ps = getPredictions(example_track)
example_track.to_csv(f"bad_trajectory_confined_example.csv", index=False)
create_plots_for_track_examples(example_track, pa, pk, ps)

Immobile

In [None]:
def immobile():
    dic = _get_dic_andi2(3)
    dic['T'] = 200 
    dic['N'] = 100

    dic['Ds'], dic['alphas'] = np.array([1, 0]), np.array([1, 0])
    dic['Pu'] = random.uniform(0, 0.1)
    dic['Pb'] = 1
    dic['r'] = random.uniform(0.5, 2)
    dic['Nt'] = random.randint(100, 300)
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    df = dfs_traj[0]

    for _, group in df.groupby("traj_idx"):
        if len(get_cps_gt(group["D"].values)[1:-1]) == 5:
            index = group["traj_idx"].iloc[0]  # Get the actual traj_idx value
            break

    return df[df["traj_idx"] == index]


In [None]:
example_track = immobile() 
example_track.to_csv(f"bad_trajectory_confined_immobile.csv", index=False)
pa, pk, ps = getPredictions(example_track)
create_plots_for_track_examples(example_track, pa, pk, ps)

Dimerised

In [50]:
def dimerised(index=0):
    dic = _get_dic_andi2(4)
    dic['T'] = 200 
    dic['N'] = 100
    Ds_array, alphas_array = np.zeros((2, 2)), np.zeros((2, 2))
    Ds_array[0], alphas_array[0] = np.array([1, 0]),  np.array([1, 0])
    Ds_array[1], alphas_array[1]  = np.array([1, 0]), np.array([1, 0])
    dic['Ds'] = Ds_array
    dic['alphas'] = alphas_array
    dic['Pu'] = random.uniform(0, 0.1)
    dic['Pb'] = 1
    dic['r'] = random.uniform(0.5, 5)
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    df = dfs_traj[0]
    return df[df["traj_idx"] == index]

In [None]:
example_track = dimerised() 
example_track.to_csv(f"trajectory_confined_dimerised.csv", index=False)
pa, pk, ps = getPredictions(example_track)
create_plots_for_track_examples(example_track, pa, pk, ps)

# Plots from Tensorboard 

In [None]:
# Example usage:
# plot_tensorboard_metrics('train_loss.csv', 'val_loss.csv', 'training_plot.svg')
# plot_tensorboard_metrics("Losses_Training_K_Loss.csv", "Losses_Validation_K_Loss.csv", save_path="k_train_plots.svg")

# Example track ruptures

In [None]:
# Run the function
# df = multi_state()
# data = np.load("plots/true_k_ruptures_example.npy")
# print(len(data))
# fig, ax = plot_k_prediction_with_ruptures(df, save_path='k_prediction_ruptures_new.svg')

# Run the function
# df = multi_state()
# fig, ax = plot_k_prediction_with_ruptures(df, save_path='k_prediction_ruptures.svg')
# print(f"{fig}, {ax}")

In [None]:
# # For alpha variable
# plot_time_series_prediction(time_points, alpha_pred, alpha_truth, 'alpha', 
#                           save_path='alpha_ruptures.svg')

# # For K variable
# plot_time_series_prediction(time_points, k_pred, k_truth, 'K',
#                           save_path='k_ruptures.svg')

# # For state variable
# plot_time_series_prediction(time_points, state_pred, state_truth, 'state',
#                           save_path='state_ruptures.svg')