In [4]:
from tqdm import tqdm
from utils.postprocessing import smooth_series, median_filter_1d
import ruptures as rpt
from sklearn.metrics import confusion_matrix
from andi_datasets.utils_challenge import *
import torch
from utils.pad_batch import FEATURE_PADDING_VALUE
from utils.features import getFeatures
from utils.postprocessing import replace_short_sequences, combined_cps_k_focused, combined_cps_k_focused_with_state
from models.ClassificationModel import ClassificationModel
from models.RegressionModel import RegressionModel
import torch.nn.functional as F
from andi_datasets.datasets_challenge import _get_dic_andi2, challenge_phenom_dataset
import random
import json
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
import matplotlib.colors as mcolors
from matplotlib.figure import Figure
from matplotlib.backends.backend_svg import FigureCanvasSVG
import seaborn as sns
import os
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from scipy import stats
import pandas as pd 

LABEL_PADDING_VALUE = 99
NUM_CLASSES = 4
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

ALPHA_COLOR = '#E69F00'
K_COLOR = '#1B9E77'
STATE_COLOR = '#9970AB'

AlphaModel = RegressionModel().to(DEVICE)
KModel = RegressionModel().to(DEVICE)
StateModel = ClassificationModel().to(DEVICE)

AlphaModel.load_state_dict(torch.load("models/optimal_weights/alpha_weights_with_fixed"))
KModel.load_state_dict(torch.load("models/optimal_weights/k_weights"))
StateModel.load_state_dict(torch.load("models/optimal_weights/state_weights"))

AlphaModel.eval()
KModel.eval()
StateModel.eval()

def padding_starts_index(array):
    padding_starts = (array == LABEL_PADDING_VALUE).argmax() 
    if padding_starts == 0:
        padding_starts = 200 
    return padding_starts

def getCP_pred_state(array):
   cps = [0]
   min_distance = 3
   array_len = len(array)

   for i in range(3, array_len-2):  # Start from 3 and end before last 2 positions
       # Check if this is a change point
       if array[i-1] != array[i]:
           # Check distance from last change point
           if i - cps[-1] >= min_distance:
               cps.append(i)

   return cps + [array_len]


def getCP_gt(array):
    cps = [0]
    for i in range(1, len(array)):
        if array[i-1] != array[i]:
            cps.append(i)

    return cps + [len(array)]



In [None]:
gt= [100, 109]
pred= [104, 130]

rmse, jaccard_value = single_changepoint_error(gt, pred)
print(rmse)

# Load Predictions and GroundTruth

In [None]:
# ROOT_DIR = "/home/haidiri/Desktop/AnDiChallenge2024/plots/results_for_plotting/single_state_alpha_fix_rest_results"
# ROOT_DIR = "/home/haidiri/Desktop/AnDiChallenge2024/plots/results_for_plotting/single_state_k_values_all_point_based_results"
ROOT_DIR = "/home/haidiri/Desktop/AnDiChallenge2024/plots/results_for_plotting/test_set"
# ROOT_DIR = "/home/haidiri/Desktop/AnDiChallenge2024/plots/results_for_plotting/single_state_length_values_results"

for file in os.listdir(ROOT_DIR):
    path = os.path.join(ROOT_DIR, file)

    if file == "gt_a.npy":
        gt_a = np.load(path)
    
    if file == "pred_a.npy":
        pred_a = np.load(path)
    
    if file == "gt_k.npy":
        gt_k = np.load(path)
    
    if file == "pred_k.npy":
        pred_k = np.load(path)
    
    if file == "gt_state.npy":
        gt_state = np.load(path)
    
    if file == "pred_state.npy":
        pred_state = np.load(path)

def getCP_rpt(array, lower_limit=0, upper_limit=float("inf"), threshold=0.05):
    array = median_filter_1d(smooth_series(array, lower_limit=lower_limit, upper_limit=upper_limit))
    if np.max(array) != np.min(array):
        pred_series_scaled = (array - np.min(array)) / (np.max(array) - np.min(array))
    else:
        pred_series_scaled = np.ones(len(array)) * 0.5 #scale them to default value of 0.5

    algo = rpt.Pelt(model="l2", min_size=3, jump=1).fit(pred_series_scaled)
    cps = [0] + algo.predict(pen=0.3)

    remove = []
    for i in range(1, len(cps) - 1):
        left_mean = array[cps[i - 1]:cps[i]].mean()
        right_mean = array[cps[i]:cps[i + 1]].mean()        
        if abs(left_mean - right_mean) < threshold:
            remove.append(cps[i])
    
    cps = [cp for cp in cps if cp not in remove]

    return cps, array

def count_changepoints(arr):
    # Initialize a counter for changepoints
    changepoints = 0
    
    # Loop through the array starting from the second element
    for i in range(1, len(arr)):
        # Check if the current element is different from the previous one
        if arr[i] != arr[i - 1]:
            changepoints += 1
    
    return changepoints

print(gt_a.shape)


