This code takes two types of input data:  
1) The feature data obtained from `data/autoencoder` 
 2) The segment indices from `data/audio_cut`  
To count the number of original audio files, an intermediate code block requires the path to the audio files as additional input. 

The final output generates two Excel files:  
1) The first file stores the evaluation metrics obtained on the test set  
2) The second file contains the predicted labels for each audio file in the test set  

Note: The indices from `data/audio_cut` are used to map the segment-level predictions back to their corresponding original audio files for final evaluation and reporting.


In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [None]:
import math
import time
from functools import partial

import diffrax
import equinox as eqx 
import IPython
import jax 
import jax.nn as jnn
import jax.numpy as jnp  
import jax.random as jr 
import jax.scipy as jsp
import librosa
import matplotlib
import matplotlib.pyplot as plt  
import numpy as np
import optax  
import soundfile as sf
import torch
import torch.nn as nn
import torchaudio
from jax import nn as jnn
from sklearn.metrics import f1_score, precision_score, recall_score
from torchaudio.utils import download_asset

In [None]:
from typing import List

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr


def binary_operator_diag(element_i, element_j):
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j


class GLU(eqx.Module):
    w1: eqx.nn.Linear
    w2: eqx.nn.Linear

    def __init__(self, input_dim, output_dim, key):
        w1_key, w2_key = jr.split(key, 2)
        self.w1 = eqx.nn.Linear(input_dim, output_dim, use_bias=True, key=w1_key)
        self.w2 = eqx.nn.Linear(input_dim, output_dim, use_bias=True, key=w2_key)

    def __call__(self, x):
        return self.w1(x) * jax.nn.sigmoid(self.w2(x))


class LRULayer(eqx.Module):
    nu_log: jnp.ndarray
    theta_log: jnp.ndarray
    B_re: jnp.ndarray
    B_im: jnp.ndarray
    C_re: jnp.ndarray
    C_im: jnp.ndarray
    D: jnp.ndarray
    gamma_log: jnp.ndarray

    def __init__(self, N, H, r_min=0, r_max=1, max_phase=6.28, *, key):
        u1_key, u2_key, B_re_key, B_im_key, C_re_key, C_im_key, D_key = jr.split(key, 7)

        # N: state dimension, H: model dimension
        # Initialization of Lambda is complex valued distributed uniformly on ring
        # between r_min and r_max, with phase in [0, max_phase].
        u1 = jr.uniform(u1_key, shape=(N,))
        u2 = jr.uniform(u2_key, shape=(N,))
        self.nu_log = jnp.log(
            -0.5 * jnp.log(u1 * (r_max**2 - r_min**2) + r_min**2)
        )
        self.theta_log = jnp.log(max_phase * u2)

        # Glorot initialized Input/Output projection matrices
        self.B_re = jr.normal(B_re_key, shape=(N, H)) / jnp.sqrt(2 * H)
        self.B_im = jr.normal(B_im_key, shape=(N, H)) / jnp.sqrt(2 * H)
        self.C_re = jr.normal(C_re_key, shape=(H, N)) / jnp.sqrt(N)
        self.C_im = jr.normal(C_im_key, shape=(H, N)) / jnp.sqrt(N)
        self.D = jr.normal(D_key, shape=(H,))

        # Normalization factor
        diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        self.gamma_log = jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2))

    def __call__(self, x):
        # Materializing the diagonal of Lambda and projections
        Lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(
            jnp.exp(self.gamma_log), axis=-1
        )
        C = self.C_re + 1j * self.C_im
        # Running the LRU + output projection
        Lambda_elements = jnp.repeat(Lambda[None, ...], x.shape[0], axis=0)
        Bu_elements = jax.vmap(lambda u: B_norm @ u)(x)
        elements = (Lambda_elements, Bu_elements)
        _, inner_states = jax.lax.associative_scan(
            binary_operator_diag, elements
        ) 
        y = jax.vmap(lambda z, u: (C @ z).real + (self.D * u))(inner_states, x)

        return y


class LRUBlock(eqx.Module):
    lru: LRULayer
    glu: GLU
    drop: eqx.nn.Dropout

    def __init__(self, N, H, r_min=0, r_max=1, max_phase=6.28, drop_rate=0.1, *, key):
        lrukey, glukey = jr.split(key, 2)
        self.lru = LRULayer(N, H, r_min, r_max, max_phase, key=lrukey)
        self.glu = GLU(H, H, key=glukey)
        self.drop = eqx.nn.Dropout(p=drop_rate)

    def __call__(self, x, *, key):
        dropkey1, dropkey2 = jr.split(key, 2)
        skip = x
        x = self.lru(x)
        x = self.drop(jax.nn.gelu(x), key=dropkey1)
        x = jax.vmap(self.glu)(x)
        x = self.drop(x, key=dropkey2)
        x = skip + x
        return x