def plot_jaccard_cps(d_cps, label_type="alpha"): 
    # Print results
    plt.close('all')  # Close all figures
    plt.clf()         # Clear current figure
    n_cps = []  # List for number of changepoints
    jaccard_values = []
    plt.clf()

    if label_type == "alpha":
        COLOR = ALPHA_COLOR
    elif label_type == "k":
        COLOR = K_COLOR
    elif label_type == "state":
        COLOR = STATE_COLOR 
    else:
        COLOR = "black"
    # Calculate average Jaccard value for each number of changepoints
    results = {}
    for key, items in d_cps.items():
        avg_jaccard = sum(items) / len(items)
        num_tracks = len(items)
        results[key] = {
            'average_jaccard': avg_jaccard,
            'num_tracks': num_tracks,
            'all_values': items
        }
        
    print("\nResults by number of changepoints:")
    for key in sorted(results.keys()):
        print(f"\nNumber of changepoints: {key}")
        print(f"Average Jaccard value: {results[key]['average_jaccard']:.4f}")
        print(f"Number of tracks: {results[key]['num_tracks']}")

    # Create lists for plotting
    for key in sorted(results.keys()):
        n_cps.append(int(key))
        jaccard_values.append(results[key]['average_jaccard'])

    # Save arrays for future use
    # np.save("n_cps_for_graph.npy", np.array(n_cps))
    np.save("jaccard_values_"+str(label_type)+"_for_graph.npy", np.array(jaccard_values))

    # Create figure with specific DPI for precise control
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    # Plot with both line and markers
    ax.plot(n_cps, jaccard_values, color=COLOR, 
            linewidth=2, alpha=1,
            marker='o', markersize=8,
            markerfacecolor='white',  # White fill
            markeredgecolor=COLOR,  # Green border
            markeredgewidth=2)

    # Set y-axis limits with a small margin
    ymin = min(jaccard_values)
    ymax = max(jaccard_values)
    margin = (ymax - ymin) * 0.1
    ax.set_ylim(ymin - margin, ymax + margin)

    # Set x-axis to show only integer values
    ax.set_xticks(n_cps)

    # Configure axis and style
    ax.set_axisbelow(True)

    # Configure tick parameters to match other plots
    ax.tick_params(which='both', direction='out', length=6, width=1,
                colors='black', pad=2)

    # Set tick label sizes
    ax.tick_params(axis='both', labelsize=32)

    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    # Set labels
    ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
    ax.set_ylabel(r'$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)

    # Adjust layout
    fig.tight_layout(pad=1.0)

    # Set transparent background
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')

    # Save as SVG
    canvas.print_figure('jaccard_vs_changepoints_'+str(label_type)+'.svg', bbox_inches='tight', 
                    pad_inches=0.1, format='svg')

    return 



def plot_jaccard_position(d_positions, label_type="alpha", bin_size=20):
    """Plot Jaccard values against changepoint positions with improved readability."""
    plt.close('all')
    plt.clf()
    
    # Set color based on label type
    if label_type == "alpha":
        COLOR = ALPHA_COLOR
    elif label_type == "k":
        COLOR = K_COLOR
    elif label_type == "state":
        COLOR = STATE_COLOR 
    else:
        COLOR = "black"
        
    # Calculate results
    results = {}
    for key, items in d_positions.items():
        if items:
            avg_jaccard = sum(items) / len(items)
            results[key] = {
                'average_jaccard': avg_jaccard,
                'num_tracks': len(items),
                'all_values': items
            }
    
    # Extract plotting values
    x_values = [int(key) + bin_size/2 for key in sorted(results.keys(), key=int)]
    y_values = [results[key]['average_jaccard'] for key in sorted(results.keys(), key=int)]
    
    # Save arrays
    np.save(f"jaccard_values_position_{label_type}_for_graph.npy", np.array(y_values))
    
    # Create figure
    fig = Figure(figsize=(10, 6), dpi=300)  # Increased width for better spacing
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Plot with both line and markers
    ax.plot(x_values, y_values, color=COLOR, 
            linewidth=2, alpha=1,
            marker='o', markersize=8,
            markerfacecolor='white',
            markeredgecolor=COLOR,
            markeredgewidth=2)
    
    # Set y-axis limits with margin
    ymin = min(y_values)
    ymax = max(y_values)
    margin = (ymax - ymin) * 0.1
    ax.set_ylim(ymin - margin, ymax + margin)
    
    # Set x-axis ticks with improved spacing
    bin_edges = list(range(0, 201, bin_size))
    ax.set_xticks(bin_edges)
    
    # Rotate x-axis labels for better readability
    ax.set_xticklabels(bin_edges, rotation=45, ha='right')
    
    # Configure axis and style
    ax.set_axisbelow(True)
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Configure tick parameters
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    ax.tick_params(axis='both', labelsize=32)
    
    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels with adjusted positioning
    ax.set_xlabel('Position', fontsize=32, labelpad=20)  # Increased labelpad
    ax.set_ylabel(r'$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)
    
    # Adjust layout with more padding at bottom
    fig.tight_layout(pad=1.0, rect=[0, 0.1, 1, 1])  # Add bottom padding
    
    # Set transparent background
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')
    
    # Save as SVG
    canvas.print_figure(f'jaccard_vs_position_{label_type}.svg', 
                       bbox_inches='tight', 
                       pad_inches=0.1, 
                       format='svg')
    
    return


def plot_mae_cps(d_cps, label_type="alpha"):
   plt.close('all')
   plt.clf()
   n_cps = []  
   mae_values = []

   COLOR = ALPHA_COLOR if label_type == "alpha" else K_COLOR if label_type == "k" else STATE_COLOR if label_type == "state" else "black"

   # Calculate average MAE for each number of changepoints
   results = {}
   for key, items in d_cps.items():
       avg_mae = sum(items) / len(items)
       results[key] = {
           'average_mae': avg_mae,
           'num_tracks': len(items),
           'all_values': items
       }

   # Create lists for plotting
   for key in sorted(results.keys()):
       n_cps.append(int(key))
       mae_values.append(results[key]['average_mae'])

   # Save arrays
   np.save(f"mae_values_{label_type}_for_graph.npy", np.array(mae_values))

   # Create figure
   fig = Figure(figsize=(8, 6), dpi=300)
   canvas = FigureCanvasSVG(fig)
   ax = fig.add_subplot(111)

   # Plot
   ax.plot(n_cps, mae_values, color=COLOR,
           linewidth=2, alpha=1,
           marker='o', markersize=8,
           markerfacecolor=COLOR,
           markeredgecolor=COLOR,
           markeredgewidth=2)

   # Axis limits
    # Axis limits and ticks
   if label_type == "k":
       ymin, ymax = 0, 0.06
       ax.set_ylabel(r'MSLE(K)', fontsize=32, rotation=90, va='center', labelpad=20)
   else:
       ymin, ymax = min(mae_values), max(mae_values)
       ax.set_ylabel(r'MAE(\alpha)', fontsize=32, rotation=90, va='center', labelpad=20)

    
   margin = (ymax - ymin) * 0.1
   ax.set_ylim(ymin - margin, ymax + margin)
    
   ax.set_xticks(n_cps)
   # Style
   ax.set_axisbelow(True)
   ax.tick_params(which='both', direction='out', length=6, width=1,
                 colors='black', pad=2, labelsize=32)

   for spine in ax.spines.values():
       spine.set_visible(True)
       spine.set_linewidth(1)
       spine.set_color('black')

   # Labels
   ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)

   # Layout
   fig.tight_layout(pad=1.0)
   fig.patch.set_facecolor('none')
   ax.set_facecolor('none')

   # Save
   canvas.print_figure(f'plots/mae_vs_changepoints_{label_type}.svg',
                      bbox_inches='tight', pad_inches=0.1, format='svg')
   
   return 


def plot_mae_cps_combined(d_cps_k, d_cps_alpha):
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    for data, label_type in [(d_cps_k, "k"), (d_cps_alpha, "alpha")]:
        n_cps = []
        mae_values = []
        COLOR = K_COLOR if label_type == "k" else ALPHA_COLOR

        results = {}
        for key, items in data.items():
            avg_mae = sum(items) / len(items)
            results[key] = {
                'average_mae': avg_mae,
                'num_tracks': len(items),
                'all_values': items
            }

        for key in sorted(results.keys()):
            n_cps.append(int(key))
            mae_values.append(results[key]['average_mae'])

        np.save(f"mae_values_{label_type}_for_graph.npy", np.array(mae_values))

        ax.plot(n_cps, mae_values, color=COLOR,
               linewidth=2, alpha=1,
               marker='o', markersize=8,
               markerfacecolor=COLOR,
               markeredgecolor=COLOR,
               markeredgewidth=2,
               label=r'MALE(K)' if label_type == "k" else r'MAE($\alpha$)')

    ax.set_ylim(0, 0.14)
    ax.set_yticks([0, 0.04, 0.08, 0.12])
    ax.set_xticks(n_cps)

    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
    # ax.set_ylabel('MAE', fontsize=32, rotation=90, va='center', labelpad=20)
    
    ax.legend(fontsize=28, frameon=False)

    fig.tight_layout(pad=1.0)
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')

    canvas.print_figure('plots/mae_vs_changepoints_combined_k_alpha_new.svg',
                       bbox_inches='tight', pad_inches=0.1, format='svg')
    
    return fig, ax

# Mean Absolute Error For K for Various CPs

In [None]:
d_cps_k = {str(i): [] for i in range(6)}  # Initialize with all possible changepoint numbers

predictions = pred_k.copy()
ground_truth = gt_k.copy()
tracks_per_cp = {str(i): 0 for i in range(6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    # if all(count >= 100 for count in tracks_per_cp.values()):
    #     break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    # if len(g) == 200: 
    if len(g)==200 and no_of_cps in range(6) and tracks_per_cp[str(no_of_cps)] < 10000:
        total_tracks += 1
        p = median_filter_1d(smooth_series(p, lower_limit=0, upper_limit=6))
        mae = np.mean(np.abs(p - g))

        d_cps_k[str(no_of_cps)].append(mae)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

# mean_mae = {k: np.mean(v) if v else 0 for k, v in d_cps.items()}
# print(mean_mae)
plot_mae_cps(d_cps_k, label_type="k")
# print("\nFinal counts:")
# for k, v in tracks_per_cp.items():
#     print(f"Changepoints {k}: {v} tracks")

# print("\nFinal d_cps lengths:")
# for k, v in d_cps.items():
#     print(f"Changepoints {k}: {len(v)} values")


# Mean Absolute Error Alpha for Various CPs

In [None]:
d_cps_alpha = {str(i): [] for i in range(6)}  # Initialize with all possible changepoint numbers

predictions = pred_a.copy()
ground_truth = gt_a.copy()
tracks_per_cp = {str(i): 0 for i in range(6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    # if all(count >= 100 for count in tracks_per_cp.values()):
    #     break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(6) and tracks_per_cp[str(no_of_cps)] < 10000:
        total_tracks += 1
        p = median_filter_1d(smooth_series(p, lower_limit=0, upper_limit=1.999))
        mae = np.mean(np.abs(p - g))

        d_cps_alpha[str(no_of_cps)].append(mae)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

# mean_mae = {k: np.mean(v) if v else 0 for k, v in d_cps.items()}
# print(mean_mae)
plot_mae_cps(d_cps_alpha)
# print("\nFinal counts:")
# for k, v in tracks_per_cp.items():
#     print(f"Changepoints {k}: {v} tracks")

# print("\nFinal d_cps lengths:")
# for k, v in d_cps.items():
#     print(f"Changepoints {k}: {len(v)} values")


In [None]:
plot_mae_cps_combined(d_cps_k, d_cps_alpha)

# Alpha for Various CPs

In [None]:
# d_cps = {}
# delta_alpha = []
# jaccard_values = []

# def count_changepoints(arr):
#     # Initialize a counter for changepoints
#     changepoints = 0
    
#     # Loop through the array starting from the second element
#     for i in range(1, len(arr)):
#         # Check if the current element is different from the previous one
#         if arr[i] != arr[i - 1]:
#             changepoints += 1
    
#     return changepoints


# predictions = pred_a.copy()
# ground_truth = gt_a.copy()
# tracks = 0

# for i in range(len(ground_truth)):

#     idx = padding_starts_index(ground_truth[i])

#     p = predictions[i][:idx]
#     g = ground_truth[i][:idx]
    
#     no_of_cps = count_changepoints(g)

#     if len(g) == 200 and no_of_cps in [0,1,2,3,4,5]:
#         tracks += 1
        
#         if tracks > 2000:
#             break

#         cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
#         cp_gt = getCP_gt(g)

#         cp_pred = cp_pred[1:-1]
#         cp_gt = cp_gt[1:-1]

#         if cp_gt == cp_pred:
#             # no need to call the function if they are the same
#             jaccard_value = 1
#         else:
#             rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred) 
#         if str(no_of_cps) not in d_cps:
#             d_cps[str(no_of_cps)] = [jaccard_value]
#         else:
#             d_cps[str(no_of_cps)].append(jaccard_value)

#         print(tracks, end="\r")

# print(d_cps)

In [None]:
d_cps = {str(i): [] for i in range(6)}  # Initialize with all possible changepoint numbers

predictions = pred_a.copy()
ground_truth = gt_a.copy()
tracks_per_cp = {str(i): 0 for i in range(6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(6) and tracks_per_cp[str(no_of_cps)] < 100:
        total_tracks += 1
        
        cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
        cp_gt = getCP_gt(g)

        cp_pred = cp_pred[1:-1]
        cp_gt = cp_gt[1:-1]

        if cp_gt == cp_pred:
            jaccard_value = 1
        else:
            rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(jaccard_value)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

# print("\nFinal counts:")
# for k, v in tracks_per_cp.items():
#     print(f"Changepoints {k}: {v} tracks")

# print("\nFinal d_cps lengths:")
# for k, v in d_cps.items():
#     print(f"Changepoints {k}: {len(v)} values")


In [None]:
plot_jaccard_cps(d_cps, label_type="alpha")

In [None]:
# gt_a_data = gt_a.copy()
# gt_k_data = gt_k.copy()
# gt_state_data = gt_state.copy()
total_cps = 0
special_cases = 0
no_of_tracks_with_this = 0
total_tracks = len(gt_a)

for i in range(len(gt_a)):    
    print(i, end="\r")
    idx_a = padding_starts_index(gt_a[i])
    g_alpha = gt_a[i][:idx_a]
    g_k = gt_k[i][:idx_a]
    g_state = gt_state[i][:idx_a]
    
    cp_a = set(getCP_gt(g_alpha))
    cp_k = set(getCP_gt(g_k))
    cp_s = set(getCP_gt(g_state))
    
    state_only_cps = cp_s - (cp_a | cp_k)

    if state_only_cps:  # If there are any such changepoints
        special_cases += len(state_only_cps)
        no_of_tracks_with_this += 1

    total_cps += len(cp_k) 

print("special tracks", no_of_tracks_with_this, "out of", total_tracks)
print("cp places", special_cases, " out of cps", total_cps)


# RMSE for all combinations

In [None]:
def plot_rmse_cps(d_cps, label_type="alpha"): 
    n_cps = []
    rmse_values = []

    if label_type == "alpha_without_0":
        COLOR = ALPHA_COLOR
    elif label_type == "k_without_0":
        COLOR = K_COLOR
    elif label_type == "state_without_0":
        COLOR = STATE_COLOR 
    else:
        COLOR = "black"

    results = {}
    for key, items in d_cps.items():
        valid_items = [x for x in items if not (np.isnan(x) or np.isinf(x))]   
        if valid_items:
            avg_rmse = sum(valid_items) / len(valid_items)
        else:
            avg_rmse = float('nan')  # or handle empty case differently

        results[key] = {
            'average_rmse': avg_rmse,
        }

    for key in sorted(results.keys()):
        n_cps.append(int(key))
        rmse_values.append(results[key]['average_rmse'])

    np.save("rmse_values_"+str(label_type)+"_for_graph.npy", np.array(rmse_values))

    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    ax.plot(n_cps, rmse_values, color=COLOR, 
            linewidth=2, alpha=1,
            marker='o', markersize=8,
            markerfacecolor='white',
            markeredgecolor=COLOR,
            markeredgewidth=2)

    ymin = min(rmse_values)
    ymax = max(rmse_values)
    margin = (ymax - ymin) * 0.1
    ax.set_ylim(ymin - margin, ymax + margin)

    ax.set_xticks(n_cps)
    ax.tick_params(which='both', direction='out', length=6, width=1,
                colors='black', pad=2, labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
    ax.set_ylabel('RMSE', fontsize=32, rotation=0, va='center', labelpad=20)

    fig.tight_layout(pad=1.0)
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')

    canvas.print_figure('rmse_vs_changepoints_'+str(label_type)+'.svg', 
                    bbox_inches='tight', pad_inches=0.1, format='svg')

    return

for alpha

In [None]:
d_cps = {str(i): [] for i in range(1, 6)}  # Initialize with all possible changepoint numbers

predictions = pred_a.copy()
ground_truth = gt_a.copy()
tracks_per_cp = {str(i): 0 for i in range(1, 6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(1, 6) and tracks_per_cp[str(no_of_cps)] < 100:
        total_tracks += 1
        
        cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
        cp_gt = getCP_gt(g)

        cp_pred = cp_pred[1:-1]
        cp_gt = cp_gt[1:-1]
        
        if cp_gt == cp_pred:
            rmse = 0
        else:
            rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(rmse)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

results = {}
for key, items in d_cps.items():
    valid_items = [x for x in items if not (np.isnan(x) or np.isinf(x))]   
    if valid_items:
        avg_rmse = sum(valid_items) / len(valid_items)
    else:
        avg_rmse = float('nan')  # or handle empty case differently

    results[key] = {
        'average_rmse': avg_rmse,
    }

print(results)

# plot_rmse_cps(d_cps, label_type="alpha_without_0")

for K

In [None]:
d_cps = {str(i): [] for i in range(1, 6)}  # Initialize with all possible changepoint numbers

predictions = pred_k.copy()
ground_truth = gt_k.copy()
tracks_per_cp = {str(i): 0 for i in range(1, 6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 1000 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(1, 6) and tracks_per_cp[str(no_of_cps)] < 100:
        total_tracks += 1

        cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=6, threshold=0.05)
        cp_gt = getCP_gt(g)

        cp_pred = cp_pred[1:-1]
        cp_gt = cp_gt[1:-1]

        if cp_gt == cp_pred:
            rmse = 0
        else:
            rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(rmse)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

plot_rmse_cps(d_cps, label_type="k_without_0")


for s

In [None]:
d_cps = {str(i): [] for i in range(1, 6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(1, 6)}  # For counting tracks per changepoint
total_tracks = 0

predictions = pred_state.copy()
ground_truth = gt_k.copy()

for gt, pd in zip(ground_truth, predictions):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
        
    gt = gt[gt!=LABEL_PADDING_VALUE]
    pd = pd[:len(gt)]
    
    # Only process sequences of length 200
    if len(gt) != 200:
        continue
        
    pd = replace_short_sequences(pd, min_length=3)
    
    cp_pred = getCP_pred_state(pd)[1:-1]
    cp_gt = getCP_gt(gt)[1:-1]
    no_of_cps = len(cp_gt)
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps in range(1, 6) and tracks_per_cp[str(no_of_cps)] < 100:

    # if no_of_cps >= 6 or no_of_cpstracks_per_cp[str(no_of_cps)] >= 100:
    #     continue
        
        total_tracks += 1
        # if no_of_cps == 5:
        #     print(cp_pred, cp_gt)
        
        if cp_gt == cp_pred:
            rmse = 0
        else:
            rmse, _ = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(rmse)
        tracks_per_cp[str(no_of_cps)] += 1
        
        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

plot_rmse_cps(d_cps, label_type="state_without_0")

for alpha + K

In [None]:
# Initialize dictionaries to store results and track counts
d_cps = {str(i): [] for i in range(1, 6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(1, 6)}  # For counting tracks per changepoint
total_tracks = 0

# Create copies of all predictions and ground truth arrays
pred_a_data = pred_a.copy()
gt_a_data = gt_a.copy()
pred_k_data = pred_k.copy()
gt_k_data = gt_k.copy()
pred_state_data = pred_state.copy()
gt_state_data = gt_state.copy()

for i in range(len(gt_a_data)):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
    
    # Get valid indices for all sequences
    idx_a = padding_starts_index(gt_a_data[i])
    
    # Get predictions and ground truth for each model
    p_alpha = pred_a_data[i][:idx_a]
    g_alpha = gt_a_data[i][:idx_a]
    
    p_k = pred_k_data[i][:idx_a]
    g_k = gt_k_data[i][:idx_a]
    
    p_state = pred_state_data[i][:idx_a]
    g_state = gt_state_data[i][:idx_a]
    
    # Only process if sequence length is 200
    if len(g_k) != 200:
        continue
    
    # Count changepoints in ground truth (using any of the ground truth series)
    no_of_cps = count_changepoints(g_k)
    no_of_cps_alpha = count_changepoints(g_alpha)

    if no_of_cps_alpha != no_of_cps:
        continue
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps in range(1, 6) and tracks_per_cp[str(no_of_cps)] < 100:
    
        total_tracks += 1
        
        # Get combined changepoints using the combined_cps function
        merged_cps, _, _, _, _, _ = combined_cps_k_focused_with_state(p_alpha, p_k, p_state)
        # merged_cps, _, _, _, _, _ = combined_cps_k_focused(p_alpha, p_k, p_state)
        
        merged_cps = [0] + merged_cps
        merged_cps = merged_cps[1:-1]
        
        # Get ground truth changepoints (using any ground truth series as they should be the same)
        cp_gt = getCP_gt(g_k)
        cp_gt = cp_gt[1:-1]  # Remove first and last points
        # Compare predicted and ground truth changepoints
        if cp_gt == merged_cps:
            rmse = 0
        else:
            rmse, _ = single_changepoint_error(cp_gt, merged_cps)
        
        # Store results
        d_cps[str(no_of_cps)].append(rmse)
        tracks_per_cp[str(no_of_cps)] += 1
        
        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

plot_rmse_cps(d_cps, label_type="alpha_k_state_focused_without_0")

for alpha + K + s

plot all

In [None]:
def plot_multiple_rmse_values():
    K_COLOR = '#1B9E77'
    ALPHA_COLOR = '#E69F00'
    STATE_COLOR = '#9970AB'
    ALPHA_K_FOCUSED_COLOR = 'black'
    AKS_COLOR = 'black'

    rmse_alpha_k_focused = np.load("rmse_values_alpha_k_focused_without_0_for_graph.npy")
    rmse_alpha_k_state_k_focused = np.load("rmse_values_alpha_k_state_focused_without_0_for_graph.npy")
    rmse_alpha = np.load("rmse_values_alpha_without_0_for_graph.npy")
    rmse_k = np.load("rmse_values_k_without_0_for_graph.npy")
    rmse_state = np.load("rmse_values_state_without_0_for_graph.npy")
    
    n_cps = list(range(1,6))
    
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    lines = [
        (rmse_k, K_COLOR, "K"),
        (rmse_alpha, ALPHA_COLOR, "α"),
        (rmse_state, STATE_COLOR, "s"),
        (rmse_alpha_k_state_k_focused, AKS_COLOR, "α+K+s", '--', 'white'),  # Added linestyle and markerfacecolor
        (rmse_alpha_k_focused, ALPHA_K_FOCUSED_COLOR, "α+K"),
    ]

    for item in lines:
        rmse_values, color, label, *style_args = item
        linestyle = style_args[0] if len(style_args) > 0 else '-'
        markerfacecolor = style_args[1] if len(style_args) > 1 else color
        
        ax.plot(n_cps, rmse_values, 
                color=color,
                linewidth=2,
                alpha=1,
                marker='o',
                markersize=10,
                markerfacecolor=markerfacecolor,
                markeredgecolor=color,
                markeredgewidth=2,
                linestyle=linestyle,
                label=label)
    
    all_values = np.concatenate([
        rmse_alpha_k_focused,
        rmse_alpha_k_state_k_focused,
        rmse_alpha,
        rmse_k,
        rmse_state
    ])
    valid_values = all_values[~(np.isnan(all_values) | np.isinf(all_values))]
    
    if len(valid_values) > 0:
        ymin = np.min(valid_values)
        ymax = np.max(valid_values)
        margin = (ymax - ymin) * 0.1
        ax.set_ylim(ymin - margin, ymax + 1.5* margin)
    
    ax.set_xticks(n_cps)
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
    ax.set_ylabel('RMSE', fontsize=32, rotation=90, va='center', labelpad=20)
    
    legend = ax.legend(fontsize=28, frameon=True,
                      loc='upper right',
                      bbox_to_anchor=(0.98, 0.98),
                      edgecolor='none',
                      facecolor='white',
                      framealpha=0.8)
    
    fig.tight_layout()
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')
    
    canvas.print_figure('all_methods_rmse_comparison_without_0.svg',
                       bbox_inches='tight',
                       pad_inches=0.1,
                       format='svg')
    
    return fig, ax

plot_multiple_rmse_values()

# Changepoint for Various Positions Alpha

In [None]:
# Initialize bins for different positions
bin_size = 20  # Assuming this is your bin size
num_bins = 200 // bin_size
d_positions = {str(i * bin_size): [] for i in range(num_bins)}
tracks_per_bin = {str(i * bin_size): 0 for i in range(num_bins)}
total_tracks = 0

predictions = pred_a.copy()
ground_truth = gt_a.copy()

for i in range(len(ground_truth)):
    if all(count >= 100 for count in tracks_per_bin.values()):
        break
    
    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    # Only process if sequence length is 200
    if len(g) == 200:
        
        cp_gt = getCP_gt(g)[1:-1]

        if len(cp_gt) == 1:

            cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
            cp_pred = cp_pred[1:-1]        

            cp_position = cp_gt[0]
            bin_start = str((cp_position // bin_size) * bin_size)
            
            # Only process if we need more tracks for this position bin
            if bin_start in tracks_per_bin and tracks_per_bin[bin_start] < 100:
                # Calculate Jaccard index
                if cp_gt == cp_pred:
                    jaccard_value = 1
                else:
                    rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
                
                # Store the result
                d_positions[bin_start].append(jaccard_value)
                tracks_per_bin[bin_start] += 1
                total_tracks += 1
                
                # Print progress
                print(f"Total tracks: {total_tracks} | Tracks per bin:", end=" ")
                for k, v in tracks_per_bin.items():
                    print(f"{k}:{v}", end=" ")
                print("", end="\r")

In [5]:
import pickle 
with open('changepoint_position_jaccard_value_alpha.pkl', 'wb') as f:
    pickle.dump(d_positions, f)

In [None]:
plot_jaccard_position(d_positions, label_type="alpha", bin_size=20)

In [None]:
# d_positions = {}  # Dictionary to store jaccard values by position bins
# bin_size = 20  # Size of position bins
# predictions = pred_a.copy()
# ground_truth = gt_a.copy()
# tracks = 0

# for i in range(len(ground_truth)):
#     idx = padding_starts_index(ground_truth[i])
#     p = predictions[i][:idx]
#     g = ground_truth[i][:idx]
    
#     no_of_cps = count_changepoints(g)

#     # Only look at tracks with length 200 and exactly 1 changepoint
#     if len(g) == 200 and no_of_cps == 1:
#         tracks += 1
        
#         if tracks > 2000:
#             break

#         cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=1.999, threshold=0.05)
#         cp_gt = getCP_gt(g)

#         # Calculate Jaccard index
#         cp_pred = cp_pred[1:-1]
#         cp_gt = cp_gt[1:-1]

#         position_bin = (cp_gt[0] - 1) // bin_size * bin_size  # Subtract 1 to start from 0
#         bin_label = str(position_bin)

#         if cp_gt == cp_pred:
#             jaccard_value = 1
#         else:
#             rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)

#         # Store in position-based dictionary
#         if bin_label not in d_positions:
#             d_positions[bin_label] = [jaccard_value]
#         else:
#             d_positions[bin_label].append(jaccard_value)

#         print(tracks, end="\r")

In [None]:
# tracks = 0
# for i in range(len(ground_truth)):
#     idx = padding_starts_index(ground_truth[i])
#     p = predictions[i][:idx]
#     g = ground_truth[i][:idx]
    
#     no_of_cps = count_changepoints(g)

#     # Only look at tracks with length 200 and exactly 1 changepoint
#     if len(g) == 200 and no_of_cps == 1:
#         tracks += 1
#         if tracks == 200:
#             plt.scatter([i for i in range(len(p))], p)  
#             plt.scatter([i for i in range(len(g))], g)
#             break
# plt.show()      

In [None]:
# Calculate average Jaccard value for each position bin
results = {}
for key, items in d_positions.items():
    avg_jaccard = sum(items) / len(items)
    num_tracks = len(items)
    results[key] = {
        'average_jaccard': avg_jaccard,
        'num_tracks': num_tracks,
        'all_values': items
    }
    
print("\nResults by position bin:")
for key in sorted(results.keys(), key=int):
    print(f"\nPosition bin {key}-{int(key)+bin_size}:")
    print(f"Average Jaccard value: {results[key]['average_jaccard']:.4f}")
    print(f"Number of tracks: {results[key]['num_tracks']}")

# Create figure
plt.figure(figsize=(10, 6))

# Extract x and y values for plotting
x_values = [int(key) + bin_size/2 for key in sorted(results.keys(), key=int)]  # Use bin centers
y_values = [results[key]['average_jaccard'] for key in sorted(results.keys(), key=int)]

# Create scatter plot
plt.scatter(x_values, y_values, color='blue', s=100)

# Add value labels above each point
for i, v in enumerate(y_values):
    plt.text(x_values[i], v + 0.01, f'{v:.3f}', ha='center')

plt.xlabel('Changepoint Position')
plt.ylabel('Average Jaccard Value')
plt.title('Average Jaccard Values by Changepoint Position For Alpha')
plt.grid(True, linestyle='--', alpha=0.7)

# Set x-ticks to show bin edges
bin_edges = range(0, 201, bin_size)
plt.xticks(bin_edges)

# Set y-axis limits to start from 0 and have some padding at the top
plt.ylim(0, max(y_values) + 0.1)

plt.tight_layout()
plt.savefig("jaccard_vs_position_test_set_alpha.svg", format="svg")
plt.show()

# Changepoints for Various Positions K

In [None]:
# Initialize bins for different positions
bin_size = 20
num_bins = 200 // bin_size
d_positions = {str(i * bin_size): [] for i in range(num_bins)}
tracks_per_bin = {str(i * bin_size): 0 for i in range(num_bins)}
total_tracks = 0

predictions = pred_k.copy()
ground_truth = gt_k.copy()

for i in range(len(ground_truth)):
    # Check if we have 100 tracks for all position bins
    if all(count >= 100 for count in tracks_per_bin.values()):
        break
    
    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    # Only process if sequence length is 200
    if len(g) == 200:
        
        cp_gt = getCP_gt(g)[1:-1]

        if len(cp_gt) == 1:
            cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=6, threshold=0.05)  # Changed upper_limit for k
            cp_pred = cp_pred[1:-1]        

            cp_position = cp_gt[0]
            bin_start = str((cp_position // bin_size) * bin_size)
            
            # Only process if we need more tracks for this position bin
            if bin_start in tracks_per_bin and tracks_per_bin[bin_start] < 100:
                # Calculate Jaccard index
                if cp_gt == cp_pred:
                    jaccard_value = 1
                else:
                    rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
                
                # Store the result
                d_positions[bin_start].append(jaccard_value)
                tracks_per_bin[bin_start] += 1
                total_tracks += 1
                
                # Print progress
                print(f"Total tracks: {total_tracks} | Tracks per bin:", end=" ")
                for k, v in tracks_per_bin.items():
                    print(f"{k}:{v}", end=" ")
                print("", end="\r")

In [13]:
import pickle 
with open('changepoint_position_jaccard_value_k.pkl', 'wb') as f:
    pickle.dump(d_positions, f)

In [None]:
plot_jaccard_position(d_positions, label_type="k", bin_size=20)

# Changepoints for Various Positions State

In [None]:
# Initialize bins for different positions
bin_size = 20
num_bins = 200 // bin_size
d_positions = {str(i * bin_size): [] for i in range(num_bins)}
tracks_per_bin = {str(i * bin_size): 0 for i in range(num_bins)}
total_tracks = 0

predictions = pred_state.copy()
ground_truth = gt_state.copy()

for gt, pd in zip(ground_truth, predictions):
    # Check if we have 100 tracks for all position bins
    if all(count >= 100 for count in tracks_per_bin.values()):
        break
        
    gt = gt[gt!=LABEL_PADDING_VALUE]
    pd = pd[:len(gt)]
    
    # Only process if sequence length is 200
    if len(gt) == 200:
        pd = replace_short_sequences(pd, min_length=3)
        
        cp_gt = getCP_gt(gt)[1:-1]

        if len(cp_gt) == 1:  # Only process sequences with exactly one changepoint
            cp_pred = getCP_pred_state(pd)[1:-1]

            cp_position = cp_gt[0]
            bin_start = str((cp_position // bin_size) * bin_size)
            
            # Only process if we need more tracks for this position bin
            if bin_start in tracks_per_bin and tracks_per_bin[bin_start] < 100:
                # Calculate Jaccard index
                if cp_gt == cp_pred:
                    jaccard_value = 1
                else:
                    rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
                
                # Store the result
                d_positions[bin_start].append(jaccard_value)
                tracks_per_bin[bin_start] += 1
                total_tracks += 1
                
                # Print progress
                print(f"Total tracks: {total_tracks} | Tracks per bin:", end=" ")
                for k, v in tracks_per_bin.items():
                    print(f"{k}:{v}", end=" ")
                print("", end="\r")

In [17]:
import pickle 
with open('changepoint_position_jaccard_value_state.pkl', 'wb') as f:
    pickle.dump(d_positions, f)

In [None]:
plot_jaccard_position(d_positions, label_type="state", bin_size=20)

# Changepoints for All variables Positions

In [10]:
import pickle

def plot_multiple_position_jaccard(d_positions_alpha, d_positions_k, d_positions_state, bin_size=20):
    plt.close('all')
    plt.clf()
    
    # Create figure
    fig = Figure(figsize=(10, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Function to process data
    def process_data(d_positions):
        results = {}
        for key, items in d_positions.items():
            if items:
                avg_jaccard = sum(items) / len(items)
                results[key] = {'average_jaccard': avg_jaccard}
        x_values = [int(key) + bin_size/2 for key in sorted(results.keys(), key=int)]
        y_values = [results[key]['average_jaccard'] for key in sorted(results.keys(), key=int)]
        return x_values, y_values
    
    # Process and plot each dataset
    datasets = [
        (d_positions_k, K_COLOR, "K"),
        (d_positions_alpha, ALPHA_COLOR, "α"),
        (d_positions_state, STATE_COLOR, "s")
    ]
    
    for data, color, label in datasets:
        x_values, y_values = process_data(data)
        ax.plot(x_values, y_values, 
                color=color,
                linewidth=2,
                alpha=1,
                marker='o',
                markersize=10,  # Increased marker size
                markerfacecolor=color,  # Fill the circles
                markeredgecolor=color,
                markeredgewidth=2,
                label=label)
    
    # Set y-axis limits
    ax.set_ylim(0, 1.0)
    
    # Set x-axis ticks
    bin_edges = list(range(0, 201, bin_size))
    ax.set_xticks(bin_edges)
    ax.set_xticklabels(bin_edges, rotation=45, ha='right')
    
    # Configure axis and style
    ax.set_axisbelow(True)
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Configure tick parameters
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    ax.tick_params(axis='both', labelsize=32)
    
    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels
    ax.set_xlabel('Position', fontsize=32, labelpad=20)
    ax.set_ylabel(r'$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)
    
    # Add legend
    legend = ax.legend(fontsize=28, 
                      frameon=True,
                      loc='lower left',
                      bbox_to_anchor=(0.02, 0.02),
                      edgecolor='none',
                      facecolor='white',
                      framealpha=0.8)
    
    # Adjust layout
    fig.tight_layout(pad=1.0, rect=[0, 0.1, 1, 1])
    
    # Set transparent background
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')
    
    # Save as SVG
    canvas.print_figure('position_jaccard_comparison_all.svg', 
                       bbox_inches='tight',
                       pad_inches=0.1,
                       format='svg')
    
    return fig, ax


def plot_multiple_position_jaccard_reformatted(d_positions_alpha, d_positions_k, d_positions_state, bin_size=20):
   # Create figure
   fig = Figure(figsize=(8, 6), dpi=300)
   canvas = FigureCanvasSVG(fig)
   ax = fig.add_subplot(111)
   
   def process_data(d_positions):
       results = {}
       for key, items in d_positions.items():
           if items:
               avg_jaccard = sum(items) / len(items)
               results[key] = {'average_jaccard': avg_jaccard}
       x_values = [int(key) + bin_size/2 for key in sorted(results.keys(), key=int)]
       y_values = [results[key]['average_jaccard'] for key in sorted(results.keys(), key=int)]
       return x_values, y_values
   
   datasets = [
       (d_positions_k, K_COLOR, "K"),
       (d_positions_alpha, ALPHA_COLOR, "α"),
       (d_positions_state, STATE_COLOR, "s")
   ]
   
   for data, color, label in datasets:
       x_values, y_values = process_data(data)
       ax.plot(x_values, y_values, 
               color=color,
               linewidth=2,
               alpha=1,
               marker='o',
               markersize=10,
               markerfacecolor=color,
               markeredgecolor=color,
               markeredgewidth=2,
               label=label)
   
   ax.set_ylim(0, 1.0)
   ax.set_yticks([0, 0.5, 1])
   ax.set_yticklabels(['0', '0.5', '1'])
   
#    bin_edges = list(range(0, 201, bin_size))
#    ax.set_xticks(bin_edges)
#    ax.set_xticklabels(['20', '', '', '', '100', '', '', '', '', '200'], ha='center')
   
   ax.set_xlim(0, 200)
   x_ticks = list(range(20, 201, 20))
   ax.set_xticks(x_ticks)
   ax.set_xticklabels(['20', '', '', '', '100', '', '', '', '', '200'])
   ax.set_axisbelow(True)
   
   ax.tick_params(which='both', direction='out', length=6, width=1,
                 colors='black', pad=2, labelsize=32)
   
   for spine in ax.spines.values():
       spine.set_visible(True)
       spine.set_linewidth(1)
       spine.set_color('black')
   
   ax.set_xlabel('Position', fontsize=32)
   ax.set_ylabel(r'$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)
   
   legend = ax.legend(fontsize=28,
                     frameon=True,
                     loc='upper right',
                     bbox_to_anchor=(0.98, 0.98),
                     edgecolor='none',
                     facecolor='white',
                     framealpha=0.8)
   
   fig.tight_layout()
   fig.patch.set_facecolor('none')
   ax.set_facecolor('none')
   
   canvas.print_figure('position_jaccard_comparison_all.svg',
                      bbox_inches='tight',
                      pad_inches=0.1,
                      format='svg')
   
   return fig, ax


# Load (when you need it later)
with open('plots/changepoint_position_jaccard_value_k.pkl', 'rb') as f:
    d_positions_k = pickle.load(f)

# Load (when you need it later)
with open('plots/changepoint_position_jaccard_value_state.pkl', 'rb') as f:
    d_positions_state = pickle.load(f)

# Load (when you need it later)
with open('plots/changepoint_position_jaccard_value_alpha.pkl', 'rb') as f:
    d_positions_alpha = pickle.load(f)


# fig, ax = plot_multiple_position_jaccard(d_positions_alpha, d_positions_k, d_positions_state)
fig, ax = plot_multiple_position_jaccard_reformatted(d_positions_alpha, d_positions_k, d_positions_state)



# K for Various CPs

In [None]:
d_cps = {str(i): [] for i in range(6)}  # Initialize with all possible changepoint numbers

predictions = pred_k.copy()
ground_truth = gt_k.copy()
tracks_per_cp = {str(i): 0 for i in range(6)}  # Track count for each number of changepoints
total_tracks = 0

for i in range(len(ground_truth)):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break

    idx = padding_starts_index(ground_truth[i])
    p = predictions[i][:idx]
    g = ground_truth[i][:idx]
    
    no_of_cps = count_changepoints(g)
    no_of_cps_alpha = count_changepoints(gt_a[i][:idx])

    if no_of_cps_alpha != no_of_cps:
        continue
    
    # Only process if we need more tracks for this number of changepoints
    if len(g) == 200 and no_of_cps in range(6) and tracks_per_cp[str(no_of_cps)] < 100:
        total_tracks += 1
        
        cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=6, threshold=0.05)  # Changed upper limit to 6 for K
        cp_gt = getCP_gt(g)

        cp_pred = cp_pred[1:-1]
        cp_gt = cp_gt[1:-1]

        if cp_gt == cp_pred:
            jaccard_value = 1
        else:
            rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
        d_cps[str(no_of_cps)].append(jaccard_value)
        tracks_per_cp[str(no_of_cps)] += 1

        # Print progress
        print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
        for k, v in tracks_per_cp.items():
            print(f"{k}:{v}", end=" ")
        print("", end="\r")

# print("\nFinal counts:")
# for k, v in tracks_per_cp.items():
#     print(f"Changepoints {k}: {v} tracks")

# print("\nFinal d_cps lengths:")
# for k, v in d_cps.items():
#     print(f"Changepoints {k}: {len(v)} values")

In [None]:
plot_jaccard_cps(d_cps, label_type="k_test")

In [None]:
# d_cps = {}
# predictions = pred_k.copy()
# ground_truth = gt_k.copy()
# tracks = 0

# for i in range(len(ground_truth)):

#     idx = padding_starts_index(ground_truth[i])

#     p = predictions[i][:idx]
#     g = ground_truth[i][:idx]
    
#     no_of_cps = count_changepoints(g)

#     if len(g) == 200 and no_of_cps in [0,1,2,3,4,5]:
#         tracks += 1
        
#         if tracks > 2000:
#             break

#         cp_pred, _ = getCP_rpt(p, lower_limit=0, upper_limit=6, threshold=0.05)
#         cp_gt = getCP_gt(g)

#         cp_pred = cp_pred[1:-1]
#         cp_gt = cp_gt[1:-1]

#         if cp_gt == cp_pred:
#             jaccard_value = 1
#         else:
#             rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred) 
#         if str(no_of_cps) not in d_cps:
#             d_cps[str(no_of_cps)] = [jaccard_value]
#         else:
#             d_cps[str(no_of_cps)].append(jaccard_value)

#         print(tracks, end="\r")

# print(d_cps)


In [None]:
# # Print results
# plt.close('all')  # Close all figures
# plt.clf()         # Clear current figure
# delta_k = []  # Clear your data lists
# jaccard_values = []

# plt.clf()

# # Calculate average Jaccard value for each number of changepoints
# results = {}
# for key, items in d_cps.items():
#     avg_jaccard = sum(items) / len(items)
#     num_tracks = len(items)
#     results[key] = {
#         'average_jaccard': avg_jaccard,
#         'num_tracks': num_tracks,
#         'all_values': items
#     }
    
# print("\nResults by number of changepoints:")
# for key in sorted(results.keys()):
#     print(f"\nNumber of changepoints: {key}")
#     print(f"Average Jaccard value: {results[key]['average_jaccard']:.4f}")
#     print(f"Number of tracks: {results[key]['num_tracks']}")

# # Optionally, create lists for plotting
# for key in sorted(results.keys()):
#     delta_alpha.append(int(key))
#     jaccard_values.append(results[key]['average_jaccard'])
# # You can now use delta_alpha and jaccard_values for plotting if needed
# # # Example plot:
# # plt.figure(figsize=(10, 6))
# # plt.plot(delta_alpha, jaccard_values, marker='o')
# # plt.xlabel('Number of Changepoints')
# # plt.ylabel('Average Jaccard Value')
# # plt.title('Jaccard Values vs Number of Changepoints for K')
# # plt.grid(True)
# # plt.savefig("jaccard_vs_changepoints_test_set_for_k_2.svg", format="svg")
# # plt.show()

In [None]:
# # Step 1: Create a function to bin delta_alpha into 0.05 intervals
# def bin_alpha(alpha, bin_size=0.05):
#     return round(alpha / bin_size) * bin_size

# # Step 2: Create a dictionary to store the Jaccard values for each binned delta_alpha
# jaccard_per_binned_alpha = defaultdict(list)

# # Step 3: Bin the delta_alpha values and populate the dictionary with corresponding Jaccard values
# for alpha, jaccard in zip(delta_alpha, jaccard_values):
#     binned_alpha = bin_alpha(alpha, bin_size=0.05)
#     jaccard_per_binned_alpha[binned_alpha].append(jaccard)

# # Step 4: Calculate the average Jaccard value and count the number of values contributing to each binned delta_alpha
# average_jaccard_binned = {alpha: np.mean(jaccards) for alpha, jaccards in jaccard_per_binned_alpha.items()}
# count_per_bin = {alpha: len(jaccards) for alpha, jaccards in jaccard_per_binned_alpha.items()}
# # Step 5: Sort binned delta_alpha values to ensure proper plotting
# sorted_binned_alpha = sorted(average_jaccard_binned.keys())
# sorted_jaccard = [average_jaccard_binned[alpha] for alpha in sorted_binned_alpha]
# sorted_counts = [count_per_bin[alpha] for alpha in sorted_binned_alpha]
# print(sorted_counts)

# # Step 6: Plot the data with count annotations
# plt.figure(figsize=(8, 6))
# plt.plot(sorted_binned_alpha, sorted_jaccard, marker='o', linestyle='-', color='b', label='Avg Jaccard Value')

# # Adding count labels for each bin on the plot
# for alpha, jaccard, count in zip(sorted_binned_alpha, sorted_jaccard, sorted_counts):
#     plt.text(alpha, jaccard, f'n={count}', fontsize=9, ha='center', va='bottom')

# # Adding labels and title
# plt.xlabel('Binned Delta Alpha (0.05 Spacing)', fontsize=14)
# plt.ylabel('Average Jaccard Value', fontsize=14)
# plt.title('Binned Delta Alpha vs Average Jaccard Value', fontsize=16)
# plt.grid(True)
# plt.legend()

# # Show the plot
# plt.show()


# State New Various CPs

In [None]:
# Initialize dictionaries to store results and track counts
d_cps = {str(i): [] for i in range(6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(6)}  # For counting tracks per changepoint
total_tracks = 0

predictions = pred_state.copy()
ground_truth = gt_k.copy()

for gt, pd in zip(ground_truth, predictions):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
        
    gt = gt[gt!=LABEL_PADDING_VALUE]
    pd = pd[:len(gt)]
    
    # Only process sequences of length 200
    if len(gt) != 200:
        continue
        
    pd = replace_short_sequences(pd, min_length=3)
    
    cp_pred = getCP_pred_state(pd)[1:-1]
    cp_gt = getCP_gt(gt)[1:-1]

    # Count the number of changepoints in ground truth
    no_of_cps = len(cp_gt)
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps >= 6 or tracks_per_cp[str(no_of_cps)] >= 100:
        continue
        
    total_tracks += 1
    
    if cp_gt == cp_pred:
        jaccard_value = 1
    else:
        rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
        
    d_cps[str(no_of_cps)].append(jaccard_value)
    tracks_per_cp[str(no_of_cps)] += 1
    
    # Print progress
    print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
    for k, v in tracks_per_cp.items():
        print(f"{k}:{v}", end=" ")
    print("", end="\r")

In [None]:
plot_jaccard_cps(d_cps, label_type="state_new")

# State for Various CPs

In [None]:
# Initialize dictionaries to store results and track counts
d_cps = {str(i): [] for i in range(6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(6)}  # For counting tracks per changepoint
total_tracks = 0

predictions = pred_state.copy()
ground_truth = gt_state.copy()

for gt, pd in zip(ground_truth, predictions):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
        
    gt = gt[gt!=LABEL_PADDING_VALUE]
    pd = pd[:len(gt)]
    
    # Only process sequences of length 200
    if len(gt) != 200:
        continue
        
    pd = replace_short_sequences(pd, min_length=3)
    
    cp_pred = getCP_pred_state(pd)[1:-1]
    cp_gt = getCP_gt(gt)[1:-1]
    
    # Count the number of changepoints in ground truth
    no_of_cps = len(cp_gt)
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps >= 6 or tracks_per_cp[str(no_of_cps)] >= 100:
        continue
        
    total_tracks += 1
    
    if cp_gt == cp_pred:
        jaccard_value = 1
    else:
        rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
        
    d_cps[str(no_of_cps)].append(jaccard_value)
    tracks_per_cp[str(no_of_cps)] += 1
    
    # Print progress
    print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
    for k, v in tracks_per_cp.items():
        print(f"{k}:{v}", end=" ")
    print("", end="\r")

In [None]:
plot_jaccard_cps(d_cps, label_type="state")

# Using all variable for changepoint detections

In [None]:
# Initialize dictionaries to store results and track counts
d_cps = {str(i): [] for i in range(6)}  # For storing Jaccard values
tracks_per_cp = {str(i): 0 for i in range(6)}  # For counting tracks per changepoint
total_tracks = 0

# Create copies of all predictions and ground truth arrays
pred_a_data = pred_a.copy()
gt_a_data = gt_a.copy()
pred_k_data = pred_k.copy()
gt_k_data = gt_k.copy()
pred_state_data = pred_state.copy()
gt_state_data = gt_state.copy()

for i in range(len(gt_a_data)):
    # Check if we have 100 tracks for all changepoint numbers
    if all(count >= 100 for count in tracks_per_cp.values()):
        break
    
    # Get valid indices for all sequences
    idx_a = padding_starts_index(gt_a_data[i])
    
    # Get predictions and ground truth for each model
    p_alpha = pred_a_data[i][:idx_a]
    g_alpha = gt_a_data[i][:idx_a]
    
    p_k = pred_k_data[i][:idx_a]
    g_k = gt_k_data[i][:idx_a]
    
    p_state = pred_state_data[i][:idx_a]
    g_state = gt_state_data[i][:idx_a]
    
    # Only process if sequence length is 200
    if len(g_k) != 200:
        continue
    
    # Count changepoints in ground truth (using any of the ground truth series)
    no_of_cps = count_changepoints(g_k)
    no_of_cps_alpha = count_changepoints(g_alpha)

    if no_of_cps_alpha != no_of_cps:
        continue
    
    # Skip if we already have enough tracks for this number of changepoints
    if no_of_cps >= 6 or tracks_per_cp[str(no_of_cps)] >= 100:
        continue
    
    total_tracks += 1
    
    # Get combined changepoints using the combined_cps function
    # combined_cps_k_focused_with_state
    # merged_cps, _, _, _, _, _ = combined_cps_with_state(p_alpha, p_k, p_state)
    merged_cps, _, _, _, _, _ = combined_cps_k_focused_with_state(p_alpha, p_k, p_state)

    # merged_cps, _, _, _, _, _ = combined_cps_k_focused(p_alpha, p_k, window_size=5)
    # merged_cps, _, _, _, _ = combined_cps(p_alpha, p_k)
    
    merged_cps = [0] + merged_cps
    merged_cps = merged_cps[1:-1]
    
    # Get ground truth changepoints (using any ground truth series as they should be the same)
    cp_gt = getCP_gt(g_k)
    cp_gt = cp_gt[1:-1]  # Remove first and last points
    
    # Compare predicted and ground truth changepoints
    if cp_gt == merged_cps:
        jaccard_value = 1
        rmse = 0
    else:
        rmse, jaccard_value = single_changepoint_error(cp_gt, merged_cps)
    
    # Store results
    d_cps[str(no_of_cps)].append(jaccard_value)
    tracks_per_cp[str(no_of_cps)] += 1
    
    # Print progress
    print(f"Total tracks: {total_tracks} | Tracks per CP:", end=" ")
    for k, v in tracks_per_cp.items():
        print(f"{k}:{v}", end=" ")
    print("", end="\r")

# Print final statistics
print("\nFinal counts:")
for k, v in tracks_per_cp.items():
    print(f"Changepoints {k}: {v} tracks")

print("\nFinal d_cps lengths:")
for k, v in d_cps.items():
    print(f"Changepoints {k}: {len(v)} values")

In [None]:
plot_jaccard_cps(d_cps, label_type="alpha_k_state_k_focused")

In [None]:
# def plot_multiple_jaccard_values():
#     # Define colors
#     ALPHA_COLOR = '#E69F00'  # Orange
#     K_COLOR = '#1B9E77'      # Green
#     STATE_COLOR = '#9970AB'  # Purple
#     ALL_EXCEPT_STATE_COLOR = 'black'  # Dark orange/red
#     ALL_COLOR = '#56B4E9'    # Blue

#     # Load all numpy arrays
#     j_alpha_k_focused = np.load("jaccard_values_alpha_and_k_focused_for_graph.npy")
#     j_alpha_k = np.load("jaccard_values_alpha_and_k_for_graph.npy")

#     j_alpha_k_state_k_focused = np.load("jaccard_values_alpha_k_state_k_focused_for_graph.npy")
#     j_alpha_k_state = np.load("jaccard_values_alpha_k_state_for_graph.npy")

#     j_alpha = np.load("jaccard_values_alpha_for_graph.npy")
#     j_k = np.load("jaccard_values_k_for_graph.npy")
#     j_state = np.load("jaccard_values_state_for_graph.npy")
    
#     # Create x-axis values
#     n_cps = [i for i in range(6)]
    
#     # Create figure with specific DPI
#     fig = Figure(figsize=(8, 6), dpi=300)
#     canvas = FigureCanvasSVG(fig)
#     ax = fig.add_subplot(111)
    
#     # Plot each line with different colors
#     lines = [
#         (j_k, K_COLOR, "K"),
#         (j_alpha, ALPHA_COLOR, "α"),
#         (j_state, STATE_COLOR, "State"),
#         (j_all, ALL_EXCEPT_STATE_COLOR, "α+K"),
#         (j_all_state, ALL_COLOR, "α+K+State")
#     ]
    
#     for jaccard_values, color, label in lines:
#         ax.plot(n_cps, jaccard_values, 
#                 color=color,
#                 linewidth=2,
#                 alpha=1,
#                 marker='o',
#                 markersize=8,
#                 markerfacecolor='white',
#                 markeredgecolor=color,
#                 markeredgewidth=2,
#                 label=label)
    
#     # Set y-axis limits with margin
#     all_values = np.concatenate([j_all, j_all_state, j_alpha, j_k, j_state])
#     ymin = np.min(all_values)
#     ymax = np.max(all_values)
#     margin = (ymax - ymin) * 0.1
#     ax.set_ylim(ymin - margin, ymax + margin)
    
#     # Set x-axis to show only integer values
#     ax.set_xticks(n_cps)
    
#     # Configure axis and style
#     ax.set_axisbelow(True)
    
#     # Configure tick parameters
#     ax.tick_params(which='both', direction='out', length=6, width=1,
#                   colors='black', pad=2)
    
#     # Set tick label sizes
#     ax.tick_params(axis='both', labelsize=32)
    
#     # Configure spines
#     for spine in ax.spines.values():
#         spine.set_visible(True)
#         spine.set_linewidth(1)
#         spine.set_color('black')
    
#     # Set labels
#     ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
#     ax.set_ylabel(r'$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)
    
#     # Add legend with same font size as labels
#     # Add legend inside the plot
#     legend = ax.legend(fontsize=24, frameon=True, 
#                       loc='upper right',
#                       bbox_to_anchor=(0.98, 0.98),
#                       edgecolor='none',
#                       facecolor='white',
#                       framealpha=0.8)
    
#     # Adjust layout to accommodate legend
#     fig.tight_layout(rect=[0, 0, 0.85, 1])
    
#     # Set transparent background
#     fig.patch.set_facecolor('none')
#     ax.set_facecolor('none')
    
#     # Save as SVG
#     canvas.print_figure('combined_jaccard_comparison.svg', 
#                        bbox_inches='tight',
#                        pad_inches=0.1,
#                        format='svg')
    
#     return fig, ax

# plot_multiple_jaccard_values()

def plot_multiple_jaccard_values():

    j_alpha_k_focused = np.load("plots/jaccard_values_alpha_and_k_focused_for_graph.npy")
    # j_alpha_k = np.load("jaccard_values_alpha_and_k_for_graph.npy")
    j_alpha_k_state_k_focused = np.load("plots/jaccard_values_alpha_k_state_k_focused_for_graph.npy")
    # j_alpha_k_state = np.load("jaccard_values_alpha_k_state_for_graph.npy")
    j_alpha = np.load("plots/jaccard_values_alpha_for_graph.npy")
    j_k = np.load("plots/jaccard_values_k_for_graph.npy")
    j_state = np.load("plots/jaccard_values_state_new_for_graph.npy")

    # Create x-axis values
    n_cps = [i for i in range(6)]

    # Create figure with specific DPI
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    lines = [
        (j_k, K_COLOR, "K"),
        (j_alpha, ALPHA_COLOR, "α"),
        (j_state, STATE_COLOR, "s"),
        (j_alpha_k_state_k_focused, "black", "α+K+s", '--', 'white'),
        (j_alpha_k_focused, "black", "α+K")
        ]

    for item in lines:
        jaccard_values, color, label, *style_args = item
        linestyle = style_args[0] if len(style_args) > 0 else '-'
        markerfacecolor = style_args[1] if len(style_args) > 1 else color

        ax.plot(n_cps, jaccard_values,
        color=color,
        linewidth=2,
        alpha=1,
        marker='o',
        markersize=10,
        markerfacecolor=markerfacecolor,
        markeredgecolor=color,
        markeredgewidth=2,
        linestyle=linestyle,
        label=label)

    # Set y-axis limits with margin
    all_values = np.concatenate([j_alpha_k_focused, 
    # j_alpha_k, 
    j_alpha_k_state_k_focused, 
    # j_alpha_k_state, 
    j_alpha, 
    j_k, 
    j_state])
    ymin = np.min(all_values)
    ymax = np.max(all_values)
    margin = (ymax - ymin) * 0.1

    # ax.set_ylim(ymin - margin, ymax + margin)
    ax.set_ylim(0,1)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '0.5', '1'])
    # Set x-axis to show only integer values
    ax.set_xticks(n_cps)
    

    # Configure axis and style
    ax.set_axisbelow(True)

    # Configure tick parameters
    ax.tick_params(which='both', direction='out', length=6, width=1,
    colors='black', pad=2)

    # Set tick label sizes
    ax.tick_params(axis='both', labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
    ax.set_ylabel(r'$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)

    legend = ax.legend(fontsize=28, frameon=True,
    loc='upper right',
    bbox_to_anchor=(0.98, 0.98),
    edgecolor='none',
    facecolor='white',
    framealpha=0.8)

    fig.tight_layout()
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')

    canvas.print_figure('all_methods_jaccard_comparison_new.svg',
    bbox_inches='tight',
    pad_inches=0.1,
    format='svg')

    return fig, ax

plot_multiple_jaccard_values()

# CPs for different models

In [None]:
import numpy as np
from tqdm import tqdm
from collections import defaultdict

# Define models in the correct order
MODELS = ['imm.', 'conf.', 'free', 'dir.']
NO_OF_CPS_MAX = 5
TRACKS_PER_CATEGORY = 100
# Constants
d_indices = {model: [[] for _ in range(NO_OF_CPS_MAX + 1)] for model in MODELS}

# Get masks for valid tracks
sequence_lengths = (gt_k != LABEL_PADDING_VALUE).sum(axis=1)
length_mask = sequence_lengths == 200

changes = gt_k[:, 1:] != gt_k[:, :-1]
changepoints = changes.sum(axis=1)
changepoint_mask = (changepoints >= 0) & (changepoints <= NO_OF_CPS_MAX)

final_mask = changepoint_mask & length_mask
valid_indices = np.where(final_mask)[0]

# First pass: categorize all valid tracks
for idx in valid_indices:
    cp_gt = np.where(gt_k[idx, 1:] != gt_k[idx, :-1])[0] + 1
    no_of_cps = len(cp_gt)
    states = gt_state[idx]
    
    # Determine model
    if 1 in states:
        model = "conf."
    elif 0 in states:
        model = "imm."
    elif 3 in states:
        model = "dir."
    else:
        model = "free"
        
    d_indices[model][no_of_cps].append(idx)

# Initialize result dictionary
d_cps = {model: [[] for _ in range(NO_OF_CPS_MAX + 1)] for model in MODELS}

# Process tracks with balanced sampling
progress_bar = tqdm(desc="Processing tracks", total=len(MODELS) * (NO_OF_CPS_MAX + 1) * TRACKS_PER_CATEGORY)

for model in MODELS:
    for no_of_cps in range(NO_OF_CPS_MAX + 1):
        available_indices = d_indices[model][no_of_cps]
        
        # If we don't have enough tracks, sample with replacement
        if len(available_indices) == 0:
            continue
            
        indices_to_process = np.random.choice(
            available_indices,
            size=TRACKS_PER_CATEGORY,
            replace=len(available_indices) < TRACKS_PER_CATEGORY
        )
        
        for idx in indices_to_process:

            cp_gt = np.where(gt_k[idx, 1:] != gt_k[idx, :-1])[0] + 1
            pk = pred_k[idx]
            pa = pred_a[idx]
            ps = pred_state[idx]
            
            # Calculate changepoints and Jaccard
            # cp_pred, _ = getCP_rpt(pk, lower_limit=0, upper_limit=6, threshold=0.05)
            cp_pred, _, _, _, _, _ = combined_cps_k_focused(pa, pk, ps)
            cp_pred = [0] + cp_pred
            cp_pred = cp_pred[1:-1]
            
            if list(cp_gt) == list(cp_pred):
                jaccard_value = 1
            else:
                rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
            
            d_cps[model][no_of_cps].append(jaccard_value)
            progress_bar.update()

progress_bar.close()

# Print statistics about the sampling
print("\nSampling Statistics:")
print("-" * 50)
for model in MODELS:
    print(f"\nModel: {model}")
    for no_of_cps in range(NO_OF_CPS_MAX + 1):
        original_count = len(d_indices[model][no_of_cps])
        sampled_count = len(d_cps[model][no_of_cps])
        print(f"  CPs={no_of_cps}: Original={original_count}, Sampled={sampled_count}")
        if original_count < TRACKS_PER_CATEGORY and original_count > 0:
            print(f"    (Sampled with replacement to reach target)")

In [44]:
import json

# Save
with open('jaccard_for_models.json', 'w') as f:
    json.dump(d_cps, f)


In [8]:
import json

with open('jaccard_for_models.json', 'r') as f:
   d_cps = json.load(f)

In [None]:
for model in MODELS:
    print(f"\nModel: {model}")
    for no_of_cps in range(NO_OF_CPS_MAX + 1):
        sampled_count = len(d_cps[model][no_of_cps])
        print(f"  CPs={no_of_cps}: Original={original_count}, Sampled={sampled_count}")
        if original_count < TRACKS_PER_CATEGORY and original_count > 0:
            print(f"    (Sampled with replacement to reach target)")

In [None]:
def plot_model_jaccard(d_cps):
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    n_cps = list(range(NO_OF_CPS_MAX + 1))
    
    for model in MODELS:
        jaccard_values = [np.mean(cps) if cps else 0 for cps in d_cps[model]]
        line = ax.plot(n_cps, jaccard_values,
                color='black',
                linewidth=2,
                linestyle='-',
                marker='o',
                markersize=8,
                markerfacecolor='black',
                markeredgecolor='black',
                markeredgewidth=2)[0]
                
        # Add label at the end of the line
        end_x = n_cps[-1]
        end_y = jaccard_values[-1]
        ax.text(end_x + 0.1, end_y, model, fontsize=28, va='center')

    ax.set_ylim(0, 1.05)
    ax.set_xlim(-0.1, NO_OF_CPS_MAX + 1)  # Extended x-axis for labels
    ax.set_xticks(n_cps)
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    ax.set_xlabel('$N_{\mathrm{CP}}$', fontsize=32)
    ax.set_ylabel(r'$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)
    
    fig.tight_layout(pad=1.0)
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')

    canvas.print_figure('jaccard_vs_models.svg',
                       bbox_inches='tight', pad_inches=0.1, format='svg')
    
    return fig, ax


plot_model_jaccard(d_cps)

In [None]:
from matplotlib.figure import Figure
from matplotlib.backends.backend_svg import FigureCanvasSVG
import numpy as np

# Constants
FONT_SIZE = 32
BAR_WIDTH = 0.25
LINE_WIDTH = 2
GREY_COLOR = '0.8'

# Define models in the correct order
MODELS = ['imm.', 'conf.', 'free', 'dir.']

# Calculate means for each model and changepoint number
means = {}
for model in MODELS:
    means[model] = [np.mean(d_cps[model][i]) for i in range(NO_OF_CPS_MAX + 1)]

# Create figure using Figure and Canvas
fig = Figure(figsize=(10, 6), dpi=300)
canvas = FigureCanvasSVG(fig)
ax = fig.add_subplot(111)

# Create positions for bars
x = np.arange(len(MODELS))
all_bar_positions = []
valid_cp_positions = []  # Store positions where we actually have bars
valid_cp_labels = []     # Store labels for valid positions

# Plot bars and collect valid positions
for i in range(NO_OF_CPS_MAX + 1):
    for j, model in enumerate(MODELS):
        pos = x[j] + (i - 1)*BAR_WIDTH
        value = means[model][i]
        
        if not np.isnan(value):
            if i == 0:
                # Dashed bars for 0 changepoints
                ax.bar(pos, value, BAR_WIDTH, 
                      color='white',
                      edgecolor='black',
                      linewidth=LINE_WIDTH,
                      hatch='///')
            elif i == 1:
                # Grey bars for 1 changepoint
                ax.bar(pos, value, BAR_WIDTH,
                      color=GREY_COLOR,
                      edgecolor='black',
                      linewidth=LINE_WIDTH)
            else:
                # Black bars for 2 changepoints
                ax.bar(pos, value, BAR_WIDTH,
                      color='black',
                      edgecolor='black',
                      linewidth=LINE_WIDTH)
            
            valid_cp_positions.append(pos)
            valid_cp_labels.append(str(i))

# Configure spines
for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1)
    spine.set_color('black')

# Labels
ax.set_xlabel('Number of Changepoints by Model', fontsize=FONT_SIZE, labelpad=40)
ax.set_ylabel(r'$\overline{J}$', fontsize=FONT_SIZE, labelpad=60, rotation=0)

# Set CP number ticks
ax.set_xticks(valid_cp_positions)
ax.set_xticklabels(valid_cp_labels, fontsize=FONT_SIZE-10)

# Add model labels under each group
model_centers = []
for j, model in enumerate(MODELS):
    valid_positions = []
    for i in range(NO_OF_CPS_MAX + 1):
        if not np.isnan(means[model][i]):
            valid_positions.append(x[j] + (i - 1)*BAR_WIDTH)
    if valid_positions:
        model_centers.append(np.mean(valid_positions))

ax.set_xticks(model_centers, minor=True)
ax.set_xticklabels(MODELS, minor=True)
ax.tick_params(axis='x', which='minor', pad=30, labelsize=FONT_SIZE)

# Configure other ticks
ax.tick_params(which='both', direction='out', length=6, width=1, colors='black', pad=2, labelsize=FONT_SIZE)

# Set y-axis limits
ax.set_ylim(0, 1)

# Set transparent background
fig.patch.set_facecolor('none')
ax.set_facecolor('none')

# Add grid
ax.grid(True, linestyle='--', alpha=0.3, color='gray')
ax.set_axisbelow(True)

# Layout adjustments
fig.tight_layout(pad=1.2)

# Save using canvas.print_figure
canvas.print_figure('jaccard_by_model_bars.svg', bbox_inches='tight', pad_inches=0.15, format='svg')

In [None]:
from matplotlib.figure import Figure
from matplotlib.backends.backend_svg import FigureCanvasSVG
import numpy as np

# Constants
FONT_SIZE = 32
MODELS = ['$N_{\mathrm{CP}}$']

# Create figure using Figure and Canvas
fig = Figure(figsize=(10, 6), dpi=300)
canvas = FigureCanvasSVG(fig)
ax = fig.add_subplot(111)

# Create x positions
x = np.arange(len(MODELS))

# Configure spines
for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1)
    spine.set_color('black')

# Configure ticks
ax.set_xticks(x)
ax.set_xticklabels(MODELS, fontsize=FONT_SIZE)
ax.tick_params(axis='y', labelsize=FONT_SIZE)

# Layout adjustments
fig.tight_layout(pad=1.2)

# Save using canvas.print_figure
canvas.print_figure('empty_model_plot.svg',
                   bbox_inches='tight',
                   pad_inches=0.15,
                   format='svg')

In [None]:
# no_of_cps_max = 2

# MODELS = ["bound", "single", "confined", "multi"]
# d_cps = {model: [[] for _ in range(no_of_cps_max + 1)] for model in MODELS}

# sequence_lengths = (gt_k != LABEL_PADDING_VALUE).sum(axis=1)
# length_mask = sequence_lengths == 200

# changes = gt_k[:, 1:] != gt_k[:, :-1]
# changepoints = changes.sum(axis=1)
# changepoint_mask = (changepoints >= 0) & (changepoints <= no_of_cps_max)

# final_mask = (changepoint_mask) & (length_mask)
# final_indices = np.where(final_mask)[0]

# progress_bar = tqdm(desc="Looping", total=len(final_indices))

# for i in final_indices:

#     cp_gt = np.where(gt_k[i, 1:] != gt_k[i, :-1])[0] + 1
#     no_of_cps = len(cp_gt)
#     states = gt_state[i]
#     pk = pred_k[i]

#     if 1 in states:
#         model = "confined"
#     elif 0 in states:
#         model = "bound"
#     elif no_of_cps == 0:
#         model = "single"
#     else:
#         model = "multi"
    
#     cp_pred, _ = getCP_rpt(pk, lower_limit=0, upper_limit=6, threshold=0.05)
#     cp_pred = cp_pred[1:-1]
    
#     # Calculate Jaccard similarity
#     if list(cp_gt) == list(cp_pred):
#         jaccard_value = 1
#     else:
#         rmse, jaccard_value = single_changepoint_error(cp_gt, cp_pred)
    
#     # Store results
#     d_cps[model][no_of_cps].append(jaccard_value)
#     progress_bar.update()

# progress_bar.close()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Calculate means for each model and changepoint number
means = {}
for model in MODELS:
    means[model] = [np.mean(d_cps[model][i]) for i in range(no_of_cps_max + 1)]

# Set up the plot with publication quality
plt.style.use('seaborn-v0_8-paper')
fig, ax = plt.subplots(figsize=(10, 6), dpi=300)

# Plot settings
markers = ['o', 's', '^', 'D']  # Different marker for each model
ms = 8  # Marker size
lw = 1.5  # Line width

# Create x positions for each model
x = np.arange(len(MODELS))
width = 0.2  # Space between points for different changepoint numbers

# Plot points for each number of changepoints
for cp_num in range(no_of_cps_max + 1):
    y_values = [means[model][cp_num] for model in MODELS]
    ax.plot(x + (cp_num - 1)*width, y_values, 
            marker=markers[cp_num], 
            color='black',
            label=f'{cp_num} CP', 
            linestyle='none',
            markersize=ms,
            markeredgewidth=lw)

# Styling
ax.grid(True, linestyle='--', alpha=0.3)
ax.set_axisbelow(True)

# Configure spines
for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1)
    spine.set_color('black')

# Labels and ticks
ax.set_xlabel('Model', fontsize=16)
ax.set_ylabel('Average Jaccard Index', fontsize=16)
ax.tick_params(which='both', direction='out', length=6, width=1,
              colors='black', pad=2, labelsize=16)

# Set x-ticks at model positions
ax.set_xticks(x)
ax.set_xticklabels(MODELS)

# Set y-axis limits
ax.set_ylim(0, 1)

# Legend
legend = ax.legend(title='Changepoints',
                  fontsize=14,
                  title_fontsize=14,
                  frameon=True,
                  edgecolor='black',
                  loc='upper right')
legend.get_frame().set_alpha(1)

# Layout
plt.tight_layout()

# Save figure
# plt.savefig("jaccard_by_model.svg", format="svg", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Print results
model = []  # Clear your data lists
jaccard_values = []
# Calculate average Jaccard value for each number of changepoints
results = {}
for key, items in d_cps.items():
    avg_jaccard = sum(items) / len(items)
    num_tracks = len(items)
    results[key] = {
        'average_jaccard': avg_jaccard,
        'num_tracks': num_tracks,
        'all_values': items
    }
    
print("\nResults by number of changepoints:")
for key in sorted(results.keys()):
    print(f"\nNumber of changepoints: {key}")
    print(f"Average Jaccard value: {results[key]['average_jaccard']:.4f}")
    print(f"Number of tracks: {results[key]['num_tracks']}")

# Optionally, create lists for plotting
for key in sorted(results.keys()):
    delta_alpha.append(int(key))
    jaccard_values.append(results[key]['average_jaccard'])
# You can now use delta_alpha and jaccard_values for plotting if needed
# Example plot:
plt.figure(figsize=(10, 6))
plt.plot(delta_alpha, jaccard_values, marker='o')
plt.xlabel('Number of Changepoints')
plt.ylabel('Average Jaccard Value')
plt.title('Jaccard Values vs Number of Changepoints for K')
plt.grid(True)
plt.savefig("jaccard_vs_changepoints_test_set_for_k_2.svg", format="svg")
plt.show()

# Error in K and Alpha Per State/Model PubLication Level

- Given the ground truth state, works out error in K and alpha

In [66]:
errors_k = {'imm.': [], 'conf.': [], 'free': [], 'dir.': []}
errors_a = {'imm.': [], 'conf.': [], 'free': [], 'dir.': []}

# Create mask for valid entries
valid_mask = np.all(gt_state != LABEL_PADDING_VALUE, axis=1)

pk2 = np.clip(pred_k, 0, 6)
pa2 = np.clip(pred_a, 0, 2)

# Calculate absolute errors
k_error = np.abs(gt_k - pk2).sum(axis=1)/200
a_error = np.abs(gt_a - pa2).sum(axis=1)/200

# Group errors by state
for idx in range(len(gt_state)):
    if valid_mask[idx]:
        states = gt_state[idx]
        k_err = k_error[idx]
        a_err = a_error[idx]
        
        if 1 in states:
            errors_k['conf.'].append(k_err)
            errors_a['conf.'].append(a_err)
        elif 0 in states:
            errors_k['imm.'].append(k_err)
            errors_a['imm.'].append(a_err)
        elif 3 in states:
            errors_k['dir.'].append(k_err)
            errors_a['dir.'].append(a_err)
        else:
            errors_k['free'].append(k_err)
            errors_a['free'].append(a_err)

# Calculate MAE for each state
mae_k = {state: np.mean(errs) for state, errs in errors_k.items()}
mae_a = {state: np.mean(errs) for state, errs in errors_a.items()}


In [None]:
def plot_errors_by_model(mae_k, mae_a):
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    models = list(mae_k.keys())
    x_pos = np.arange(len(models))
    
    # Plot points
    k_scatter = ax.plot(x_pos, list(mae_k.values()),
            color=K_COLOR,
            marker='o',
            linewidth=2,
        #     linestyle = "--",
            label='MSLE(K)')
    
    a_scatter = ax.plot(x_pos, list(mae_a.values()),
            color=ALPHA_COLOR,
            marker='o',
            linewidth=2,
        #     linestyle = "--",
            label=r'MAE($\alpha$)')
    
    ax.set_xticks(x_pos)
    ax.set_xticklabels(models, fontsize=32)
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

#     ax.set_xlim(-0.5, len(models) + 1)
    ax.set_ylim(-0.005, 0.16)  # Increased upper limit to fit legend
    ax.set_yticks([0, 0.04, 0.08, 0.12])
    
    legend = ax.legend(fontsize=28, frameon=False, 
                    loc='upper center',
                    ncol=2,
                    labelcolor='black')

    fig.tight_layout(pad=1.0)
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')

    canvas.print_figure('errors_by_model_new.svg',
                       bbox_inches='tight', pad_inches=0.1, format='svg')
    
    return fig, ax

plot_errors_by_model(mae_k, mae_a)

In [None]:
# # Initialize error lists
# errors_k = [[] for _ in range(NUM_CLASSES)]
# errors_alpha = [[] for _ in range(NUM_CLASSES)]

# # Clip predictions
# # pred_k = np.clip(pred_k, 0, 6)
# # pred_a = np.clip(pred_a, 0, 2)

# pk = median_filter_1d(smooth_series(pred_k, lower_limit=0, upper_limit=6))
# pa = median_filter_1d(smooth_series(pred_a, lower_limit=0, upper_limit=1.999))

# # Calculate errors
# k_error = gt_k - pred_k 
# alpha_error = gt_a - pred_a 

# # Collect errors by state
# for class_idx in range(NUM_CLASSES):
#     valid_positions = (gt_state == class_idx) & (gt_state != LABEL_PADDING_VALUE)
#     errors_k[class_idx].extend(k_error[valid_positions])
#     errors_alpha[class_idx].extend(alpha_error[valid_positions])

# # Calculate statistics
# mean_k = [np.mean(err) if len(err) > 0 else np.nan for err in errors_k]
# mean_alpha = [np.mean(err) if len(err) > 0 else np.nan for err in errors_alpha]
# std_k = [np.std(err) if len(err) > 0 else np.nan for err in errors_k]
# std_alpha = [np.std(err) if len(err) > 0 else np.nan for err in errors_alpha]

# # Calculate 95% confidence intervals
# ci_k = np.array([stats.sem(err) * stats.t.ppf((1 + 0.95) / 2, len(err)-1) 
#                  if len(err) > 1 else np.nan for err in errors_k])
# ci_alpha = np.array([stats.sem(err) * stats.t.ppf((1 + 0.95) / 2, len(err)-1) 
#                     if len(err) > 1 else np.nan for err in errors_alpha])

# # Absolute error statistics
# abs_mean_k = [np.mean(np.abs(err)) if len(err) > 0 else np.nan for err in errors_k]
# abs_mean_alpha = [np.mean(np.abs(err)) if len(err) > 0 else np.nan for err in errors_alpha]

# # Print statistics
# print("\nError Statistics by State:")
# print("-" * 50)
# for state in range(NUM_CLASSES):
#     print(f"\nState {state}:")
#     print(f"K error      - Mean ± Std: {mean_k[state]:.3f} ± {std_k[state]:.3f}")
#     print(f"K error      - Mean ± 95% CI: {mean_k[state]:.3f} ± {ci_k[state]:.3f}")
#     print(f"K error      - Abs Mean: {abs_mean_k[state]:.3f}")
#     print(f"Alpha error  - Mean ± Std: {mean_alpha[state]:.3f} ± {std_alpha[state]:.3f}")
#     print(f"Alpha error  - Mean ± 95% CI: {mean_alpha[state]:.3f} ± {ci_alpha[state]:.3f}")
#     print(f"Alpha error  - Abs Mean: {abs_mean_alpha[state]:.3f}")

In [19]:
from matplotlib.figure import Figure
from matplotlib.backends.backend_svg import FigureCanvasSVG
import numpy as np


mean_k = np.array(mean_k)
mean_alpha = np.array(mean_alpha)
std_k = np.array(std_k)
std_alpha = np.array(std_alpha)
# Constants
NUM_CLASSES = 4
CAPSIZE = 8
LINE_WIDTH = 2
MARKER_SIZE = 8
FONT_SIZE = 32

# Create figure using Figure and Canvas
fig = Figure(figsize=(10, 6), dpi=300)
canvas = FigureCanvasSVG(fig)
ax = fig.add_subplot(111)
states = np.arange(NUM_CLASSES)

# Create plots with filled circles and dashed error bars
k_line = ax.errorbar(states, mean_k, yerr=std_k, 
                    label='K',
                    fmt='o',
                    markerfacecolor=K_COLOR,
                    markeredgecolor=K_COLOR,
                    markeredgewidth=2,
                    capsize=CAPSIZE,
                    # alpha=0.9,
                    color=K_COLOR,
                    markersize=MARKER_SIZE,
                    capthick=LINE_WIDTH,
                    elinewidth=LINE_WIDTH,
                    ls='none')

k_line[-1][0].set_linestyle('--')

alpha_line = ax.errorbar(states, mean_alpha, yerr=std_alpha,
                        label='α',
                        fmt='o',
                        markerfacecolor=ALPHA_COLOR,
                        markeredgecolor=ALPHA_COLOR,
                        markeredgewidth=2,
                        capsize=CAPSIZE,
                        # alpha=0.7,
                        color=ALPHA_COLOR,
                        markersize=MARKER_SIZE,
                        capthick=LINE_WIDTH,
                        elinewidth=LINE_WIDTH,
                        ls='none')

alpha_line[-1][0].set_linestyle('--')

# Calculate y-axis limits with padding
y_min = min(min(mean_k - std_k), min(mean_alpha - std_alpha))
y_max = max(max(mean_k + std_k), max(mean_alpha + std_alpha))
y_range = y_max - y_min
padding = 0.3 * y_range  # 10% padding
ax.set_ylim(y_min - padding, y_max + padding)

# Baseline
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5, alpha=0.5)

# Configure spines
for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1)
    spine.set_color('black')

ax.set_xlabel('State (s)', fontsize=FONT_SIZE, labelpad=20)  # Updated x-label
ax.set_ylabel(r'$\overline{Error}$ ± σ', fontsize=FONT_SIZE, labelpad=60)  # Updated y-label with sigma

# Configure ticks
ax.tick_params(which='both', direction='out', length=6, width=1,
              colors='black', pad=2, labelsize=FONT_SIZE)
ax.set_xticks(states)
ax.set_xticklabels([str(i) for i in states])

# Legend
legend = ax.legend(fontsize=FONT_SIZE,
                  frameon=True,
                  edgecolor='black',
                  loc='upper right')
legend.get_frame().set_alpha(1)

# Set transparent background
fig.patch.set_facecolor('none')
ax.set_facecolor('none')

# Layout adjustments
fig.tight_layout(pad=1.0)

# Save using canvas.print_figure
canvas.print_figure('error_by_state.svg',
                   bbox_inches='tight',
                   pad_inches=0.1,
                   format='svg')

# Accuracy Track Length Publication Level Plot

In [24]:
def create_accuracy_length_plot(gt_state, pred_state, output_dir):
    """
    Create publication quality accuracy vs length plot with per-class averaging
    """
    # Figure setup
    # height_pixels = 1000
    # aspect_ratio = 4/3
    # dpi = 100
    # height_inches = height_pixels / dpi
    # width_inches = height_inches * aspect_ratio
    # plt.figure(figsize=(width_inches, height_inches), dpi=dpi)
    # ax = plt.gca()
    
    # Replace the custom size calculation with:
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    # Calculate accuracy
    max_len = 181
    n_classes = 4
    correct_predictions = np.zeros((max_len, n_classes), dtype=np.float64)
    total_points = np.zeros((max_len, n_classes), dtype=np.float64)

    # Calculate sequence lengths using padding mask (99)
    non_zero_mask = (gt_state != LABEL_PADDING_VALUE)
    sequence_lengths = non_zero_mask.sum(axis=1)
    indices = sequence_lengths - 20

    # For each class
    for class_idx in range(n_classes):
        # Create mask for current class
        class_mask = (gt_state == class_idx) * non_zero_mask
        # Calculate correct predictions for this class
        correct_mask = (pred_state == class_idx) * class_mask
        
        # Sum correct predictions and total points per sequence for this class
        correct_per_sequence = correct_mask.sum(axis=1)
        total_per_sequence = class_mask.sum(axis=1)
        
        # Accumulate for each length
        np.add.at(correct_predictions[:, class_idx], indices, correct_per_sequence)
        np.add.at(total_points[:, class_idx], indices, total_per_sequence)

    # Calculate accuracy per length (average over classes)
    valid_mask = total_points.sum(axis=1) > 0
    accuracy = np.zeros(max_len, dtype=np.float64)
    
    # For each length, calculate average accuracy across classes
    for length_idx in range(max_len):
        if valid_mask[length_idx]:
            class_accuracies = np.zeros(n_classes)
            for class_idx in range(n_classes):
                if total_points[length_idx, class_idx] > 0:
                    class_accuracies[class_idx] = (correct_predictions[length_idx, class_idx] / 
                                                 total_points[length_idx, class_idx])
            # Average only over classes that appear at this length
            classes_present = total_points[length_idx] > 0
            if np.any(classes_present):
                accuracy[length_idx] = np.mean(class_accuracies[classes_present])

    # Plot data
    x_values = np.arange(20, 201)
    colors = {'state': '#9970AB'}
    
    scatter_params = {
        'alpha': 1,
        's': 60,
    }
    
    ax.scatter(x_values, accuracy, color=colors['state'], label='Mean Class Accuracy', **scatter_params)
    
    # Set axis limits
    ax.set_xlim(15, 205)
    # ax.set_ylim(0, 1.1)  # Accuracy is between 0 and 1
    
    # Configure axes
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)
    
    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels
    ax.set_xlabel('Traj. Length', fontsize=55)
    ax.set_ylabel('Mean Class Accuracy', fontsize=55)
    
    # Configure legend
    # ax.legend(fontsize=55, frameon=False, 
    #         bbox_to_anchor=(0, 1.02, 1, 0.2),
    #         loc='lower center', 
    #         ncol=3, 
    #         mode="expand",
    #         handletextpad=0.5,
    #         columnspacing=0.6,
    #         markerscale=2.0)
    
    plt.tight_layout(pad=1.0)
    
    output_path = os.path.join(output_dir, 'accuracy_per_length.svg')
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, format='svg')
    plt.show()
    
    return plt.gcf(), ax

def create_accuracy_length_plot_smooth(gt_state, pred_state, output_dir):
    from statsmodels.nonparametric.smoothers_lowess import lowess
    
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    max_len = 181
    n_classes = 4
    correct_predictions = np.zeros((max_len, n_classes), dtype=np.float64)
    total_points = np.zeros((max_len, n_classes), dtype=np.float64)

    non_zero_mask = (gt_state != LABEL_PADDING_VALUE)
    sequence_lengths = non_zero_mask.sum(axis=1)
    indices = sequence_lengths - 20

    for class_idx in range(n_classes):
        class_mask = (gt_state == class_idx) * non_zero_mask
        correct_mask = (pred_state == class_idx) * class_mask
        
        correct_per_sequence = correct_mask.sum(axis=1)
        total_per_sequence = class_mask.sum(axis=1)
        
        np.add.at(correct_predictions[:, class_idx], indices, correct_per_sequence)
        np.add.at(total_points[:, class_idx], indices, total_per_sequence)

    valid_mask = total_points.sum(axis=1) > 0
    accuracy = np.zeros(max_len, dtype=np.float64)
    
    for length_idx in range(max_len):
        if valid_mask[length_idx]:
            class_accuracies = np.zeros(n_classes)
            for class_idx in range(n_classes):
                if total_points[length_idx, class_idx] > 0:
                    class_accuracies[class_idx] = (correct_predictions[length_idx, class_idx] / 
                                                 total_points[length_idx, class_idx])
            classes_present = total_points[length_idx] > 0
            if np.any(classes_present):
                accuracy[length_idx] = np.mean(class_accuracies[classes_present])

    x_values = np.arange(20, 201)
    valid_points = ~np.isnan(accuracy)
    x_valid = x_values[valid_points]
    y_valid = accuracy[valid_points]

    # LOWESS smoothing
    smoothed = lowess(y_valid, x_valid, frac=0.3)
    
    # Plot smoothed line
    ax.plot(smoothed[:, 0], smoothed[:, 1], color=STATE_COLOR,
           linewidth=4, alpha=1)
    
    ax.set_xlim(15, 205)
    
    # Set specific x and y ticks
    x_ticks = list(range(20, 201, 20))
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(['20', '', '', '', '100', '', '', '', '', '200'])
    ax.set_yticks([0.85, 0.9, 0.95])

    # Configure tick parameters and axes
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    ax.tick_params(axis='both', labelsize=32)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels with adjusted positioning
    ax.set_xlabel('Traj. Length', fontsize=32, labelpad=20)
    ax.set_ylabel('Mean Class Accuracy', fontsize=32, labelpad=20)
    
    # Adjust layout
    # fig.tight_layout(pad=1.0)
    fig.tight_layout(pad=1.0, rect=[0.1, 0.1, 0.9, 0.9])

    output_path = os.path.join(output_dir, 'accuracy_per_length_smooth_new.svg')
    canvas.print_figure(output_path, bbox_inches='tight', pad_inches=0.1, format='svg')
    
    return fig, ax

# Usage
# fig, ax = create_accuracy_length_plot(gt_state, pred_state, output_dir=ROOT_DIR)
# ROOT_DIR)
fig, ax = create_accuracy_length_plot_smooth(gt_state, pred_state, output_dir="./")

# Track Length Error Publication Level

In [25]:
def calculate_losses(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state):
    """Calculate losses for each sequence length."""
    max_len = 181
    loss_a = np.zeros(max_len, dtype=np.float64)
    loss_k = np.zeros(max_len, dtype=np.float64)
    loss_state = np.zeros(max_len, dtype=np.float64)
    point_count = np.zeros(max_len, dtype=np.float64)

    # Calculate sequence lengths using padding mask (99)
    non_zero_mask = (gt_a != LABEL_PADDING_VALUE)
    sequence_lengths = non_zero_mask.sum(axis=1)
    indices = sequence_lengths - 20

    # Calculate masked differences
    diff_a = np.abs(pred_a - gt_a) * non_zero_mask
    diff_k = np.abs(pred_k - gt_k) * non_zero_mask
    diff_state = np.abs(pred_state - gt_state) * non_zero_mask

    # Sum for each sequence
    sums_a = diff_a.sum(axis=1)
    sums_k = diff_k.sum(axis=1)
    sums_state = diff_state.sum(axis=1)

    # Accumulate results and point counts
    np.add.at(loss_a, indices, sums_a)
    np.add.at(loss_k, indices, sums_k)
    np.add.at(loss_state, indices, sums_state)
    np.add.at(point_count, indices, sequence_lengths)

    # Calculate means
    valid_mask = point_count > 0
    loss_a[valid_mask] /= point_count[valid_mask]
    loss_k[valid_mask] /= point_count[valid_mask]
    loss_state[valid_mask] /= point_count[valid_mask]

    return loss_a, loss_k, loss_state

# def create_loss_length_plot(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state, output_dir):
#     """
#     Create publication quality loss vs length plot with exact styling specifications
#     """
#     # Use plt instead of Figure directly
#     # Set fixed height in pixels and calculate width based on aspect ratio
#     height_pixels = 1000  # or whatever height you want
#     aspect_ratio = 4/3    # width/height ratio
    
#     # Convert to inches for figure creation (assuming 100 DPI)
#     dpi = 100
#     height_inches = height_pixels / dpi
#     width_inches = height_inches * aspect_ratio
    
#     # Create figure
#     plt.figure(figsize=(width_inches, height_inches), dpi=dpi)

#     ax = plt.gca()
    
#     # Calculate losses
#     loss_a, loss_k, loss_state = calculate_losses(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state)
    
#     # Plot data
#     x_values = np.arange(20, 201)
#     colors = {'alpha': '#E69F00', 'K': '#1B9E77', 'state': '#9970AB'}
    
#     # Create scatter plots
#     scatter_params = {
#         'alpha': 1,
#         's': 60,
#     }
    
#     ax.scatter(x_values, loss_a, color=colors['alpha'], label=r'$L_{\alpha}$', **scatter_params)
#     ax.scatter(x_values, loss_k, color=colors['K'], label=r'$L_{K}$', **scatter_params)
#     ax.scatter(x_values, loss_state, color=colors['state'], label=r'$L_{s}$', **scatter_params)
    
#     # Set axis limits
#     ax.set_xlim(15, 205)
#     y_max = max(np.nanmax(loss_a), np.nanmax(loss_k), np.nanmax(loss_state))
#     ax.set_ylim(0, y_max * 1.1)
    
#     # Configure axes
#     ax.tick_params(which='both', direction='out', length=6, width=1,
#                   colors='black', pad=2, labelsize=32)
    
#     # Configure spines
#     for spine in ax.spines.values():
#         spine.set_visible(True)
#         spine.set_linewidth(1)
#         spine.set_color('black')
    
#     # Set labels
#     ax.set_xlabel('Track Length', fontsize=55)
#     ax.set_ylabel(r'$Loss$', fontsize=55)
    
#     # Configure legend outside the plot at the top
#     ax.legend(fontsize=55, frameon=False, 
#             bbox_to_anchor=(0, 1.02, 1, 0.2),
#             loc='lower center', 
#             ncol=3, 
#             mode="expand",
#             handletextpad=0.5,  # Reduces space between marker and text
#             columnspacing=0.6,  # Reduces space between legend columns
#             markerscale=2.0)  # Makes legend markers bigger
    
#     # Adjust layout
#     plt.tight_layout(pad=1.0)
    
#     # Save if needed
#     output_path = os.path.join(output_dir, 'loss_per_length.svg')
#     plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, format='svg')
    
#     # Show the plot
#     plt.show()
    
#     return plt.gcf(), ax


def create_loss_length_plot_smooth(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state, output_dir):
    from statsmodels.nonparametric.smoothers_lowess import lowess
    
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    loss_a, loss_k, loss_state = calculate_losses(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state)
    
    x_values = np.arange(20, 201)
    colors = {'alpha': '#E69F00', 'K': '#1B9E77', 'state': '#9970AB'}
    
    # Apply LOWESS smoothing to each loss
    valid_points_a = ~np.isnan(loss_a)
    valid_points_k = ~np.isnan(loss_k)
    valid_points_s = ~np.isnan(loss_state)
    
    smoothed_a = lowess(loss_a[valid_points_a], x_values[valid_points_a], frac=0.1)
    smoothed_k = lowess(loss_k[valid_points_k], x_values[valid_points_k], frac=0.2)
    smoothed_s = lowess(loss_state[valid_points_s], x_values[valid_points_s], frac=0.3)
    
    # Plot smoothed lines
    ax.plot(smoothed_a[:, 0], smoothed_a[:, 1], color=colors['alpha'], linewidth=4, alpha=1, label=r'$\mathcal{L}_{\alpha}$')
    ax.plot(smoothed_k[:, 0], smoothed_k[:, 1], color=colors['K'], linewidth=4, alpha=1, label=r'$\mathcal{L}_{K}$')
    ax.plot(smoothed_s[:, 0], smoothed_s[:, 1], color=colors['state'], linewidth=4, alpha=1, label=r'$\mathcal{L}_{s}$')
    
    ax.set_xlim(15, 205)
    
    # Set specific x and y ticks
    x_ticks = list(range(20, 201, 20))
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(['20', '', '', '', '100', '', '', '', '', '200'])
    
    # Configure tick parameters and axes
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    ax.tick_params(axis='both', labelsize=32)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels with adjusted positioning
    ax.set_xlabel('Traj. Length', fontsize=32, labelpad=20)
    ax.set_ylabel('Loss', fontsize=32, labelpad=20)
    
    # Configure legend
    ax.legend(fontsize=28, frameon=False, 
             loc='upper right',
             handletextpad=0.5,
             columnspacing=0.6,
             markerscale=2.0)
    
    fig.tight_layout(pad=1.0)
    
    output_path = os.path.join(output_dir, 'loss_per_length_smooth_new.svg')
    canvas.print_figure(output_path, bbox_inches='tight', pad_inches=0.1, format='svg')
    
    return fig, ax

# Usage
# fig, ax = create_loss_length_plot(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state, ROOT_DIR)
fig, ax = create_loss_length_plot_smooth(gt_a, gt_k, gt_state, pred_a, pred_k, pred_state, "./")


# Plot Heatmaps Publication Level

In [None]:
def mae(pred, gt, padding_value=LABEL_PADDING_VALUE):

   non_zero_mask = (gt != padding_value)
   loss = np.abs(pred - gt) * non_zero_mask

   sequence_lengths = non_zero_mask.sum(axis=1)
   sequence_sums = loss.sum(axis=1)
   sequence_averages = sequence_sums / sequence_lengths

   return sequence_averages.mean()

male = mae(pred_k.copy(), gt_k.copy())
mae = mae(pred_a.copy(), gt_a.copy())

print(male, mae)

In [3]:
def mae_alpha(pred_a, gt_a, padding_value=LABEL_PADDING_VALUE):

   non_zero_mask = (gt_a != padding_value)
   loss = np.abs(pred_a - gt_a) * non_zero_mask

   sequence_lengths = non_zero_mask.sum(axis=1)
   sequence_sums = loss.sum(axis=1)
   sequence_averages = sequence_sums / sequence_lengths

   return sequence_averages.mean()

def mre_k(pred_k, gt_k, padding_value=LABEL_PADDING_VALUE):

   non_zero_mask = (gt_k != padding_value)
   # Calculate relative error with transformations and mask
   relative_errors = np.abs((10**pred_k - 1) - (10**gt_k - 1)) / (10**gt_k - 1)
   relative_errors = relative_errors * non_zero_mask
   # Get per-sequence average
   sequence_lengths = non_zero_mask.sum(axis=1)
   sequence_sums = relative_errors.sum(axis=1)
   sequence_averages = sequence_sums / sequence_lengths
   # Mean over all sequences
   return sequence_averages.mean()


def create_publication_heatmap(true_values, pred_values, param_name, value_range, output_dir, color):
    """
    Create publication quality heatmap with exact styling specifications
    """

    if param_name.lower() == 'alpha':
        error = mae_alpha(pred_values.reshape(-1, 1), true_values.reshape(-1, 1))
        error_text = f'MAE = {error:.3f}'
    else:  # K
        error = mre_k(pred_values.reshape(-1, 1), true_values.reshape(-1, 1))
        error_text = f'MRE = {error:.3f}'
    
    

    # Create figure with specific DPI for precise control
    mask = true_values != LABEL_PADDING_VALUE
    true_values = true_values[mask]
    pred_values = pred_values[mask]

    print("Max:", np.max(true_values), "Min:", np.min(true_values))

    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Calculate histogram
    num_bins = 40
    min_val, max_val = value_range
    true_bins = np.linspace(min_val, np.max(true_values), num_bins + 1)
    pred_bins = np.linspace(min_val, np.max(true_values), num_bins + 1)
    hist, xedges, yedges = np.histogram2d(true_values, pred_values, 
                                         bins=[true_bins, pred_bins])
    
    if color.lower() == ALPHA_COLOR:
        cmap = mcolors.LinearSegmentedColormap.from_list('white_orange', 
                                                        ['white', '#E69F00'])
    elif color.lower() == STATE_COLOR:  # purple
        cmap = mcolors.LinearSegmentedColormap.from_list('white_purple', 
                                                        ['white', '#9970AB'])
        
    elif color.lower() == K_COLOR: # dark green
        cmap = mcolors.LinearSegmentedColormap.from_list('white_green', 
                                                        ['white', '#1B9E77'])
    else:
        raise ValueError('This color does not exist')
        
    # Plot heatmap
    im = ax.imshow(hist.T, cmap=cmap, aspect='equal', 
                   extent=[min_val, max_val, min_val, max_val], origin='lower')

    ax.text(0.05, 0.95, error_text,
            transform=ax.transAxes,
            fontsize=32,
            verticalalignment='top',
            bbox=dict(facecolor='white', alpha=1, edgecolor='none', pad=5))
    # aspect='auto', 
    
    # Set tick positions and labels
    mid_val = (min_val + max_val) / 2
    tick_positions = [min_val, mid_val, max_val]
    ax.set_xticks(tick_positions)
    ax.set_yticks(tick_positions)
    
    # Set tick label size to exactly 6pt
    ax.set_xticklabels([f'{int(x)}' if float(x).is_integer() else f'{x:.1f}' for x in tick_positions], fontsize=32)
    ax.set_yticklabels([f'{int(x)}' if float(x).is_integer() else f'{x:.1f}' for x in tick_positions], fontsize=32)

    # Configure tick parameters (6pt = 2.12mm)
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    
    # Set spines with thinner lines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels based on parameter with smaller font size
    if param_name.lower() == 'alpha':
        xlabel = r'$\alpha_{\mathrm{pred}}$'
        ylabel = r'$\alpha_{\mathrm{true}}$'
        filename = "alpha_heatmap_new.svg"
    else:  # K
        xlabel = r'$K_{\mathrm{pred}}$'
        ylabel = r'$K_{\mathrm{true}}$'
        filename = "K_heatmap_new.svg"
    
    # Reduced font size for axis labels to match your image
    ax.set_xlabel(xlabel, fontsize=55)
    ax.set_ylabel(ylabel, fontsize=55)
    
    # Adjust layout with smaller margins
    fig.tight_layout(pad=1.0)
    
    # # Save as SVG
    # output_path = os.path.join(output_dir, filename)
    # canvas.print_figure(output_path, bbox_inches='tight', 
    #                    pad_inches=0.1, format='svg')

    # plt.show()
    
    return fig, ax


def create_publication_heatmap_new(true_values, pred_values, param_name, value_range, output_dir, color):
    # Calculate error
    if param_name.lower() == 'alpha':
        error = mae_alpha(pred_values.reshape(-1, 1), true_values.reshape(-1, 1))
        error_text = f'MAE = {error:.3f}'
    else:
        error = mre_k(pred_values.reshape(-1, 1), true_values.reshape(-1, 1))
        error_text = f'MRE = {error:.3f}'

    mask = true_values != LABEL_PADDING_VALUE
    true_values = true_values[mask]
    pred_values = pred_values[mask]

    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Calculate histogram
    num_bins = 40
    min_val, max_val = value_range
    true_bins = np.linspace(min_val, np.max(true_values), num_bins + 1)
    pred_bins = np.linspace(min_val, np.max(true_values), num_bins + 1)
    hist, xedges, yedges = np.histogram2d(true_values, pred_values, 
                                         bins=[true_bins, pred_bins])
    
    # Create colormap
    cmap_dict = {
        "alpha_color": ['white', ALPHA_COLOR],
        "state_color": ['white', STATE_COLOR],
        "k_color": ['white', K_COLOR]
    }
    cmap = mcolors.LinearSegmentedColormap.from_list('custom', cmap_dict[color.lower()])
    
    # Plot heatmap
    im = ax.imshow(hist.T, cmap=cmap, aspect='equal',
                   extent=[min_val, max_val, min_val, max_val], origin='lower')

    # Add error text
    ax.text(0.05, 0.95, error_text,
            transform=ax.transAxes,
            fontsize=32,
            verticalalignment='top',
            bbox=dict(facecolor='white', alpha=1, edgecolor='none', pad=5))
    
    # Set ticks
    mid_val = (min_val + max_val) / 2
    tick_positions = [min_val, mid_val, max_val]
    ax.set_xticks(tick_positions)
    ax.set_yticks(tick_positions)
    
    tick_labels = [f'{int(x)}' if float(x).is_integer() else f'{x:.1f}' for x in tick_positions]
    ax.set_xticklabels(tick_labels, fontsize=32)
    ax.set_yticklabels(tick_labels, fontsize=32)

    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels
    # Set labels based on parameter with smaller font size
    if param_name.lower() == 'alpha':
        xlabel = r'$\alpha_{\mathrm{pred}}$'
        ylabel = r'$\alpha_{\mathrm{true}}$'
        filename = "alpha_heatmap_new.svg"
    else:  # K
        xlabel = r'$K_{\mathrm{pred}}$'
        ylabel = r'$K_{\mathrm{true}}$'
        filename = "K_heatmap_new.svg"
    
    
    ax.set_xlabel(xlabel, fontsize=32)
    ax.set_ylabel(ylabel, fontsize=32)
    
    fig.tight_layout(pad=1.0)
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')
    
    output_path = os.path.join(output_dir, filename)
    canvas.print_figure(output_path, bbox_inches='tight', 
                       pad_inches=0.1, format='svg')

    return fig, ax

In [None]:
# # Usage example:
# # For alpha:
# fig, ax = create_publication_heatmap_new(
#     gt_a.flatten(), 
#     pred_a.flatten(), 
#     param_name='alpha',
#     value_range=(0, 2),
#     output_dir=ROOT_DIR,
#     color="alpha_color"
# )

# For K:
fig, ax = create_publication_heatmap_new(
    gt_k.flatten(), 
    pred_k.flatten(), 
    param_name='K',
    value_range=(0, 3),
    output_dir=ROOT_DIR,
    color="k_color"
)

# Publication Level Confusion Matrix Done!!!

In [21]:
def plot_confusion_matrix_new(cm_normalized, output_dir):
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    cm_normalized = cm_normalized * 100  # Convert to percentages
    
    labels = ['imm.', 'conf.', 'free', 'dir.']
    cmap = mcolors.LinearSegmentedColormap.from_list('white_purple', 
                                                    ['white', STATE_COLOR])
    
    sns.heatmap(cm_normalized, 
                annot=True,
                fmt='.0f',
                cmap=cmap,
                xticklabels=labels,
                yticklabels=labels,
                square=True,
                annot_kws={'size': 32},
                cbar=False,
                linewidths=1,
                linecolor='black',
                robust=True,
                ax=ax)
    
    ax.set_xlabel('Predicted', fontsize=32, labelpad=20)
    ax.set_ylabel('Ground Truth', fontsize=32, labelpad=20)
    
    ax.tick_params(labelsize=32, which='both', direction='out', 
                  length=6, width=1, colors='black', pad=2)
    
    plt.setp(ax.get_xticklabels(), rotation=0)
    plt.setp(ax.get_yticklabels(), rotation=0)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    ax.axhline(y=0, color='black', linewidth=1)
    ax.axhline(y=cm_normalized.shape[0], color='black', linewidth=1)
    ax.axvline(x=0, color='black', linewidth=1)
    ax.axvline(x=cm_normalized.shape[1], color='black', linewidth=1)
    
    fig.tight_layout(pad=1.0)
    
    canvas.print_figure(os.path.join(output_dir, "confusion_matrix.svg"),
                       bbox_inches='tight', pad_inches=0.1, format='svg')
    
    return fig, ax


def plot_confusion_matrix(cm_normalized, output_dir):
    """
    Create publication quality confusion matrix plot with outlines
    """
    # Create figure with specific size
    plt.figure(figsize=(12, 10))
    
    # Define class labels
    labels = ['imm.', 'comp.', 'free', 'dir.']
    
    # Create custom purple colormap
    cmap = mcolors.LinearSegmentedColormap.from_list('white_purple', 
                                                    ['white', STATE_COLOR])
    
    # Create heatmap with specific styling
    ax = sns.heatmap(cm_normalized, 
                     annot=True,
                     fmt='.2f',
                     cmap=cmap,
                     xticklabels=labels,
                     yticklabels=labels,
                     square=True,
                     annot_kws={'size': 32},
                     cbar=False,
                     linewidths=1,          # Add lines between cells
                     linecolor='black',     # Make the lines black
                     robust=True)
    
    # Set font sizes for axis labels
    ax.set_xlabel('Predicted', fontsize=55, labelpad=20)
    ax.set_ylabel('Ground Truth', fontsize=55, labelpad=20)
    
    # Set font sizes for tick labels
    ax.tick_params(labelsize=32)
    
    # Rotate the tick labels
    plt.xticks(rotation=0)
    plt.yticks(rotation=0)
    
    # Add thick border around the entire matrix
    for _, spine in ax.spines.items():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_color('black')
    
    # Add thicker outer border
    ax.axhline(y=0, color='black', linewidth=1)
    ax.axhline(y=cm_normalized.shape[0], color='black', linewidth=1)
    ax.axvline(x=0, color='black', linewidth=1)
    ax.axvline(x=cm_normalized.shape[1], color='black', linewidth=1)
    
    # Make the plot square and tight
    plt.tight_layout()
    
    # Save figure
    plt.savefig(os.path.join(output_dir, "confusion_matrix.svg"), 
                format="svg", 
                bbox_inches='tight',
                pad_inches=0.1,
                dpi=300)
        
    return ax

def calculate_normalized_confusion_matrix(pred_states, gt_states, num_classes=4):
    # Flatten the arrays and remove padding
    pred_flat = pred_states.flatten()
    gt_flat = gt_states.flatten()
    mask = (gt_flat != LABEL_PADDING_VALUE)
    pred_flat = pred_flat[mask]
    gt_flat = gt_flat[mask]
    # Calculate confusion matrix
    cm = confusion_matrix(gt_flat, pred_flat, labels=range(num_classes))
    # Normalize the confusion matrix by row
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    return cm_normalized


def calculate_metrics(pred_states, gt_states, num_classes=4):
    """
    Calculate F1 Score, Recall, Precision and average class accuracy
    """
    # Flatten and remove padding
    pred_flat = pred_states.flatten()
    gt_flat = gt_states.flatten()
    mask = (gt_flat != LABEL_PADDING_VALUE)
    pred_flat = pred_flat[mask]
    gt_flat = gt_flat[mask]
    
    # Calculate metrics for each class
    precision, recall, f1, support = precision_recall_fscore_support(gt_flat, pred_flat, 
                                                                   labels=range(num_classes), 
                                                                   zero_division=0)
    
    # Calculate per-class accuracy from confusion matrix
    cm = confusion_matrix(gt_flat, pred_flat, labels=range(num_classes))
    class_accuracy = np.diag(cm) / np.sum(cm, axis=1)
    avg_class_accuracy = np.mean(class_accuracy)
    
    # Create a formatted output string
    labels = ['imm.', 'comp.', 'free', 'dir.']
    
    print("\nDetailed Classification Metrics:")
    print("--------------------------------")
    for i, label in enumerate(labels):
        print(f"\n{label}:")
        print(f"F1 Score:   {f1[i]:.3f}")
        print(f"Precision:  {precision[i]:.3f}")
        print(f"Recall:     {recall[i]:.3f}")
        print(f"Accuracy:   {class_accuracy[i]:.3f}")
    
    print("\nAverages:")
    print("--------------------------------")
    print(f"Average F1 Score:           {np.mean(f1):.3f}")
    print(f"Average Precision:          {np.mean(precision):.3f}")
    print(f"Average Recall:             {np.mean(recall):.3f}")
    print(f"Average Class Accuracy:     {avg_class_accuracy:.3f}")
    
    return {
        'per_class': {
            'f1': f1,
            'precision': precision,
            'recall': recall,
            'accuracy': class_accuracy
        },
        'averages': {
            'f1': np.mean(f1),
            'precision': np.mean(precision),
            'recall': np.mean(recall),
            'accuracy': avg_class_accuracy
        }
    }

In [None]:
cm_normalized = calculate_normalized_confusion_matrix(pred_state, gt_state)
ax = plot_confusion_matrix_new(cm_normalized, ROOT_DIR)
metrics = calculate_metrics(pred_state, gt_state)

# Jaccard Plots Alpha/K Publication Level

In [52]:
def create_jaccard_plot(file_path, counter_path, parameter_type='alpha', output_dir=None):
    """
    Create publication quality plot for Jaccard index vs either delta alpha or delta K
    
    Parameters:
    -----------
    file_path : str
        Path to the JSON file containing the Jaccard index data
    counter_path : str
        Path to the JSON file containing the counter data
    parameter_type : str
        Either 'alpha' or 'K' to specify which parameter is being plotted
    output_dir : str or None
        Directory to save the plot. If None, saves in current directory
    
    Returns:
    --------
    fig, ax : tuple
        Matplotlib figure and axis objects
    """
    # Validate parameter type
    if parameter_type not in ['alpha', 'K']:
        raise ValueError("parameter_type must be either 'alpha' or 'K'")

    # Set parameters based on type
    if parameter_type == 'alpha':
        bin_width = 0.1
        x_limits = (-2, 2)
        xlabel = r'$\Delta\alpha$'
        output_filename = 'jaccard_delta_alpha.svg'
        color_variable = ALPHA_COLOR
        x_tick_positions = [-2, -1, 0, 1, 2]  # Only whole numbers
    else:  # K
        bin_width = 0.5
        x_limits = (-5, 5)
        xlabel = r'$\Delta K$'
        output_filename = 'jaccard_delta_k.svg'
        color_variable = K_COLOR
        x_tick_positions = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]  # Only whole numbers

    # Load data
    with open(file_path, 'r') as file:
        data = json.load(file)
    with open(counter_path, 'r') as file:
        counter_data = json.load(file)

    # Extract and process values
    x_values = np.array([float(key) for key in data.keys()])
    y_values = np.array(list(data.values()))
    counter = np.array(list(counter_data.values()))
    y_values = y_values / counter

    # Create bins
    bins = np.arange(min(x_values) - bin_width/2, max(x_values) + bin_width/2, bin_width)
    binned_data = {}

    # Bin the data
    for i in range(len(bins)-1):
        bin_start = bins[i]
        bin_end = bins[i+1]
        mask = (x_values >= bin_start) & (x_values < bin_end)
        if np.any(mask):
            bin_y_values = y_values[mask]
            binned_data[round((bin_start + bin_end)/2, 3)] = np.mean(bin_y_values)

    # Convert binned data to sorted lists
    sorted_x = sorted(binned_data.keys())
    sorted_y = [binned_data[x] for x in sorted_x]

    # Create figure with specific dimensions and DPI to match heatmap
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    # Plot data with thicker lines and larger markers
    ax.plot(sorted_x, sorted_y, marker='o', markersize=8, 
            color=color_variable, linewidth=3, alpha=1)

    # Set axis limits
    ax.set_xlim(x_limits)
    ax.set_ylim(0, 1)

    # Set tick positions
    y_tick_positions = [0, 0.5, 1]  # As requested: 0, 0.5, and 1
    
    # ax.set_xticks(x_tick_positions)
    ax.set_yticks(y_tick_positions)

    # Add grid with consistent style
    # ax.grid(True, linestyle='--', alpha=0.7, color='gray', linewidth=0.5)
    ax.set_axisbelow(True)

    # Configure tick parameters to match heatmap
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)

    # Set tick label sizes to match heatmap
    # ax.set_xticklabels([f'{x:g}' for x in x_tick_positions], fontsize=32)
    ax.set_yticklabels([f'{y:g}' for y in y_tick_positions], fontsize=32)

    # Configure spines to match heatmap
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    # Set labels with matching font sizes
    ax.set_xlabel(xlabel, fontsize=55)
    ax.set_ylabel('J', fontsize=55, rotation=0, labelpad=16, va='center')

    # Adjust layout
    fig.tight_layout(pad=1.0)

    # Save the plot
    if output_dir:
        output_path = os.path.join(output_dir, output_filename)
    else:
        output_path = output_filename
    
    canvas.print_figure(output_path, bbox_inches='tight', pad_inches=0.1, format='svg')
    plt.show()
    
    return fig, ax

In [51]:
def create_jaccard_plot_new(file_path, counter_path, parameter_type='alpha'):
    if parameter_type not in ['alpha', 'K']:
        raise ValueError("parameter_type must be either 'alpha' or 'K'")

    if parameter_type == 'alpha':
        bin_width = 0.1
        x_limits = (-2, 2)
        xlabel = r'$\Delta\alpha$'
        output_filename = 'jaccard_delta_alpha_new.svg'
        color_variable = ALPHA_COLOR
        x_tick_positions = [-2, -1, 0, 1, 2]
    else:
    #    bin_width = 0.1
    #    x_limits = (-5, 5)
        x_limits = (-0.5, 0.5)
        xlabel = r'$\Delta K$' 
        output_filename = 'jaccard_delta_k_new_log_space.svg'
        color_variable = K_COLOR
        x_tick_positions = [-0.5, -0.4, -0.3, -0.25,-0.1, 0, 0.1, 0.25, 0.3, 0.4, 0.5]
        # x_tick_positions = [-0.5, -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5]

    with open(file_path, 'r') as file:
        data = json.load(file)
    with open(counter_path, 'r') as file:
        counter_data = json.load(file)

    # Create a sorted list of keys to ensure same order
    keys = sorted(data.keys(), key=float)  # Sort keys as float numbers
    # Create matched arrays
    x_values = np.array([float(key) for key in keys])
    y_values = np.array([data[key] for key in keys])
    counter = np.array([counter_data[key] for key in keys])

    # Perform the division
    y_values = y_values / counter
    sorted_x = x_values
    sorted_y = y_values

    print(sorted_x)
    #    with open(file_path, 'r') as file:
    #        data = json.load(file)
    #    with open(counter_path, 'r') as file:
    #        counter_data = json.load(file)

    #    x_values = np.array([float(key) for key in data.keys()])
    #    y_values = np.array(list(data.values()))
    #    print(x_values, y_values)
    #    counter = np.array(list(counter_data.values()))
    #    y_values = y_values / counter
    #    print(y_values) 
    #    bins = np.arange(min(x_values) - bin_width/2, max(x_values) + bin_width/2, bin_width)
    #    binned_data = {}

    #    for i in range(len(bins)-1):
    #        bin_start = bins[i]
    #        bin_end = bins[i+1]
    #        mask = (x_values >= bin_start) & (x_values < bin_end)
    #        if np.any(mask):
    #            bin_y_values = y_values[mask]
    #            binned_data[round((bin_start + bin_end)/2, 3)] = np.mean(bin_y_values)

    #    sorted_x = sorted(binned_data.keys())
    #    sorted_y = [binned_data[x] for x in sorted_x]

    #    sorted_x = sorted(x_values.keys())
    #    sorted_y = [binned_data[x] for x in sorted_x]

    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    ax.plot(sorted_x, sorted_y,
            color=color_variable,
            linewidth=2, 
            alpha=1,
            marker='o',
            markersize=10,
            markerfacecolor=color_variable,
            markeredgecolor=color_variable,
            markeredgewidth=2)

    ax.set_xlim(x_limits)
    ax.set_ylim(0, 1.05)

    ax.set_xticks(x_tick_positions)
    ax.set_yticks([0, 0.5, 1])

    ax.set_axisbelow(True)

    ax.tick_params(which='both', direction='out', length=6, width=1,
                    colors='black', pad=2, labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    ax.set_xlabel(xlabel, fontsize=32)
    ax.set_ylabel('$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)

    fig.tight_layout()
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')

    canvas.print_figure(output_filename,
                        bbox_inches='tight',
                        pad_inches=0,
                        format='svg')

    return fig, ax

  ax.set_ylabel('$\overline{J}$', fontsize=32, rotation=0, va='center', labelpad=20)


In [None]:
file_path = '/home/haidiri/Desktop/AnDiChallenge2024/plots/results_for_plotting/jaccard_simulations_for_alpha_fixedL_singleCP_results/jaccard_new_tracks.json'
counter_path = '/home/haidiri/Desktop/AnDiChallenge2024/plots/results_for_plotting/jaccard_simulations_for_alpha_fixedL_singleCP_results/counter.json'

# fig, ax = create_jaccard_plot_new(file_path, counter_path, parameter_type='alpha')
# plt.show()

In [52]:
file_path = '/home/haidiri/Desktop/AnDiChallenge2024/plots/results_for_plotting/K_jaccard_single_CP_more_sampling_results/jaccard_more_sampling_new_tracks_0_5.json'
counter_path = '/home/haidiri/Desktop/AnDiChallenge2024/plots/results_for_plotting/K_jaccard_single_CP_more_sampling_results/counter_more_sampling_0_5.json'

fig, ax = create_jaccard_plot_new(file_path, counter_path, parameter_type='K')

[-0.49 -0.46 -0.44 -0.41 -0.39 -0.36 -0.34 -0.31 -0.29 -0.26 -0.24 -0.21
 -0.19 -0.16 -0.14 -0.11 -0.09 -0.06 -0.04 -0.01  0.01  0.04  0.06  0.09
  0.11  0.14  0.16  0.19  0.21  0.24  0.26  0.29  0.31  0.34  0.36  0.39
  0.41  0.44  0.46  0.49  0.51]


# Example Tracks

In [3]:
def getPredictions(df, max_size=200):

    features = np.nan_to_num(getFeatures(df["x"].values, df["y"].values), nan=0.0, posinf=0.0, neginf=0.0)
    features = torch.tensor(features, dtype=torch.float32, device=DEVICE).unsqueeze(0)
    length = features.size(1)

    if length < max_size:
        features = F.pad(features, (0, 0, 0, max_size - length), value=FEATURE_PADDING_VALUE)
    elif length > max_size:
        features = features[:, :max_size]
        print(f"Note that the input series is longer than the maximum size. The input series has been truncated to the first {max_size} values.")

    with torch.no_grad():
        # convert to numpy arrays for downstream analysis
        pred_alpha_list = AlphaModel(features).cpu().numpy().flatten().squeeze()[:length]
        pred_k_list = KModel(features).cpu().numpy().flatten().squeeze()[:length]
        states_log_probs = StateModel(features)
        pred_states_list = torch.argmax(states_log_probs, dim=-1).cpu().numpy().flatten().squeeze()[:length]

    pred_alpha_list = np.array(pred_alpha_list)
    pred_k_list = np.array(pred_k_list)
    pred_states_list = np.array(pred_states_list)

    pred_alpha_list = np.clip(median_filter_1d(smooth_series(pred_alpha_list)), 0, 2)
    pred_k_list = np.clip(median_filter_1d(smooth_series(pred_k_list)), 0, 6)

    return pred_alpha_list, pred_k_list, pred_states_list

def create_plots_for_track_examples(example_track, pred_a, pred_k, pred_state):

    # Plot track 
    plt.figure(figsize=(10, 6))
    plt.plot(example_track["x"], example_track["y"], marker='o')    
    plt.figure()
    
    # Find where padding starts
    padding_starts = (example_track["alpha"] == LABEL_PADDING_VALUE).argmax() 
    if padding_starts == 0:
        padding_starts = 200

    # Trim predictions to remove padding
    pred_a = pred_a[:padding_starts]
    pred_k = pred_k[:padding_starts]
    pred_state = pred_state[:padding_starts]

    # Apply smoothing and value constraints
    pred_a = np.clip(median_filter_1d(smooth_series(pred_a)), 0, 2)
    pred_k = np.clip(median_filter_1d(smooth_series(pred_k)), 0, 6)
    pred_state = replace_short_sequences(pred_state)

    # Get ground truth values
    gt_alpha = example_track["alpha"][:padding_starts]
    gt_k = np.log10(example_track["D"][:padding_starts] + 1)
    gt_state = example_track["state"][:padding_starts]
    
    # Time points for x-axis
    time_points = np.arange(len(pred_a))

    # Plot Alpha
    plt.figure(figsize=(10, 6))
    plt.plot(time_points, gt_alpha, linewidth=2)
    plt.scatter(time_points, pred_a, color=ALPHA_COLOR, alpha=0.6, s=30)
    plt.title("Alpha")
    plt.ylim(0, 2)
    plt.figure()

    # Plot K
    plt.figure(figsize=(10, 6))
    plt.plot(time_points, gt_k, linewidth=2)
    plt.scatter(time_points, pred_k, color=K_COLOR, alpha=0.6, s=30)
    plt.title("K")
    plt.ylim(0, 3)
    plt.figure()

    # Plot State
    plt.figure(figsize=(10, 6))
    plt.plot(time_points, gt_state, linewidth=2)
    plt.scatter(time_points, pred_state, color=STATE_COLOR, alpha=0.6, s=30)
    plt.title("State")
    plt.ylim(0, 4)
    plt.show()

In [3]:
def plot_prediction(pred, true, label_type, save_path=None):
    K_COLOR = '#1B9E77'      # Green
    ALPHA_COLOR = '#E69F00'  # Orange
    STATE_COLOR = '#9970AB'  # Purple
    
    if label_type =="alpha":
        COLOR = ALPHA_COLOR 
        ymin = 0
        ymax = 2
        label = r'$K$'
    elif label_type == "k":
        COLOR = K_COLOR
        ymin = 0
        ymax = 3
        label = r'$α$'
    elif label_type == "state":
        COLOR = STATE_COLOR
        ymin = 0
        ymax = 4
        label = r'$s$'
    else:
        raise ValueError

    def getCP_gt(array):
        cps = [0]
        for i in range(1, len(array)):
            if array[i-1] != array[i]:
                cps.append(i)

        return cps + [len(array)]

    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    time = np.arange(len(pred))
    ax.plot(time, pred, color=COLOR, linestyle='none', 
            marker='o', markersize=6, markerfacecolor='none',
            markeredgecolor='#1B9E77', markeredgewidth=1.5,
            alpha=0.7)
    
    cp_gt = getCP_gt(true)

    for i in range(len(cp_gt)-1):
        start_idx = cp_gt[i]
        end_idx = cp_gt[i+1]
        if end_idx == len(true):
            end_idx -= 1
        segment_value = true[start_idx]  # value for this segment
        ax.plot([start_idx, end_idx], [segment_value, segment_value],
                color='black', linewidth=2, alpha=1)

    margin = (ymax - ymin) * 0.1  # 10% margin
    ax.set_ylim(ymin - margin, ymax + margin)
    ax.set_axisbelow(True)
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    
    ax.tick_params(axis='both', labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels with exact same formatting as heatmap
    ax.set_xlabel('Time', fontsize=32)
    ax.set_ylabel(label, fontsize=32)
    
    # # Configure legend with matching font size
    # ax.legend(fontsize=28, frameon=True, loc='best')
    fig.tight_layout(pad=1.0)
    
    # Save as SVG with same parameters
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight', 
                          pad_inches=0.1, format='svg')
    
    return fig, ax

Single State

In [36]:
def single_state(index=0):
    dic = _get_dic_andi2(1)
    dic['T'] = 200 
    dic['N'] = 100
    # dic['alphas'], dic['Ds'] = np.array([0.01, 0]), np.array([500, 0])

    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)

    df = dfs_traj[0]
    for i in df["traj_idx"].unique():
        if len(df[df["traj_idx"] == i]) <= 25:
            index = i
            break           
    return df[df["traj_idx"] == index]

In [None]:
example_track = single_state() 
plt.plot(example_track["x"], example_track["y"])
plt.figure()
example_track.to_csv(f"short_trajectory_single_example.csv", index=False)
pa, pk, ps = getPredictions(example_track)
create_plots_for_track_examples(example_track, pa, pk, ps)

MultiState

In [39]:
def multi_state():
    dic = _get_dic_andi2(2)
    dic['T'] = 200 
    dic['N'] = 100
    dic['M'] = np.array([[0.5, 0.5],[0.5, 0.5]])
    dic['alphas'] = np.array([[1,0],[0.5,0]])
    dic["Ds"] = np.array([[1,0],[0.1,0]])
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    
    df = dfs_traj[0]

    for i in df["traj_idx"].unique():
        if len(getCP_gt(df[df["traj_idx"] == i]["D"].values)[1:-1]) == 2 and len(df[df["traj_idx"] == i]["D"].values) <=40:
            index = i
            break

    return df[df["traj_idx"] == index]

In [None]:
example_track = multi_state() 
pa, pk, ps = getPredictions(example_track)
example_track.to_csv(f"short_bad_trajectory_multi_example.csv", index=False)
create_plots_for_track_examples(example_track, pa, pk, ps)

Confined

In [21]:
def confined():
    dic = _get_dic_andi2(5)
    dic['T'] = 200 
    dic['N'] = 100

    Ds_array = np.array([[1.5, 0.1],[0.3,0.1]])
    alphas_array = np.array([[0.9, 0.1],[0.5,0.1]])
    dic['Ds'] = Ds_array
    dic['alphas'] = alphas_array
    dic['trans'] = 0
    dic['Nc'] = random.randint(30,35)
    dic['r'] = random.randint(5, 10)
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    df = dfs_traj[0]
    
    for _, group in df.groupby("traj_idx"):
        if len(getCP_gt(group["D"].values)[1:-1]) == 1:
            index = group["traj_idx"].iloc[0]  # Get the actual traj_idx value
            break
    return df[df["traj_idx"] == index]

In [None]:
example_track = confined() 
pa, pk, ps = getPredictions(example_track)
example_track.to_csv(f"bad_trajectory_confined_example.csv", index=False)
create_plots_for_track_examples(example_track, pa, pk, ps)

Immobile

In [31]:
def immobile():
    dic = _get_dic_andi2(3)
    dic['T'] = 200 
    dic['N'] = 100

    dic['Ds'], dic['alphas'] = np.array([1, 0]), np.array([1, 0])
    dic['Pu'] = random.uniform(0, 0.1)
    dic['Pb'] = 1
    dic['r'] = random.uniform(0.5, 2)
    dic['Nt'] = random.randint(100, 300)
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    df = dfs_traj[0]

    for _, group in df.groupby("traj_idx"):
        if len(getCP_gt(group["D"].values)[1:-1]) == 5:
            index = group["traj_idx"].iloc[0]  # Get the actual traj_idx value
            break

    return df[df["traj_idx"] == index]


In [None]:
example_track = immobile() 
example_track.to_csv(f"bad_trajectory_confined_immobile.csv", index=False)
pa, pk, ps = getPredictions(example_track)
create_plots_for_track_examples(example_track, pa, pk, ps)

Dimerised

In [50]:
def dimerised(index=0):
    dic = _get_dic_andi2(4)
    dic['T'] = 200 
    dic['N'] = 100
    Ds_array, alphas_array = np.zeros((2, 2)), np.zeros((2, 2))
    Ds_array[0], alphas_array[0] = np.array([1, 0]),  np.array([1, 0])
    Ds_array[1], alphas_array[1]  = np.array([1, 0]), np.array([1, 0])
    dic['Ds'] = Ds_array
    dic['alphas'] = alphas_array
    dic['Pu'] = random.uniform(0, 0.1)
    dic['Pb'] = 1
    dic['r'] = random.uniform(0.5, 5)
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    df = dfs_traj[0]
    return df[df["traj_idx"] == index]

In [None]:
example_track = dimerised() 
example_track.to_csv(f"trajectory_confined_dimerised.csv", index=False)
pa, pk, ps = getPredictions(example_track)
create_plots_for_track_examples(example_track, pa, pk, ps)

# Plots from Training 

In [23]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def plot_tensorboard_metrics(train_csv, val_csv, save_path=None, loss_type='alpha'):
    """
    Plot training metrics from TensorBoard CSV files with consistent publication styling
    
    Parameters:
    -----------
    train_csv : str
        Path to training loss CSV file
    val_csv : str
        Path to validation loss CSV file
    save_path : str or None
        Path to save the output plot
    loss_type : str
        Either 'alpha' or 'K' to specify which loss is being plotted
    """
    # Read CSV files
    train_df = pd.read_csv(train_csv)
    val_df = pd.read_csv(val_csv)
    
    # Extract step and value columns
    train_steps = train_df['Step']
    train_values = train_df['Value']
    val_steps = val_df['Step']
    val_values = val_df['Value']

    # Create figure with same dimensions as jaccard plot
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    # Plot data with solid and dashed lines in black
    ax.plot(train_steps, train_values, color='black', linewidth=3, 
            label='Train', linestyle='-')
    ax.plot(val_steps, val_values, color='black', linewidth=3, 
            label='Val', linestyle='--')

    # Configure axis and style to match jaccard plot
    ax.set_axisbelow(True)
    ax.set_xlim(0, max(max(train_steps), max(val_steps)))

    # Configure tick parameters to match
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    
    # Set tick label sizes to match (32 pt)
    ax.tick_params(axis='both', labelsize=32)

    # Configure spines to match
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    # Set labels with matching font sizes (55 pt)
    ax.set_xlabel('Epoch', fontsize=55)
    
    # Set y-label based on loss type using LaTeX notation
    if loss_type.lower() == 'alpha':
        ylabel = r'$L_{\alpha}$'
    else:  # K
        ylabel = r'$L_{K}$'
    ax.set_ylabel(ylabel, fontsize=55)
    
    # Configure legend with larger font size
    ax.legend(fontsize=32, frameon=True, loc='upper right')

    # Adjust layout
    fig.tight_layout(pad=1.0)

    # Save the plot
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight', 
                          pad_inches=0.1, format='svg')
    
    plt.show()
    
    return fig, ax
# Example usage:
# plot_tensorboard_metrics('train_loss.csv', 'val_loss.csv', 'training_plot.svg')
# plot_tensorboard_metrics("Losses_Training_K_Loss.csv", "Losses_Validation_K_Loss.csv", save_path="k_train_plots.svg")


In [None]:
# plot_tensorboard_metrics("Losses_Training_K_Loss.csv", "Losses_Validation_K_Loss.csv", save_path="k_train_plots.svg")
# plot_tensorboard_metrics("Losses_Training_Alpha_Loss.csv", "Losses_Validation_Alpha_Loss.csv", save_path="alpha_train_plots.svg", loss_type="alpha")
# plot_tensorboard_metrics("Losses_Training_K_Loss.csv", "Losses_Validation_K_Loss.csv", save_path="k_train_plots.svg", loss_type="k")
plot_tensorboard_metrics("Losses_Training_Alpha_Loss.csv", "Losses_Validation_Alpha_Loss.csv", save_path="alpha_train_plots.svg", loss_type="alpha")

# "Y:\Rasched_2\thesis_figures\andi_figures_for_thesis\Losses_Validation Alpha Loss.csv"
# "Y:\Rasched_2\thesis_figures\andi_figures_for_thesis\Losses_Validation_K_Loss.csv"
# "Y:\Rasched_2\thesis_figures\andi_figures_for_thesis\Losses_Training_K_Loss.csv"
# "Y:\Rasched_2\thesis_figures\andi_figures_for_thesis\Losses_Training Alpha Loss.csv"

# Example track ruptures

In [22]:
AlphaModel = RegressionModel().to(DEVICE)
KModel = RegressionModel().to(DEVICE)
StateModel = ClassificationModel().to(DEVICE)

AlphaModel.load_state_dict(torch.load("models/optimal_weights/alpha_weights_with_fixed"))
KModel.load_state_dict(torch.load("models/optimal_weights/k_weights"))
StateModel.load_state_dict(torch.load("models/optimal_weights/state_weights"))

AlphaModel.eval()
KModel.eval()
StateModel.eval()

def getPredictions(df, max_size=200):
    
    features = np.nan_to_num(getFeatures(df["x"].values, df["y"].values), nan=0.0, posinf=0.0, neginf=0.0)
    features = torch.tensor(features, dtype=torch.float32, device=DEVICE).unsqueeze(0)
    length = features.size(1)

    if length < max_size:
        features = F.pad(features, (0, 0, 0, max_size - length), value=FEATURE_PADDING_VALUE)
    elif length > max_size:
        features = features[:, :max_size]
        print(f"Note that the input series is longer than the maximum size. The input series has been truncated to the first {max_size} values.")

    with torch.no_grad():
        # convert to numpy arrays for downstream analysis
        pred_alpha_list = AlphaModel(features).cpu().numpy().flatten().squeeze()[:length]
        pred_k_list = KModel(features).cpu().numpy().flatten().squeeze()[:length]
        states_log_probs = StateModel(features)
        pred_states_list = torch.argmax(states_log_probs, dim=-1).cpu().numpy().flatten().squeeze()[:length]

    return pred_alpha_list, pred_k_list, pred_states_list

def multi_state(index=0):
    dic = _get_dic_andi2(2)
    dic['T'] = 200 
    dic['N'] = 100
    dic['M'] = np.array([[0.7, 0.3],[0.4, 0.6]])
    dic['alphas'] = np.array([[1.5,0],[0.5,0]])
    dic["Ds"] = np.array([[0.6,0],[1,0]])
    
    dfs_traj, _, _ = challenge_phenom_dataset(save_data=False, 
                                        dics=[dic], 
                                        return_timestep_labs=True, 
                                        get_video=False)
    
    df = dfs_traj[0]

    for i in df["traj_idx"].unique():
        if len(df[df["traj_idx"] == i]) == 200 and len(np.unique(df[df["traj_idx"]]["D"].values) == 2):
            index = i
            break

    return df[df["traj_idx"] == index]


def getCP_gt(array):
    cps = [0]
    for i in range(1, len(array)):
        if array[i-1] != array[i]:
            cps.append(i)

    return cps + [len(array)]

def getCP_rpt(array, lower_limit=0, upper_limit=float("inf"), threshold=0.05):
    array = median_filter_1d(smooth_series(array, lower_limit=lower_limit, upper_limit=upper_limit))
    if np.max(array) != np.min(array):
        pred_series_scaled = (array - np.min(array)) / (np.max(array) - np.min(array))
    else:
        pred_series_scaled = np.ones(len(array)) * 0.5 #scale them to default value of 0.5

    algo = rpt.Pelt(model="l2", min_size=3, jump=1).fit(pred_series_scaled)
    cps = [0] + algo.predict(pen=0.3)

    remove = []
    for i in range(1, len(cps) - 1):
        left_mean = array[cps[i - 1]:cps[i]].mean()
        right_mean = array[cps[i]:cps[i + 1]].mean()        
        if abs(left_mean - right_mean) < threshold:
            remove.append(cps[i])
    
    cps = [cp for cp in cps if cp not in remove]

    return cps, array


def plot_time_series_prediction(time, pred, truth, variable_type, save_path=None):
    """
    Create publication quality plot comparing predictions and ground truth
    
    Parameters:
    -----------
    time : array-like
        Time points
    pred : array-like
        Prediction values
    truth : array-like
        Ground truth values
    variable_type : str
        One of 'alpha', 'K', or 'state' to determine color scheme
    save_path : str, optional
        Path to save the output plot
    """
    # Set color based on variable type
    if variable_type.lower() == 'alpha':
        color = '#E69F00'  # Orange
    elif variable_type.lower() == 'k':
        color = '#1B9E77'  # Green
    else:  # state
        color = '#9970AB'  # Purple
    
    # Create figure
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)

    # Plot predictions with markers
    ax.plot(time, pred, color=color, linestyle='none', 
            marker='o', markersize=6, alpha=0.7,
            label='Prediction')
            
    # Plot ground truth with solid line
    ax.plot(time, truth, color=color, linestyle='-', 
            linewidth=2, alpha=1,
            label='Ground Truth')

    # Configure axis and style
    ax.set_axisbelow(True)
    
    # Configure tick parameters
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)

    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')

    # Set labels
    ax.set_xlabel('Time', fontsize=55)
    
    # Set y-label based on variable type
    if variable_type.lower() == 'alpha':
        ylabel = r'$\alpha$'
    elif variable_type.lower() == 'k':
        ylabel = r'$K$'
    else:  # state
        ylabel = 'State'
    ax.set_ylabel(ylabel, fontsize=55)
    
    # Configure legend
    ax.legend(fontsize=32, frameon=True, loc='best')

    # Adjust layout
    fig.tight_layout(pad=1.0)

    # Save the plot
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight', 
                          pad_inches=0.1, format='svg')
    
    plt.show()
    
    return fig, ax

In [None]:
def plot_k_prediction_with_ruptures(df, save_path=None):
    """
    Create publication quality plot comparing K predictions and ground truth with rupture points
    """
    # Get predictions
    pred_alpha, pred_k, pred_states = getPredictions(df)
    
    # Get ground truth
    ground_truth_k = df["D"].values
    time = np.arange(len(ground_truth_k))
    
    # Get rupture points using both ground truth and predictions
    gt_cps = getCP_gt(ground_truth_k)
    pred_cps, smoothed_k = getCP_rpt(pred_k)
    
    # Create figure using plt instead of Figure
    plt.figure(figsize=(8, 6), dpi=300)
    ax = plt.gca()
    
    # Set K color (green)
    color = '#1B9E77'
    
    # Plot predictions with markers
    ax.plot(time, pred_k, color=color, linestyle='none', 
            marker='o', markersize=6, alpha=0.7,
            label='Prediction')
    
    # Plot ground truth with solid line
    ax.plot(time, ground_truth_k, color=color, linestyle='-', 
            linewidth=2, alpha=1,
            label='Ground Truth')
    
    # Plot rupture lines
    ymin, ymax = ax.get_ylim()
    y_range = ymax - ymin
    
    # Ground truth rupture lines
    for cp in gt_cps:
        ax.axvline(x=cp, color='black', linestyle='--', 
                  linewidth=2, alpha=0.5)
    
    # Configure axis and style
    ax.set_axisbelow(True)
    
    # Configure tick parameters
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)
    
    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels
    ax.set_xlabel('Time', fontsize=55)
    ax.set_ylabel(r'$K$', fontsize=55)
    
    # Configure legend
    ax.legend(fontsize=32, frameon=True, loc='best')
    
    # Adjust layout
    plt.tight_layout(pad=1.0)
    
    # Save the plot
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', 
                   pad_inches=0.1, format='svg')
    
    fig = plt.gcf()
    return fig, ax

def plot_k_prediction_with_ruptures_new(df, save_path=None):
    pred_alpha, pred_k, pred_states = getPredictions(df)
    ground_truth_k = df["D"].values
    time = np.arange(len(ground_truth_k))
    gt_cps = getCP_gt(ground_truth_k)
    pred_cps, smoothed_k = getCP_rpt(pred_k)
    
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Plot predictions with filled markers
    ax.plot(time, pred_k, color=K_COLOR, linestyle='none', 
            marker='o', markersize=8, 
            markerfacecolor=K_COLOR,
            markeredgecolor=K_COLOR,
            markeredgewidth=2,
            alpha=1,
            label='Prediction')
    
    # Plot ground truth
    ax.plot(time, ground_truth_k, color=K_COLOR, linestyle='-', 
            linewidth=2, alpha=1,
            label='Ground Truth')
    
    for cp in gt_cps:
        ax.axvline(x=cp, color='black', linestyle='--', 
                  linewidth=2, alpha=0.5)
    
    ax.set_xticks([0, 200])
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    ax.set_xlabel('Time', fontsize=32)
    ax.set_ylabel(r'$K$', fontsize=32, rotation=0, va='center', labelpad=20)
    
    ax.legend(fontsize=32, frameon=False, loc='best')
    
    fig.tight_layout(pad=1.0)
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')
    
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight',
                          pad_inches=0.1, format='svg')
    
    return fig, ax



# Run the function
# df = multi_state()
# data = np.load("plots/true_k_ruptures_example.npy")
# print(len(data))
# fig, ax = plot_k_prediction_with_ruptures_new(df, save_path='k_prediction_ruptures_new.svg')

# Run the function
# df = multi_state()
# fig, ax = plot_k_prediction_with_ruptures(df, save_path='k_prediction_ruptures.svg')
# print(f"{fig}, {ax}")

In [None]:
def plot_k_prediction_with_ruptures_from_data(true_k, save_path=None):
    time = np.arange(len(true_k))
    pred_cps, smoothed_k = getCP_rpt(true_k)
    
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Plot predictions with filled markers
    ax.plot(time, smoothed_k, color=K_COLOR, linestyle='none', 
            marker='o', markersize=8, 
            markerfacecolor=K_COLOR,
            markeredgecolor=K_COLOR,
            markeredgewidth=2,
            alpha=1,
            label='Predictions')
    
    # Plot ground truth
    # ax.plot(time, true_k, color="black", linestyle='-', 
    #         linewidth=2, alpha=1,
    #         label='Ground Truth')
    
    for cp in pred_cps:
        ax.axvline(x=cp, color='black', linestyle='--', 
                  linewidth=2, alpha=1)
    
    ax.set_xticks([0, 200])
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    ax.set_xlabel('Time', fontsize=32)
    ax.set_ylabel(r'$K$', fontsize=32, rotation=0, va='center', labelpad=20)
    
    ax.plot([], [], color='black', linestyle='-', 
        linewidth=2, alpha=1,
        label='Ground Truth')
    
    # Add line for legend only
    ax.plot([], [], color='black', linestyle='--', 
            linewidth=2, alpha=1,
            label='Predicted CPs')
    
    
    ax.legend(fontsize=28, frameon=False, loc='best')
    
    fig.tight_layout(pad=1.0)
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')
    
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight',
                          pad_inches=0.1, format='svg')
    
    return fig, ax


data = np.load("plots/true_k_ruptures_example.npy")
plot_k_prediction_with_ruptures_from_data(data, save_path="ruptures_example_from_data.svg")

In [None]:
# # For alpha variable
# plot_time_series_prediction(time_points, alpha_pred, alpha_truth, 'alpha', 
#                           save_path='alpha_ruptures.svg')

# # For K variable
# plot_time_series_prediction(time_points, k_pred, k_truth, 'K',
#                           save_path='k_ruptures.svg')

# # For state variable
# plot_time_series_prediction(time_points, state_pred, state_truth, 'state',
#                           save_path='state_ruptures.svg')

# Example tracks

In [55]:
def generate_alpha_data():
    """
    Generate synthetic alpha data with a single changepoint:
    - Length: 200 points
    - Changepoint at t=100
    - Values: 1.5 -> 0.5
    """
    # Initialize arrays
    time = np.arange(200)
    true_alpha = np.zeros(200)
    
    # Set values
    true_alpha[:100] = 1.5  # First segment
    true_alpha[100:] = 0.5  # Second segment
    
    # Generate some noisy predictions around the true values
    np.random.seed(42)  # For reproducibility
    pred_alpha = true_alpha + np.random.normal(0, 0.01, size=200)
    
    return pred_alpha, true_alpha

# Generate the data
pred_alpha, true_alpha = generate_alpha_data()

# Create the plot
fig, ax = plot_track_variable(pred_alpha, true_alpha, label_type="alpha", 
                            save_path="synthetic_alpha.svg")

In [4]:
from matplotlib.figure import Figure
from matplotlib.backends.backend_svg import FigureCanvasSVG
import numpy as np

def plot_colored_trajectory(df, save_path=None):
    """
    Plot trajectory with different colors based on state values,
    with lines colored based on the destination point's state.
    Blue (362cfcff) for state 2
    Red (ff0000ff) for states 0 and 1
    """
    # Create figure
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Create masks for different states
    state_2_mask = df['state'] == 2
    state_01_mask = df['state'].isin([0, 1])
    
    # # Plot points for each state
    # ax.scatter(df.loc[state_2_mask, 'x'], 
    #           df.loc[state_2_mask, 'y'],
    #           color='#362cfc',
    #           s=50,
    #           alpha=1,
    #           label='State 2')
              
    # ax.scatter(df.loc[state_01_mask, 'x'], 
    #           df.loc[state_01_mask, 'y'],
    #           color='#ff0000',
    #           s=50,
    #           alpha=1,
    #           label='State 0/1')
    
    # Connect points with colored lines based on destination point state
    for i in range(len(df)-1):
        color = '#362cfc' if df['state'].iloc[i+1] == 2 else '#ff0000'
        ax.plot([df['x'].iloc[i], df['x'].iloc[i+1]], 
                [df['y'].iloc[i], df['y'].iloc[i+1]], 
                color=color,
                linewidth=1)
    
    # Style the plot
    ax.set_xlabel('X', fontsize=32)
    ax.set_ylabel('Y', fontsize=32)
    
    # Configure ticks
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    ax.tick_params(axis='both', labelsize=32)
    
    # Style spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    ax.set_axisbelow(True)
    
    # Add legend with matching font size
    ax.legend(fontsize=28, frameon=True, loc='best')
    
    # Adjust layout
    fig.tight_layout(pad=1.0)
    
    # Save if path provided
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight',
                          pad_inches=0.1, format='svg')
    
    return fig, ax

# Usage:
# df = pd.read_csv(file)
# fig, ax = plot_colored_trajectory(df, save_path='trajectory_colored.svg')

In [5]:
from matplotlib.figure import Figure
from matplotlib.backends.backend_svg import FigureCanvasSVG
import numpy as np

def plot_colored_trajectory(df, save_path=None, multi=False):
    """
    Plot trajectory with different colors based on either:
    - State values (default): Blue for state 2, Red for states 0/1
    - Alpha changes (if multi=True): Alternates between blue and red when alpha changes
    """
    # Create figure
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Create masks for different states
    state_2_mask = df['state'] == 2
    state_01_mask = df['state'].isin([0, 1])
    
    # # Plot points for each state (commented out)
    # ax.scatter(df.loc[state_2_mask, 'x'], 
    #           df.loc[state_2_mask, 'y'],
    #           color='#362cfc',
    #           s=50,
    #           alpha=1,
    #           label='State 2')
              
    # ax.scatter(df.loc[state_01_mask, 'x'], 
    #           df.loc[state_01_mask, 'y'],
    #           color='#ff0000',
    #           s=50,
    #           alpha=1,
    #           label='State 0/1')
    
    if multi:
        # For multi mode, detect alpha changes and alternate colors
        use_blue = True  # Start with blue
        current_alpha = df['alpha'].iloc[0]
        
        for i in range(len(df)-1):
            # Check if alpha changed
            if df['alpha'].iloc[i+1] != current_alpha:
                use_blue = not use_blue  # Switch color
                current_alpha = df['alpha'].iloc[i+1]
            
            color = '#362cfc' if use_blue else '#ff0000'
            ax.plot([df['x'].iloc[i], df['x'].iloc[i+1]], 
                    [df['y'].iloc[i], df['y'].iloc[i+1]], 
                    color=color,
                    linewidth=1)
    else:
        # Original state-based coloring
        for i in range(len(df)-1):
            color = '#362cfc' if df['state'].iloc[i+1] == 2 else '#ff0000'
            ax.plot([df['x'].iloc[i], df['x'].iloc[i+1]], 
                    [df['y'].iloc[i], df['y'].iloc[i+1]], 
                    color=color,
                    linewidth=1)
    
    # Style the plot
    ax.set_xlabel('X', fontsize=32)
    ax.set_ylabel('Y', fontsize=32)
    
    # Configure ticks
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    ax.tick_params(axis='both', labelsize=32)
    
    # Style spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    ax.set_axisbelow(True)
    
    # Add legend with matching font size (commented out since scatter is disabled)
    # ax.legend(fontsize=28, frameon=True, loc='best')
    
    # Adjust layout
    fig.tight_layout(pad=1.0)
    
    # Save if path provided
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight',
                          pad_inches=0.1, format='svg')
    
    return fig, ax

# Usage:
# df = pd.read_csv(file)
# Normal state-based coloring:
# fig, ax = plot_colored_trajectory(df, save_path='trajectory_state_colored.svg')
# Alpha-change based coloring:
# fig, ax = plot_colored_trajectory(df, save_path='trajectory_alpha_colored.svg', multi=True)

In [None]:
def plot_track_variable_old(pred, true, label_type, save_path=None):
    """
    Plot predicted and true values for track variables with specific integer ticks.
    """
    # # Define colors
    # ALPHA_COLOR = '#E69F00'  # Orange
    # K_COLOR = '#1B9E77'      # Green
    # STATE_COLOR = '#9970AB'  # Purple
    
    # Set plot parameters based on variable type
    plot_params = {
        "alpha": {
            "color": ALPHA_COLOR,
            "ymin": 0,
            "ymax": 2,
            "label": r'$α$',
            "yticks": [0, 2]
        },
        "k": {
            "color": K_COLOR,
            "ymin": 0,
            "ymax": 1,
            "label": r'$K$',
            "yticks": [0, 1]
        },
        "state": {
            "color": STATE_COLOR,
            "ymin": 0,
            "ymax": 3,
            "label": r'$s$',
            "yticks": [0, 1, 2, 3]
        }
    }
    
    if label_type not in plot_params:
        raise ValueError(f"label_type must be one of {list(plot_params.keys())}")
    
    params = plot_params[label_type]
    
    def getCP_gt(array):
        """Get change points in ground truth data"""
        cps = [0]
        for i in range(1, len(array)):
            if array[i-1] != array[i]:
                cps.append(i)
        return cps + [len(array)]

    # Create figure
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Plot predictions
    time = np.arange(len(pred))

    ax.plot(time, pred, color=params["color"], linestyle='none',
            marker='o', markersize=6, markerfacecolor=params["color"],
            markeredgecolor=params["color"], markeredgewidth=3,
            alpha=1)
    
    # Plot ground truth segments
    cp_gt = getCP_gt(true)
    for i in range(len(cp_gt)-1):
        start_idx = cp_gt[i]
        end_idx = cp_gt[i+1]
        if end_idx == len(true):
            end_idx -= 1
        segment_value = true[start_idx]
        ax.plot([start_idx, end_idx], [segment_value, segment_value],
                color='black', linewidth=3, alpha=1)

    # Set axis limits with margin
    # margin = (params["ymax"] - params["ymin"]) * 0.1
    # ax.set_ylim(params["ymin"] - margin, params["ymax"] + margin)
    ax.set_ylim(params["ymin"], params["ymax"])
    ax.set_axisbelow(True)
    
    # Set specific integer ticks
    ax.set_yticks(params["yticks"])
    # Configure ticks
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    ax.tick_params(axis='both', labelsize=32)

    # Style spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels
    ax.set_xlabel('Time', fontsize=32)
    ax.set_ylabel(params["label"], fontsize=32)
    ax.set_xticks([0, 200])
    # Adjust layout
    fig.tight_layout(pad=1.0)
    
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')
    
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight',
                          pad_inches=0.1, format='svg')
    
    return fig, ax

def plot_track_variable(pred, true, label_type, save_path=None):
    plot_params = {
        "alpha": {
            "color": ALPHA_COLOR,
            "ymin": 0,
            "ymax": 2,
            "label": r'$α$',
            "yticks": [0, 2]
        },
        "k": {
            "color": K_COLOR,
            "ymin": 0,
            "ymax": 1,
            "label": r'$D$',
            "yticks": [0, 1]
        },
        "state": {
            "color": STATE_COLOR,
            "ymin": 0,
            "ymax": 3,
            "label": r'$s$',
            "yticks": [0, 1, 2, 3]
        }
    }
    
    if label_type not in plot_params:
        raise ValueError(f"label_type must be one of {list(plot_params.keys())}")
    
    params = plot_params[label_type]
    
    def getCP_gt(array):
        cps = [0]
        for i in range(1, len(array)):
            if array[i-1] != array[i]:
                cps.append(i)
        return cps + [len(array)]

    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    # Plot ground truth segments
    cp_gt = getCP_gt(true)
    for i in range(len(cp_gt)-1):
        start_idx = cp_gt[i]
        end_idx = cp_gt[i+1]
        if end_idx == len(true):
            end_idx -= 1
        segment_value = true[start_idx]
        
        # Plot predictions for each segment
        segment_length = end_idx - start_idx
        if segment_length <= 6:  # For short segments
            segment_time = np.linspace(start_idx, end_idx, 3)
        else:
            segment_time = np.arange(start_idx, end_idx, 12)  # Take every 3rd point
            if end_idx not in segment_time:  # Include end point
                segment_time = np.append(segment_time, end_idx)
            
        segment_pred = np.interp(segment_time, np.arange(len(pred)), pred)
        
        ax.plot(segment_time, segment_pred, color=params["color"], linestyle='none',
                marker='o', markersize=30, markerfacecolor=params["color"],
                markeredgecolor=params["color"], markeredgewidth=3,
                alpha=1)
        
        ax.plot([start_idx, end_idx], [segment_value, segment_value],
                color='black', linewidth=6, alpha=1)
        
    margin = (params["ymax"] - params["ymin"]) * 0.05
    ax.set_ylim(params["ymin"] - margin, params["ymax"] + margin)
    ax.set_axisbelow(True)
    ax.set_yticks(params["yticks"])
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2, labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    ax.set_xlabel('Time', fontsize=32)
    ax.set_ylabel(params["label"], fontsize=32)
    ax.set_xticks([0, len(pred)])
    
    fig.tight_layout(pad=1.0)
    fig.patch.set_facecolor('none')
    ax.set_facecolor('none')
    
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight',
                          pad_inches=0.1, format='svg')
    
    return fig, ax

file = "plots/short_trajectory_single_example.csv"
save_path = file.replace(".csv", "_bigger")
os.makedirs(save_path, exist_ok=True)
df = pd.read_csv(file)

pred_alpha_list, pred_k_list, pred_states_list = getPredictions(df, max_size=200)

df = pd.read_csv(file)
fig, ax = plot_colored_trajectory(df, save_path=os.path.join(save_path,'trajectory.svg'), multi=True)

plot_track_variable(pred_alpha_list, df["alpha"].values, label_type="alpha", save_path=os.path.join(save_path,'alpha.svg'))
gt_k = np.log10(df["D"].values + 1)
plot_track_variable(pred_k_list, gt_k, label_type="k", save_path=os.path.join(save_path,'k.svg'))
plot_track_variable(pred_states_list, df["state"].values, label_type="state", save_path=os.path.join(save_path,'state.svg'))