class LRU(eqx.Module):
    linear_encoder: eqx.nn.Linear
    blocks: List[LRUBlock]
    linear_layer: eqx.nn.Linear
    classification: bool
    output_step: int
    stateful: bool = True
    nondeterministic: bool = True
    lip2: bool = False

    def __init__(
        self,
        num_blocks,
        data_dim,
        N,
        H,
        output_dim,
        classification,
        output_step,
        r_min=0,
        r_max=1,
        max_phase=6.28,
        drop_rate=0.1,
        *,
        key
    ):
        linear_encoder_key, *block_keys, linear_layer_key = jr.split(
            key, num_blocks + 2
        )
        self.linear_encoder = eqx.nn.Linear(data_dim, H, key=linear_encoder_key)
        self.blocks = [
            LRUBlock(N, H, r_min, r_max, max_phase, drop_rate, key=key)
            for key in block_keys
        ]
        self.linear_layer = eqx.nn.Linear(H, 1, key=linear_layer_key)
        self.classification = classification
        self.output_step = output_step

    def __call__(self, x,key,*,inference=False):
        dropkeys = jr.split(key, len(self.blocks))
        x = x[:, 1:] 
        x = jax.vmap(self.linear_encoder)(x)
        for block, key in zip(self.blocks, dropkeys):
            x = block(x,key=key)
        x = jnp.mean(x, axis=0)
        x = self.linear_layer(x)
        (x,) = jnn.sigmoid(x)
        return x

In [None]:
features1 = torch.load('/path/to/your/train_features_auto.pt')
features2 = torch.load('/path/to/your/test_features_auto.pt')

indices_train = torch.load('/path/to/your/train_indices.pt')
indices_test = torch.load('/path/to/your/test_indices.pt')
labels1=torch.load('/path/to/your/train_labels.pt')
labels2=torch.load('/path/to/your/test_labels.pt')

In [None]:
features_np = features1
features_np_test =features2
labels1_np=labels1.detach().cpu().numpy()
labels2_np=labels2.detach().cpu().numpy()

features_jax = jnp.array(features_np)
features_jax_test=jnp.array(features_np_test)
labels_jax=jnp.array(labels1_np)
labels_jax_test=jnp.array(labels2_np)

In [None]:
def preprocess_data(features):
    mean = features.mean((0, 1), keepdims=True)  
    std = features.std((0, 1), keepdims=True)   
    standardized_features = (features - mean) / (std + 1e-8)  # 防止除以零
    
    return standardized_features

def get_data(features):
    ts = jnp.linspace(0,1, features.shape[1])
    ts1 = jnp.repeat(ts[None, :], features.shape[0], axis=0)
    normalized_features = preprocess_data(features)
    time_steps_expanded = ts1[:, :, None]  
    features_with_time = jnp.concatenate([time_steps_expanded,normalized_features], axis=2) 
    return features_with_time

In [None]:
X_train=get_data(features_jax)
X_test=get_data(features_jax_test)

y_train=labels_jax
y_test=labels_jax_test

In [None]:
def count_audio_files(directory):

    audio_extensions = ('.wav', '.mp3', '.flac', '.aac', '.ogg', '.m4a', '.wma')
    audio_file_count = 0

    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(audio_extensions):
                audio_file_count += 1

    return audio_file_count

train_ccn=count_audio_files("your_audio_cc_train")
train_cdn=count_audio_files("your_audio_cd_train")
test_ccn=count_audio_files("your_audio_cc_test")
test_cdn=count_audio_files("your_audio_cd_test")
print("train_ccn:",train_ccn)
print("train_cdn:",train_cdn)
print("test_ccn:",test_ccn)
print("test_cdn:",test_cdn)

In [None]:
labels_test = jnp.concatenate([jnp.zeros(test_ccn), jnp.ones(test_cdn)])
labels_train = jnp.concatenate([jnp.zeros(train_ccn), jnp.ones(train_cdn)])

In [None]:
X_train.shape

In [None]:
# Define the Dataloader class
class Dataloader:
    data: jnp.ndarray  
    labels: jnp.ndarray  
    size: int 

    def __init__(self, data, labels):
        self.data = data  
        self.labels = labels  
        self.size = len(data) 

    def loop(self, batch_size, *, key=None):
        if batch_size == self.size:
            yield self.data, self.labels

        indices = jnp.arange(self.size) 
        while True:
            subkey, key = jr.split(key)  
            perm = jr.permutation(subkey, indices)  
            start = 0
            end = batch_size
            while end < self.size:
                batch_perm = perm[start:end] 
                yield self.data[batch_perm], self.labels[batch_perm]
                start = end  
                end = start + batch_size  

# Initialise dataloaders for training and testing data
train_dataloader = Dataloader(X_train, y_train)
test_dataloader = Dataloader(X_test, y_test)

In [None]:
# Define the classification loss function with gradient calculation
@eqx.filter_jit
@eqx.filter_value_and_grad
def classification_loss(model, X, y, *, key):
    batch_size = X.shape[0]
    keys = jax.random.split(key, batch_size)
    def model_forward(x, k):
        return model(x, k, inference=False)

    pred_y = jax.vmap(model_forward)(X, keys)
    epsilon = 1e-7
    pred_y_clipped = jnp.clip(pred_y, epsilon, 1 - epsilon)
    loss = - ( y * jnp.log(pred_y_clipped) +  (1 - y) * jnp.log(1 - pred_y_clipped))    
    return jnp.mean(loss)

# Define the training step function with JIT compilation
@eqx.filter_jit
def train_step(model, X, y, opt, opt_state, *, key):
    key, subkey = jr.split(key)
    loss, grads = classification_loss(model, X, y,key=subkey)
    updates, opt_state = opt.update(grads, opt_state, params=trainable_params)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

In [None]:
import jax.numpy as jnp
from sklearn.metrics import roc_curve, auc, precision_recall_curve, f1_score, accuracy_score
import matplotlib.pyplot as plt

def evaluate_and_plot_roc_pr_curves(labels, predictions, plot_title_prefix=""):
    labels_np = jnp.asarray(labels)
    predictions_np = jnp.asarray(predictions)

    if labels_np.shape[0] != predictions_np.shape[0]:
        raise ValueError(f"Found input variables with inconsistent numbers of samples: {labels_np.shape[0]}, {predictions_np.shape[0]}")

    fpr, tpr, roc_thresholds = roc_curve(labels_np, predictions_np)
    roc_auc = auc(fpr, tpr)
    print(f"{plot_title_prefix} AUC (ROC):", roc_auc)

    youden_j = tpr - fpr
    best_threshold_index_roc = jnp.argmax(youden_j)
    best_threshold_roc = roc_thresholds[best_threshold_index_roc]

    precision, recall, pr_thresholds = precision_recall_curve(labels_np, predictions_np)
    pr_thresholds = pr_thresholds[:-1]
    f1_scores = 2 * (precision * recall) / (precision + recall)
    f1_scores = jnp.nan_to_num(f1_scores, nan=0.0)  # 将 NaN 值替换为 0，避免除以零的问题

    best_threshold_index_pr = jnp.argmax(f1_scores)

    if best_threshold_index_pr >= len(pr_thresholds):
        best_threshold_index_pr = len(pr_thresholds) - 1

    best_threshold_pr = pr_thresholds[best_threshold_index_pr]

    accuracies = []
    for threshold in pr_thresholds:
        preds = (predictions_np >= threshold).astype(int)
        accuracy = accuracy_score(labels_np, preds)
        accuracies.append(accuracy)
    accuracies = jnp.array(accuracies)

    best_threshold_index_accuracy = jnp.argmax(accuracies)
    best_threshold_accuracy = pr_thresholds[best_threshold_index_accuracy]

    return best_threshold_roc, best_threshold_pr, best_threshold_accuracy


In [None]:
import os
import pandas as pd

def save_model_results(
    seed,
    acc_test,
    f1,
    precision,
    recall,
    acc_test_vote,
    f1_vote,
    precision_vote,
    recall_vote,
    test_predictions1,
    test_predictions2,
    metrics_csv="metrics.csv",
    preds_csv="predictions.csv"
):

    if not os.path.exists(metrics_csv):
        df_metrics = pd.DataFrame(columns=["seed","acc_test","f1","precision","recall","acc_test_vote","f1_vote","precision_vote","recall_vote"])
    else:
        df_metrics = pd.read_csv(metrics_csv)

    if seed in df_metrics["seed"].values:
        print(f"[INFO] Metrics for Seed={seed} already exist in {metrics_csv}, skipping save.")
    else:
        new_row_df = pd.DataFrame([{
            "seed": seed,
            "acc_test": acc_test,
            "f1": f1,
            "precision": precision,
            "recall": recall,
            "acc_test_vote": acc_test_vote,
            "f1_vote": f1_vote,
            "precision_vote": precision_vote,
            "recall_vote": recall_vote
        }])
        df_metrics = pd.concat([df_metrics, new_row_df], ignore_index=True)
        df_metrics.to_csv(metrics_csv, index=False)
        print(f"[SUCCESS] Saved metrics for Seed={seed} to {metrics_csv}.")

    if not os.path.exists(preds_csv):
        df_preds = pd.DataFrame(columns=["seed","test_predictions1","test_predictions2"])
    else:
        df_preds = pd.read_csv(preds_csv)

    if seed in df_preds["seed"].values:
        print(f"[INFO] Prediction list for Seed={seed} already exists in {preds_csv}, skipping save.")
    else:
        test_preds1_str = str(test_predictions1)
        test_preds2_str = str(test_predictions2)

        new_row_preds_df = pd.DataFrame([{
            "seed": seed,
            "test_predictions1": test_preds1_str,
            "test_predictions2": test_preds2_str
        }])
        df_preds = pd.concat([df_preds, new_row_preds_df], ignore_index=True)
        df_preds.to_csv(preds_csv, index=False)
        print(f"[SUCCESS] Saved prediction list for Seed={seed} to {preds_csv}.")




In [None]:
import gc
import jax
import optax
import math
import equinox as eqx
import jax.numpy as jnp

def get_trainable_params(model):
    return eqx.filter(model, eqx.is_inexact_array)

def train_model(
    model,
    num_steps=415, 
    print_steps=40,  
    batch_size=32, 
    base_lr=3.5e-4, 
    warmup_steps = 84,
    weight_decay=0, 
    *,
    key,
    seed,
):
    global train_predictions1, test_predictions1,test_predictions2,trainable_params
    trainable_params = get_trainable_params(model)

    warmup_schedule = optax.linear_schedule(
        init_value=0.0, 
        end_value=base_lr, 
        transition_steps=warmup_steps,  
    )
    
    cosine_schedule = optax.cosine_decay_schedule(
        init_value=base_lr,  
        decay_steps=num_steps - warmup_steps, 
        alpha=0.01  
    )
    
    lr_schedule = optax.join_schedules(
        schedules=[warmup_schedule, cosine_schedule], 
        boundaries=[warmup_steps]  
    )
    
    opt = optax.adamw(learning_rate=lr_schedule, weight_decay=weight_decay)
    
    opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

    test_accs = [] 
    test_accs_vote = []
    steps = []  
    train_accs = [] 

    dataset_size = X_train.shape[0]
    steps_per_epoch = math.ceil(dataset_size / batch_size)
    total_epochs = math.ceil(num_steps / steps_per_epoch)

    for epoch in range(total_epochs):
        print(f"Epoch: {epoch + 1}")
        trainloopkey, key = jax.random.split(key)

        for step, data in zip(
            range(steps_per_epoch), train_dataloader.loop(batch_size, key=trainloopkey)
        ):
            start_time = time.time()
    
            X, y = data  
            key, subkey = jr.split(key)

            model, opt_state, loss = train_step(model, X, y, opt, opt_state, key=subkey)
            if step == 0 or (step + 1) % print_steps == 0 or step==(steps_per_epoch - 1):
                inference_model = eqx.nn.inference_mode(model)
                inference_model = eqx.Partial(inference_model,inference=True)

                for batch, data in zip(
                    range(1), train_dataloader.loop(train_dataloader.size)
                ):
                    X, y = data
                    keys = jax.random.split(jr.PRNGKey(0), X.shape[0])
                    output = jax.vmap(inference_model)(X, keys)
                    pre_train = output
                    train_acc = jnp.mean((output > 0.5) == (y == 1))

                for batch, data in zip(
                    range(1), test_dataloader.loop(test_dataloader.size)
                ):
                    X, y = data
                    keys = jax.random.split(jr.PRNGKey(0), X.shape[0])
                    output = jax.vmap(inference_model)(X, keys)

                    test_acc = jnp.mean((output > 0.5) == (y == 1))
                if step == steps_per_epoch - 1:
                    pre_test = output

                elapsed_time = time.time() - start_time
                print(f"Step: {step + 1}, Loss: {loss}, Train Acc: {train_acc}, Test Acc: {test_acc}, Time: {elapsed_time:.4f} seconds")
    
                steps.append(step + 1)
        audio_segments_train = {}       
        
        for idx, pred in zip(indices_train[:,0], pre_train):
            if idx.size == 1:
                idx = int(idx.item())
            else:
                raise ValueError(f"Unexpected idx size: {idx.size}, idx: {idx}")
            if idx in audio_segments_train:
                audio_segments_train[idx].append(pred)
            else:
                audio_segments_train[idx] = [pred]
        
        audio_predictions_train = {idx: jnp.mean(jnp.array(preds)) for idx, preds in audio_segments_train.items()}

        predictions1 = list(audio_predictions_train.values())
        predictions1 = jnp.array(predictions1) 


        train_predictions1 = [(idx, 1 if pred >= 0.5 else 0) for idx, pred in audio_predictions_train.items()]

        audio_segments_test = {}
        for idx, pred_val in zip(indices_test[:, 0], pre_test):
            if idx.size == 1:
                idx = int(idx.item()) 
            else:
                raise ValueError(f"Unexpected idx size: {idx.size}, idx: {idx}")
            if idx in audio_segments_test:
                audio_segments_test[idx].append(pred_val)
            else:
                audio_segments_test[idx] = [pred_val]
        
        audio_predictions_test = {idx: jnp.mean(jnp.array(preds)) for idx, preds in audio_segments_test.items()}
        audio_predictions_test_vote = {idx: 1 if jnp.sum(jnp.array(preds) > 0.5) > jnp.sum(jnp.array(preds) <= 0.5) else 0 for idx, preds in audio_segments_test.items()}
        
        values = jnp.array(list(audio_predictions_test.values()))

        evaluate_and_plot_roc_pr_curves(labels_train, predictions1, plot_title_prefix="Train")
        evaluate_and_plot_roc_pr_curves(labels_test, values, plot_title_prefix="Test")

        correct_predictions_test = 0
        predict_label=[]
        predict_label_vote=[]
        for idx, pred in audio_predictions_test.items():
            label = labels_test[idx]
            predict_label.append((pred>0.5))
            if (pred>0.5) == label:
                correct_predictions_test += 1
                
        acc_test = correct_predictions_test / len(audio_predictions_test)
                
        correct_predictions_test_vote = 0
        for idx, pred in audio_predictions_test_vote.items():
            label = labels_test[idx]
            predict_label_vote.append(pred)
            if pred == label:
                correct_predictions_test_vote += 1

        acc_test_vote = correct_predictions_test_vote / len(audio_predictions_test)

        
        predict_label_array=np.array(predict_label)
        predict_label_vote_array=np.array(predict_label_vote)
        labels_test_array=np.array(labels_test)

        precision = precision_score(labels_test_array,predict_label_array)
        precision_vote=precision_score(labels_test_array,predict_label_vote_array)

        recall = recall_score(labels_test_array,predict_label_array)
        recall_vote=recall_score(labels_test_array,predict_label_vote_array)

        f1 = f1_score(labels_test_array,predict_label_array)
        f1_vote=f1_score(labels_test_array,predict_label_vote_array)

        print('acc_test:', acc_test)
        print('acc_test_vote:', acc_test_vote)
        test_accs.append(acc_test)
        test_accs_vote.append(acc_test_vote)
        if epoch==total_epochs-1:
            predictions_list_test = []
            for idx, preds in audio_segments_test.items():
                mean_pred = jnp.mean(jnp.array(preds))
                predictions_list_test.append((idx, mean_pred))
            
            predictions_list_test.sort(key=lambda x: x[0])
            
            for idx, mean_pred in predictions_list_test:
                print(f"Audio segment test {idx}: Prediction value {mean_pred}")
            test_predictions1 = [(idx, 1 if pred >= 0.5 else 0) for idx, pred in predictions_list_test]
            test_predictions2 = [(idx, 1 if pred == 1 else 0) for idx, pred in audio_predictions_test_vote.items()]
            save_model_results(
                seed=seed,
                acc_test=acc_test,
                f1=f1,
                precision=precision,
                recall=recall,
                acc_test_vote=acc_test_vote,
                f1_vote=f1_vote,
                precision_vote=precision_vote,
                recall_vote=recall_vote,
                test_predictions1=test_predictions1,
                test_predictions2=test_predictions2,
                metrics_csv="your_solution.csv",
                preds_csv="your_pred.csv"
            )
        
    return acc_test,acc_test_vote,test_accs,test_accs_vote,train_predictions1, test_predictions1,test_predictions2




In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


In [None]:
import jax.numpy as jnp
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gc
def train_with_seeds(seeds):
    global hyperparameters, all_train_predictions, all_test_predictions_vote, all_test_predictions, acc_seeds
    hyperparameters = []
    all_train_predictions = []
    all_test_predictions = [] 
    all_test_predictions_vote = [] 
    acc_seeds = []
    all_test_accuracies = []
    all_test_accuracies_vote = []

    for seed in seeds:
        key = jax.random.PRNGKey(seed)
        modelkey, key = jax.random.split(key)
        trainkey, key = jax.random.split(key)
        LRU1 = LRU(
            num_blocks,
            data_dim,
            ssm_dim,
            hidden_dim,
            label_dim,
            classification,
            output_step,
            key=key,
        )
        
        try:
            print(f"Training with seed: {seed}")
            train_predictions1 = []
            test_predictions1 = []
            test_predictions2 = []

            acc_test, acc_test_vote, test_accs,test_accs_vote,train_predictions1, test_predictions1,test_predictions2 = train_model(LRU1, key=trainkey, seed=seed)

            acc_test_cpu = np.array(acc_test)
            acc_test_vote_cpu = np.array(acc_test_vote)
            test_accs_cpu = np.array(test_accs)
            test_accs_vote_cpu=np.array(test_accs_vote)

            test_predictions1_cpu = [np.array(p) for p in test_predictions1]
            test_predictions2_cpu = [np.array(p) for p in test_predictions2]

            all_train_predictions.append(train_predictions1)
            all_test_predictions.append(test_predictions1_cpu)
            all_test_predictions_vote.append(test_predictions2_cpu)
            all_test_accuracies.append(test_accs_cpu)
            all_test_accuracies_vote.append(test_accs_vote_cpu)



        finally:
            del key
            del modelkey
            del trainkey
            jax.device_put(None)
            gc.collect()
            jax.clear_caches() 


def vote_and_evaluate():
    train_votes = aggregate_votes(all_train_predictions)
    accuracy=evaluate_accuracy(train_votes, labels_train, "Train")

    test_votes = aggregate_votes(all_test_predictions)
    accuracy_test=evaluate_accuracy(test_votes, labels_test, "Test")
    test_votes=jnp.array(list(test_votes.values()))

    test_vote_votes = aggregate_votes(all_test_predictions_vote)
    accuracy_vote_test=evaluate_accuracy(test_vote_votes, labels_test, "Test_vote")
    test_vote_votes=jnp.array(list(test_vote_votes.values()))
    
    precision = precision_score(labels_test,test_votes)
    print(f'Precision: {precision}')
    
    recall = recall_score(labels_test,test_votes)
    print(f'Recall: {recall}')
    
    f1 = f1_score(labels_test,test_votes)
    print(f'F1 Score: {f1}')

    precision = precision_score(labels_test,test_vote_votes)
    print(f'Precision_vote: {precision}')
    
    recall = recall_score(labels_test,test_vote_votes)
    print(f'Recall_vote: {recall}')
    
    f1 = f1_score(labels_test,test_vote_votes)
    print(f'F1 Score_vote: {f1}')
    
    print('testvotes:',test_votes)
    print('test_vote_votes:',test_vote_votes)

def aggregate_votes(predictions_list_all_seeds):
    aggregated_predictions = {}
    for idx in range(len(predictions_list_all_seeds[0])):
        preds = [predictions_list_all_seeds[seed_idx][idx][1] for seed_idx in range(len(predictions_list_all_seeds))]
        vote_result = 1 if sum(pred == 1 for pred in preds) > len(preds) / 2 else 0
        aggregated_predictions[idx] = vote_result
    return aggregated_predictions

def evaluate_accuracy(aggregated_predictions, labels, dataset_name):
    correct_predictions = 0
    for idx, pred in aggregated_predictions.items():
        label = labels[idx]
        if pred == label:
            correct_predictions += 1
    accuracy = correct_predictions / len(aggregated_predictions)
    print(f'{dataset_name} Accuracy (after voting): {accuracy:.2f}')
    return accuracy


In [None]:
hidden_dim = 128
data_dim = 768
label_dim = 1
vf_hidden_dim = 512
vf_num_hidden = 3
ode_solver_stepsize = 1 / 500
stepsize = 15
num_blocks=6
ssm_dim=256
classification=True
output_step=1
dt0 = None
num_seeds = 50 
seeds = np.arange(1001, 1050 + 1).tolist()
train_with_seeds(seeds)