In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [None]:
import functools
import itertools
import math
import time
from collections import defaultdict
from functools import partial
from typing import Generator, Union

import diffrax
import equinox as eqx
import IPython
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import librosa
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax
import soundfile as sf
import torch
import torch.nn as nn
import torchaudio
from jax import nn as jnn
from scipy.sparse import csc_matrix
# Import specific functions from signax for handling log-signature calculation
from signax.signature import signature  # For computing signatures of paths
from signax.signature_flattened import flatten  # For flattening signatures
from signax.tensor_ops import log  # For converting to log-signatures
from sklearn.metrics import f1_score, precision_score, recall_score
from torchaudio.utils import download_asset


In [None]:
def tkey_to_index(width: int, tkey: Union[int, tuple[int]]) -> int:
    if isinstance(tkey, int):
        assert tkey <= width
        return tkey

    result = 0
    for letter in tkey:
        result *= width
        result += letter
    return result


def tensor_algebra_dimension(width: int, depth: int) -> int:
    result = 1
    for _ in range(depth):
        result *= width
        result += 1
    return result


def generate_tensor_keys_level(
    width: int, degree: int
) -> Generator[tuple[int], None, None]:
    if degree == 1:
        yield from ((i,) for i in range(1, width + 1))
        return

    for i in range(1, width + 1):
        yield from ((i, *r) for r in generate_tensor_keys_level(width, degree - 1))


def generate_tensor_keys(width: int, depth: int) -> Generator[tuple[int], None, None]:
    yield ()

    if depth == 0:
        return

    for i in range(1, width + 1):
        yield (i,)

    for degree in range(2, depth + 1):
        yield from generate_tensor_keys_level(width, degree)


class HallSet:
    def __init__(self, width, degree=1):
        self.width = width
        self.degree = 1

        self.data = data = []
        self.reverse_map = reverse_map = {}
        self.degree_ranges = degree_ranges = []
        self.sizes = sizes = []
        self.letters = letters = []
        self.l2k = l2k = {}

        data.append((0, 0))
        degree_ranges.append((0, 1))
        sizes.append(0)

        for letter in range(1, width + 1):
            parents = (0, letter)
            letters.append(letter)
            data.append(parents)
            reverse_map[parents] = letter
            l2k[letter] = letter

        degree_ranges.append((degree_ranges[0][1], len(data)))
        sizes.append(width)

        if degree > self.degree:
            self.grow_up(degree)

    def grow_up(self, degree):

        data = self.data
        reverse_map = self.reverse_map
        degree_ranges = self.degree_ranges

        while self.degree < degree:
            next_degree = self.degree + 1
            left = 1
            while 2 * left <= next_degree:
                right = next_degree - left

                ilower, iupper = degree_ranges[left]
                jlower, jupper = degree_ranges[right]

                i = ilower

                while i < iupper:
                    j = max(jlower, i + 1)
                    while j < jupper:
                        if data[j][0] <= i:
                            parents = (i, j)
                            data.append(parents)
                            reverse_map[parents] = len(data) - 1
                        j += 1
                    i += 1
                left += 1

            degree_ranges.append((degree_ranges[-1][1], len(data)))
            self.sizes.append(len(data))
            self.degree += 1

    @functools.lru_cache
    def key_to_string(self, key: int) -> str:
        assert key < len(self.data)

        left, right = self.data[key]

        if left == 0:
            return f"{right}"

        return f"[{self.key_to_string(left)}, {self.key_to_string(right)}]"

    @functools.lru_cache
    def product(self, lhs_key: int, rhs_key: int) -> list[tuple[int, int]]:
        if rhs_key < lhs_key:
            return [(k, -c) for k, c in self.product(rhs_key, lhs_key)]

        if lhs_key == rhs_key:
            return []

        if key := self.reverse_map.get((lhs_key, rhs_key)):
            return [(key, 1)]

        lparent, rparent = self.data[rhs_key]

        left_result = [
            (k, c1 * c)
            for (k1, c1) in self.product(lhs_key, lparent)
            for (k, c) in self.product(k1, rparent)
        ]
        right_result = [
            (k, -c1 * c)
            for (k1, c1) in self.product(lhs_key, rparent)
            for (k, c) in self.product(k1, lparent)
        ]
        result = defaultdict(lambda: 0)
        for k, c in left_result:
            result[k] += c
        for k, c in right_result:
            result[k] += c

        return list(result.items())

    @functools.lru_cache
    def expand(self, key: int) -> list[tuple[int, tuple[int]]]:
        if key in self.letters:
            return [((key,), 1)]

        assert key < len(self.data)
        lparent, rparent = self.data[key]

        left_expansion = self.expand(lparent)
        right_expansion = self.expand(rparent)

        left_terms = [
            ((*k1, *k2), c1 * c2)
            for (k1, c1), (k2, c2) in itertools.product(left_expansion, right_expansion)
        ]
        right_terms = [
            ((*k1, *k2), c1 * c2)
            for (k1, c1), (k2, c2) in itertools.product(right_expansion, left_expansion)
        ]

        result = defaultdict(lambda: 0)
        for k, c in left_terms:
            result[k] += c
        for k, c in right_terms:
            result[k] -= c

        return list(result.items())

    @functools.lru_cache
    def rbracket(self, tkey: Union[int, tuple[int]]) -> list[tuple[int, int]]:
        if isinstance(tkey, int):
            return [(tkey, 1)]

        if len(tkey) == 0:
            return []

        if len(tkey) == 1:
            return [(tkey[0], 1)]

        assert len(tkey) > 1, f"{tkey}"
        first, *remaining = tkey
        return [
            (k, c1 * c)
            for (k1, c1) in self.rbracket(tuple(remaining))
            for k, c in self.product(first, k1)
        ]

    def l2t_matrix(self, degree=None, dtype=np.float32) -> jnp.ndarray:
        degree = degree or self.degree
        tensor_alg_size = tensor_algebra_dimension(self.width, degree)

        indptr = [0, 0]
        indices = []
        data = []
        for lkey in range(1, self.sizes[degree]):
            for k, c in self.expand(lkey):
                indices.append(tkey_to_index(self.width, k))
                data.append(c)
            indptr.append(indptr[-1] + len(self.expand(lkey)))

        data = np.array(data, dtype=dtype)
        indices = np.array(indices, dtype=np.int64)
        indptr = np.array(indptr, dtype=np.int64)
        return jnp.array(
            csc_matrix(
                (data, indices, indptr),
                shape=(tensor_alg_size, self.sizes[degree]),
                dtype=dtype,
            ).toarray()
        )

    def t2l_matrix(self, degree=None, dtype=np.float32) -> jnp.ndarray:
        degree = degree or self.degree
        tensor_alg_size = tensor_algebra_dimension(self.width, degree)

        indptr = [0]
        indices = []
        data = []
        for tkey in generate_tensor_keys(self.width, degree):
            for k, c in self.rbracket(tkey):
                indices.append(k)
                data.append(c / len(tkey))
            indptr.append(len(data))
        data = np.array(data, dtype=dtype)
        indices = np.array(indices, dtype=np.int64)
        indptr = np.array(indptr, dtype=np.int64)

        return jnp.array(
            csc_matrix(
                (data, indices, indptr),
                shape=(self.sizes[degree], tensor_alg_size),
                dtype=dtype,
            ).toarray()
        )

In [None]:
def remat_conv(conv_layer, x):
    return jax.remat(conv_layer)(x)
class VectorField(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(
        self, in_size, out_size, width, depth, *, key,scale=1000
    ):
        mlp = eqx.nn.MLP(
            in_size=in_size,
            out_size=out_size,
            width_size=width,
            depth=depth,
            activation=jax.nn.silu,
            final_activation=jax.nn.tanh,
            key=key,
        )

        def init_weight(model):
            is_linear = lambda x: isinstance(x, eqx.nn.Linear)
            get_weights = lambda m: [
                x.weight
                for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                if is_linear(x)
            ]
            weights = get_weights(model)
            new_weights = [weight / scale for weight in weights]
            new_model = eqx.tree_at(get_weights, model, new_weights)
            get_bias = lambda m: [
                x.bias
                for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                if is_linear(x)
            ]
            biases = get_bias(model)
            new_bias = [bias / scale for bias in biases]
            new_model = eqx.tree_at(get_bias, new_model, new_bias)
            return new_model

        self.mlp = init_weight(mlp)

    def __call__(self, y):
        return self.mlp(y)
        
# Define the NeuralCDE class
class NeuralCDE(eqx.Module):
    vf: eqx.nn.MLP  
    data_dim: int 
    hidden_dim: int 
    ode_solver_stepsize: int 
    linear1: eqx.nn.Linear 
    linear2: eqx.nn.Linear  
    def __init__(
        self,
        hidden_dim,
        data_dim,
        label_dim,
        vf_hidden_dim,
        vf_num_hidden,
        ode_solver_stepsize,
        *,
        key,
    ):

        vf_key, l1key, l2key, conv_key = jr.split(key, 4)
                
        # Initialise the MLP vector field
        self.vf = VectorField(
            hidden_dim,
            hidden_dim * data_dim,
            vf_hidden_dim,
            vf_num_hidden,
            scale=1,
            key=vf_key,
        )

        self.linear1 = eqx.nn.Linear(data_dim, hidden_dim, key=l1key)
        self.linear2 = eqx.nn.Linear(hidden_dim, label_dim, key=l2key)
        self.hidden_dim = hidden_dim
        self.data_dim = data_dim
        self.ode_solver_stepsize = ode_solver_stepsize
        
    # Method to get the ODE term
    def get_ode(self, ts, X):

        # Compute backward Hermite coefficients for interpolation
        coeffs = diffrax.backward_hermite_coefficients(ts, X)
        # Create a cubic interpolation control term
        control = diffrax.CubicInterpolation(ts, coeffs)
        # Reshape output from vector field to be matrix-valued
        func = lambda t, y, args: jnp.reshape(
            self.vf(y), (self.hidden_dim, self.data_dim)
        )
        # Return the control term converted to an ODE
        return diffrax.ControlTerm(func, control).to_ode(), control
    
    # Forward pass method
    def __call__(self, X,key,*,inference=False):
        key, dropout_key1, dropout_key2,dropout_key3 = jr.split(key, 4)
        ts = X[:, 0] # Assume time is the first channel
    
        # Separate time and features
        X_features = X[:, 1:]  # Features without time
        

        X_with_time = jnp.concatenate([ts[:, None], X_features], axis=1)
        
    
        # Get the ODE term
        result = self.get_ode(ts, X_with_time)
    
        # Check if get_ode returned control
        if isinstance(result, tuple):
            ode_term, control = result
            # Initialise the hidden state using control
            h0 = self.linear1(control.evaluate(ts[0]))
        else:
            ode_term = result
            h0 = self.linear1(X_with_time[0, :])
    
        saveat = diffrax.SaveAt(t1=True) 
        # Solve the differential equation
        solution = jax.remat(diffrax.diffeqsolve(
            terms=ode_term,
            solver=diffrax.Heun(),
            t0=ts[0],
            t1=ts[-1],
            dt0=self.ode_solver_stepsize,
            y0=h0,
            saveat=saveat,
            stepsize_controller=diffrax.ConstantStepSize(),
        ))
        
        (prediction,) = jnn.sigmoid(self.linear2(solution.ys[-1]))

        return prediction


In [None]:
# Define the LogNeuralCDE class, which is identical to NeuralCDE, except for the get_ode method
class LogNeuralCDE(NeuralCDE):
    stepsize: int 
    depth: int
    hall_set: HallSet

    def __init__(
        self,
        hidden_dim,
        data_dim,
        label_dim,
        vf_hidden_dim,
        vf_num_hidden,
        ode_solver_stepsize,
        stepsize,
        depth,
        *,
        key,
    ):
        # Initialise the parent class NeuralCDE
        super().__init__(
            hidden_dim,
            data_dim,
            label_dim,
            vf_hidden_dim,
            vf_num_hidden,
            ode_solver_stepsize,
            key=key,
        )
        self.stepsize = stepsize
        # Ensure the depth is either 1 or 2
        if depth not in [1, 2]:
            raise ValueError(
                "The Log-ODE method is only implemented for truncation depths one and two"
            )
        self.depth = depth
        self.hall_set = HallSet(data_dim, depth)

    # Method to calculate log-signatures
    def calc_logsigs(self, X):
        # Reshape data
        X = X.reshape(-1, self.stepsize, X.shape[-1])

        # Prepend zero to the first interval and the last element of the previous interval to every other interval
        prepend = jnp.concatenate((jnp.zeros((1, X.shape[-1])), X[:-1, -1, :]))[
            :, None, :
        ]
        X = jnp.concatenate((prepend, X), axis=1)

        # Define log-signature function
        def logsig(x):
            logsig = flatten(log(signature(x, self.depth)))
            if self.depth == 1:
                return jnp.concatenate((jnp.array([0]), logsig))
            else:
                tensor_to_lie_map = self.hall_set.t2l_matrix(self.depth)
                return tensor_to_lie_map[:, 1:] @ logsig

        # Calculate log-signatures over each interval
        logsigs = jax.vmap(logsig)(X)

        return logsigs

    # ODE for depth one Log-ODE method
    def depth_one_ode(self, y, logsig, interval_length):
        vf_out = jnp.reshape(self.vf(y), (self.hidden_dim, self.data_dim))
        return jnp.dot(vf_out, logsig[1:]) / interval_length

    # ODE for depth two Log-ODE method
    def depth_two_ode(self, y, logsig, interval_length):
        # Reshape output from vector field to be matrix-valued
        vf_out = jnp.reshape(self.vf(y), (self.hidden_dim, self.data_dim))

        # Calculate Jacobian-vector products used to compute the Lie brackets
        jvps = jnp.reshape(
            jax.vmap(lambda x: jax.jvp(self.vf, (y,), (x,))[1])(vf_out.T),
            (self.data_dim, self.data_dim, self.hidden_dim),
        )

        # Compute Lie brackets
        def liebracket(jvps, pair):
            return jvps[pair[0] - 1, pair[1] - 1] - jvps[pair[1] - 1, pair[0] - 1]

        pairs = jnp.asarray(self.hall_set.data[self.data_dim + 1 :])
        lieout = jax.vmap(liebracket, in_axes=(None, 0))(jvps, pairs)

        # Combine Lie brackets with the log-signature
        vf_depth1 = jnp.dot(vf_out, logsig[1 : self.data_dim + 1])
        vf_depth2 = jnp.dot(lieout.T, logsig[self.data_dim + 1 :])

        return (vf_depth1 + vf_depth2) / interval_length

    # Define get_ode using the Log-ODE method
    def get_ode(self, ts, X):
        # Calculate the log-signatures
        logsigs = self.calc_logsigs(X)
        # Calculate intervals, assuming 0<=t<=1
        intervals = (
            jnp.arange(0, X.shape[0] + self.stepsize, self.stepsize) / X.shape[0]
        )

        # Define ODE function
        def func(t, y, args):
            idx = jnp.searchsorted(intervals, t)
            logsig_t = logsigs[idx - 1]
            interval_length = intervals[idx] - intervals[idx - 1]
            if self.depth == 1:
                return self.depth_one_ode(y, logsig_t, interval_length)
            if self.depth == 2:
                return self.depth_two_ode(y, logsig_t, interval_length)

        return diffrax.ODETerm(func)

In [None]:
features1 = np.load('/home/sichengyu/text/NCDE/SimplifiedProgram/autoencode/Pitt/train_features_Pitt_conv_1.npy')
features2 = np.load('/home/sichengyu/text/NCDE/SimplifiedProgram/autoencode/Pitt/test_features_Pitt_conv_1.npy')
indices_train = torch.load('/home/sichengyu/text/NCDE/feature_tensor/Pittnew/indices_train_Pitt_0.3_whisper_30_new_norm.pt')
indices_test = torch.load('/home/sichengyu/text/NCDE/feature_tensor/Pittnew/indices_test_Pitt_0.3_whisper_30_new_norm.pt')
labels1=torch.load('/home/sichengyu/text/NCDE/feature_tensor/Pittnew/labels1_Pitt_0.3_train_whisper_30_new_norm.pt')
labels2=torch.load('/home/sichengyu/text/NCDE/feature_tensor/Pittnew/labels2_Pitt_0.3_test_whisper_30_new_norm.pt')

In [None]:
features_np = features1
features_np_test =features2
labels1_np=labels1.detach().cpu().numpy()
labels2_np=labels2.detach().cpu().numpy()


features_jax = jnp.array(features_np)
features_jax_test=jnp.array(features_np_test)
labels_jax=jnp.array(labels1_np)
labels_jax_test=jnp.array(labels2_np)

In [None]:
def preprocess_data(features):
    mean = features.mean((0, 1), keepdims=True)  
    std = features.std((0, 1), keepdims=True)    
    standardized_features = (features - mean) / (std + 1e-8) 
    return standardized_features

def get_data(features):
    ts = jnp.linspace(0,1, features.shape[1])  
    ts1 = jnp.repeat(ts[None, :], features.shape[0], axis=0)
    normalized_features = preprocess_data(features)
    time_steps_expanded = ts1[:, :, None]  
    features_with_time = jnp.concatenate([time_steps_expanded,normalized_features], axis=2)  

    return features_with_time 

In [None]:
X_train=get_data(features_jax)
X_test=get_data(features_jax_test)


y_train=labels_jax
y_test=labels_jax_test


In [None]:
def count_audio_files(directory):
    audio_extensions = ('.wav', '.mp3', '.flac', '.aac', '.ogg', '.m4a', '.wma')
    audio_file_count = 0

    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(audio_extensions):
                audio_file_count += 1

    return audio_file_count

train_ccn=count_audio_files("/home/sichengyu/Downloads/dementiabank/Pitt_new/norm/train_test_split/train/cc_enhence")
train_cdn=count_audio_files("/home/sichengyu/Downloads/dementiabank/Pitt_new/norm/train_test_split/train/cd_enhence")
test_ccn=count_audio_files("/home/sichengyu/Downloads/dementiabank/Pitt_new/norm/train_test_split/test/cc_enhence")
test_cdn=count_audio_files("/home/sichengyu/Downloads/dementiabank/Pitt_new/norm/train_test_split/test/cd_enhence")
print("train_ccn:",train_ccn)
print("train_cdn:",train_cdn)
print("test_ccn:",test_ccn)
print("test_cdn:",test_cdn)

In [None]:
labels_test = jnp.concatenate([jnp.zeros(test_ccn), jnp.ones(test_cdn)])
labels_train = jnp.concatenate([jnp.zeros(train_ccn), jnp.ones(train_cdn)])

In [None]:
X_train.shape

In [None]:
class Dataloader:
    data: jnp.ndarray 
    labels: jnp.ndarray  
    size: int 

    def __init__(self, data, labels):
        self.data = data  
        self.labels = labels 
        self.size = len(data) 

    def loop(self, batch_size, *, key=None):
        if batch_size == self.size:
            yield self.data, self.labels

        indices = jnp.arange(self.size)  
        while True:
            subkey, key = jr.split(key)  
            perm = jr.permutation(subkey, indices)  
            start = 0
            end = batch_size
            while end < self.size:
                batch_perm = perm[start:end] 
                yield self.data[batch_perm], self.labels[batch_perm]
                start = end  
                end = start + batch_size  


# Initialise dataloaders for training and testing data
train_dataloader = Dataloader(X_train, y_train)
test_dataloader = Dataloader(X_test, y_test)

In [None]:
# Define the classification loss function with gradient calculation
@eqx.filter_jit
@eqx.filter_value_and_grad
def classification_loss(model, X, y, *, key):
    batch_size = X.shape[0]

    keys = jax.random.split(key, batch_size)
    def model_forward(x, k):
        return model(x, k, inference=False)


    pred_y = jax.vmap(model_forward)(X, keys)
    epsilon = 1e-7
    pred_y_clipped = jnp.clip(pred_y, epsilon, 1 - epsilon)
    loss = - ( y * jnp.log(pred_y_clipped) +  (1 - y) * jnp.log(1 - pred_y_clipped))    
    norm = 0
    for layer in model.vf.mlp.layers:
        norm += jnp.mean(
            jnp.linalg.norm(layer.weight, axis=-1)
            + jnp.linalg.norm(layer.bias, axis=-1)
        )
    norm *= 0
    return jnp.mean(loss)+norm

# Define the training step function with JIT compilation
@eqx.filter_jit
def train_step(model, X, y, opt, opt_state, *, key):
    key, subkey = jr.split(key)
    loss, grads = classification_loss(model, X, y,key=subkey)
    updates, opt_state = opt.update(grads, opt_state, params=trainable_params)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss


In [None]:
import os
import pandas as pd

def save_model_results(
    seed,
    acc_test,
    f1,
    precision,
    recall,
    acc_test_vote,
    f1_vote,
    precision_vote,
    recall_vote,
    test_predictions1,
    test_predictions2,
    metrics_csv="metrics.csv",
    preds_csv="predictions.csv"
):

    if not os.path.exists(metrics_csv):
        df_metrics = pd.DataFrame(columns=["seed","acc_test","f1","precision","recall","acc_test_vote","f1_vote","precision_vote","recall_vote"])
    else:
        df_metrics = pd.read_csv(metrics_csv)

    if seed in df_metrics["seed"].values:
        print(f"[INFO] Metrics for Seed={seed} already exist in {metrics_csv}, skipping save.")
    else:
        new_row_df = pd.DataFrame([{
            "seed": seed,
            "acc_test": acc_test,
            "f1": f1,
            "precision": precision,
            "recall": recall,
            "acc_test_vote": acc_test_vote,
            "f1_vote": f1_vote,
            "precision_vote": precision_vote,
            "recall_vote": recall_vote
        }])
        df_metrics = pd.concat([df_metrics, new_row_df], ignore_index=True)
        df_metrics.to_csv(metrics_csv, index=False)
        print(f"[SUCCESS] Saved metrics for Seed={seed} to {metrics_csv}.")

    if not os.path.exists(preds_csv):
        df_preds = pd.DataFrame(columns=["seed","test_predictions1","test_predictions2"])
    else:
        df_preds = pd.read_csv(preds_csv)

    if seed in df_preds["seed"].values:
        print(f"[INFO] Prediction list for Seed={seed} already exists in {preds_csv}, skipping save.")
    else:
        test_preds1_str = str(test_predictions1)
        test_preds2_str = str(test_predictions2)

        new_row_preds_df = pd.DataFrame([{
            "seed": seed,
            "test_predictions1": test_preds1_str,
            "test_predictions2": test_preds2_str
        }])
        df_preds = pd.concat([df_preds, new_row_preds_df], ignore_index=True)
        df_preds.to_csv(preds_csv, index=False)
        print(f"[SUCCESS] Saved prediction list for Seed={seed} to {preds_csv}.")



In [None]:
import gc
import jax
import optax
import math
import equinox as eqx
import jax.numpy as jnp

def get_trainable_params(model):
    return eqx.filter(model, eqx.is_inexact_array)

def train_model(
    model,
    num_steps=415, 
    print_steps=40,  
    batch_size=32, 
    base_lr=3.5e-4, 
    warmup_steps = 84,
    weight_decay=0, 
    *,
    key,
    seed,
):
    global train_predictions1, test_predictions1,test_predictions2,trainable_params
    trainable_params = get_trainable_params(model)

    warmup_schedule = optax.linear_schedule(
        init_value=0.0, 
        end_value=base_lr, 
        transition_steps=warmup_steps,  
    )
    
    cosine_schedule = optax.cosine_decay_schedule(
        init_value=base_lr,  
        decay_steps=num_steps - warmup_steps, 
        alpha=0.01  
    )
    
    lr_schedule = optax.join_schedules(
        schedules=[warmup_schedule, cosine_schedule], 
        boundaries=[warmup_steps]  
    )
    
    opt = optax.adamw(learning_rate=lr_schedule, weight_decay=weight_decay)
    
    opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

    test_accs = [] 
    test_accs_vote = []
    steps = []  
    train_accs = [] 

    dataset_size = X_train.shape[0]
    steps_per_epoch = math.ceil(dataset_size / batch_size)
    total_epochs = math.ceil(num_steps / steps_per_epoch)

    for epoch in range(total_epochs):
        print(f"Epoch: {epoch + 1}")
        trainloopkey, key = jax.random.split(key)

        for step, data in zip(
            range(steps_per_epoch), train_dataloader.loop(batch_size, key=trainloopkey)
        ):
            start_time = time.time()
    
            X, y = data  
            key, subkey = jr.split(key)

            model, opt_state, loss = train_step(model, X, y, opt, opt_state, key=subkey)
            if step == 0 or (step + 1) % print_steps == 0 or step==(steps_per_epoch - 1):
                inference_model = eqx.nn.inference_mode(model)
                inference_model = eqx.Partial(inference_model,inference=True)
                for batch, data in zip(
                    range(1), train_dataloader.loop(train_dataloader.size)
                ):
                    X, y = data
                    keys = jax.random.split(jr.PRNGKey(0), X.shape[0])
                    output = jax.vmap(inference_model)(X, keys)
                    pre_train = output
                    train_acc = jnp.mean((output > 0.5) == (y == 1))

                for batch, data in zip(
                    range(1), test_dataloader.loop(test_dataloader.size)
                ):
                    X, y = data
                    keys = jax.random.split(jr.PRNGKey(0), X.shape[0])
                    output = jax.vmap(inference_model)(X, keys)

                    test_acc = jnp.mean((output > 0.5) == (y == 1))
                if step == steps_per_epoch - 1:
                    pre_test = output

                elapsed_time = time.time() - start_time
                print(f"Step: {step + 1}, Loss: {loss}, Train Acc: {train_acc}, Test Acc: {test_acc}, Time: {elapsed_time:.4f} seconds")
    
                steps.append(step + 1)
        audio_segments_train = {}       
        
        for idx, pred in zip(indices_train[:,0], pre_train):
            if idx.size == 1:
                idx = int(idx.item())
            else:
                raise ValueError(f"Unexpected idx size: {idx.size}, idx: {idx}")
            if idx in audio_segments_train:
                audio_segments_train[idx].append(pred)
            else:
                audio_segments_train[idx] = [pred]
        
        audio_predictions_train = {idx: jnp.mean(jnp.array(preds)) for idx, preds in audio_segments_train.items()}

        predictions1 = list(audio_predictions_train.values())
        predictions1 = jnp.array(predictions1) 


        train_predictions1 = [(idx, 1 if pred >= 0.5 else 0) for idx, pred in audio_predictions_train.items()]

        audio_segments_test = {}
        for idx, pred_val in zip(indices_test[:, 0], pre_test):
            if idx.size == 1:
                idx = int(idx.item()) 
            else:
                raise ValueError(f"Unexpected idx size: {idx.size}, idx: {idx}")
            if idx in audio_segments_test:
                audio_segments_test[idx].append(pred_val)
            else:
                audio_segments_test[idx] = [pred_val]
        
        audio_predictions_test = {idx: jnp.mean(jnp.array(preds)) for idx, preds in audio_segments_test.items()}
        audio_predictions_test_vote = {idx: 1 if jnp.sum(jnp.array(preds) > 0.5) > jnp.sum(jnp.array(preds) <= 0.5) else 0 for idx, preds in audio_segments_test.items()}
        
        values = jnp.array(list(audio_predictions_test.values()))

        evaluate_and_plot_roc_pr_curves(labels_train, predictions1, plot_title_prefix="Train")
        evaluate_and_plot_roc_pr_curves(labels_test, values, plot_title_prefix="Test")

        correct_predictions_test = 0
        predict_label=[]
        predict_label_vote=[]
        for idx, pred in audio_predictions_test.items():
            label = labels_test[idx]
            predict_label.append((pred>0.5))
            if (pred>0.5) == label:
                correct_predictions_test += 1
                
        acc_test = correct_predictions_test / len(audio_predictions_test)
                
        correct_predictions_test_vote = 0
        for idx, pred in audio_predictions_test_vote.items():
            label = labels_test[idx]
            predict_label_vote.append(pred)
            if pred == label:
                correct_predictions_test_vote += 1

        acc_test_vote = correct_predictions_test_vote / len(audio_predictions_test)

        
        predict_label_array=np.array(predict_label)
        predict_label_vote_array=np.array(predict_label_vote)
        labels_test_array=np.array(labels_test)

        precision = precision_score(labels_test_array,predict_label_array)
        precision_vote=precision_score(labels_test_array,predict_label_vote_array)

        recall = recall_score(labels_test_array,predict_label_array)
        recall_vote=recall_score(labels_test_array,predict_label_vote_array)

        f1 = f1_score(labels_test_array,predict_label_array)
        f1_vote=f1_score(labels_test_array,predict_label_vote_array)

        print('acc_test:', acc_test)
        print('acc_test_vote:', acc_test_vote)
        test_accs.append(acc_test)
        test_accs_vote.append(acc_test_vote)
        if epoch==total_epochs-1:
            predictions_list_test = []
            for idx, preds in audio_segments_test.items():
                mean_pred = jnp.mean(jnp.array(preds))
                predictions_list_test.append((idx, mean_pred))
            
            predictions_list_test.sort(key=lambda x: x[0])
            
            for idx, mean_pred in predictions_list_test:
                print(f"Audio segment test {idx}: Prediction value {mean_pred}")
            test_predictions1 = [(idx, 1 if pred >= 0.5 else 0) for idx, pred in predictions_list_test]
            test_predictions2 = [(idx, 1 if pred == 1 else 0) for idx, pred in audio_predictions_test_vote.items()]
            save_model_results(
                seed=seed,
                acc_test=acc_test,
                f1=f1,
                precision=precision,
                recall=recall,
                acc_test_vote=acc_test_vote,
                f1_vote=f1_vote,
                precision_vote=precision_vote,
                recall_vote=recall_vote,
                test_predictions1=test_predictions1,
                test_predictions2=test_predictions2,
                metrics_csv="solution/Pitt_log_50seed_acc_h128_v512_0norm_ode250_step60x.csv",
                preds_csv="solution/Pitt_log_50seed_predict_h128_v512_0norm_ode250_step60x.csv"
            )
        
    return acc_test,acc_test_vote,test_accs,test_accs_vote,train_predictions1, test_predictions1,test_predictions2



In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


In [None]:
import jax.numpy as jnp
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gc
def train_with_seeds(seeds):
    global hyperparameters, all_train_predictions, all_test_predictions_vote, all_test_predictions, acc_seeds
    hyperparameters = []
    all_train_predictions = []  
    all_test_predictions = []  
    all_test_predictions_vote = [] 
    acc_seeds = []  
    all_test_accuracies = []
    all_test_accuracies_vote = []
    for seed in seeds:
        key = jax.random.PRNGKey(seed)
        modelkey, key = jax.random.split(key)
        trainkey, key = jax.random.split(key)

        LogNCDE_Depth2 = LogNeuralCDE(
            hidden_dim=hidden_dim,
            data_dim=data_dim,
            label_dim=label_dim,
            vf_hidden_dim=vf_hidden_dim,
            vf_num_hidden=vf_num_hidden,
            ode_solver_stepsize=ode_solver_stepsize,
            stepsize=stepsize,
            depth=2,
            key=modelkey,
        )
        
        try:
            print(f"Training with seed: {seed}")
            train_predictions1 = []
            test_predictions1 = []
            test_predictions2 = []

            acc_test, acc_test_vote, test_accs,test_accs_vote,train_predictions1, test_predictions1,test_predictions2 = train_model(LogNCDE_Depth2, key=trainkey, seed=seed)

            acc_test_cpu = np.array(acc_test)
            acc_test_vote_cpu = np.array(acc_test_vote)
            test_accs_cpu = np.array(test_accs)
            test_accs_vote_cpu=np.array(test_accs_vote)

            test_predictions1_cpu = [np.array(p) for p in test_predictions1]
            test_predictions2_cpu = [np.array(p) for p in test_predictions2]

            all_train_predictions.append(train_predictions1)
            all_test_predictions.append(test_predictions1_cpu)
            all_test_predictions_vote.append(test_predictions2_cpu)
            all_test_accuracies.append(test_accs_cpu)
            all_test_accuracies_vote.append(test_accs_vote_cpu)

        except Exception as e:
            print(f"Error encountered with seed {seed}: {e}")
            placeholder_accs = np.zeros(len(test_accs)) if 'test_accs' in locals() else np.zeros(10)
            all_test_accuracies.append(placeholder_accs)
            placeholder_accs_vote = np.zeros(len(test_accs_vote)) if 'test_accs_vote' in locals() else np.zeros(10)
            all_test_accuracies_vote.append(placeholder_accs_vote)
            continue

        finally:
            del LogNCDE_Depth2
            del key
            del modelkey
            del trainkey
            jax.device_put(None)
            gc.collect()
            jax.clear_caches()

    print("test_predictions:", all_test_predictions)
    print("test_predictions_vote:", all_test_predictions_vote)

    vote_and_evaluate()

    epochs = range(1, len(all_test_accuracies[0]) + 1)
    df_test = pd.DataFrame(all_test_accuracies, columns=epochs, index=[f"Seed_{s}" for s in seeds])
    df_test_vote = pd.DataFrame(all_test_accuracies_vote, columns=epochs, index=[f"Seed_{s}" for s in seeds])
    print("\nTest Accuracies per Epoch per Seed:")
    print(df_test)
    print("\nTest Accuracies vote per Epoch per Seed:")
    print(df_test_vote)
    return df_test


def vote_and_evaluate():
    train_votes = aggregate_votes(all_train_predictions)
    accuracy=evaluate_accuracy(train_votes, labels_train, "Train")

    test_votes = aggregate_votes(all_test_predictions)
    accuracy_test=evaluate_accuracy(test_votes, labels_test, "Test")
    test_votes=jnp.array(list(test_votes.values()))

    test_vote_votes = aggregate_votes(all_test_predictions_vote)
    accuracy_vote_test=evaluate_accuracy(test_vote_votes, labels_test, "Test_vote")
    test_vote_votes=jnp.array(list(test_vote_votes.values()))
    
    precision = precision_score(labels_test,test_vote_votes)
    recall = recall_score(labels_test,test_vote_votes)
    f1 = f1_score(labels_test,test_vote_votes)
    print('testvotes:',test_votes)
    print('test_vote_votes:',test_vote_votes)

def aggregate_votes(predictions_list_all_seeds):
    aggregated_predictions = {}
    for idx in range(len(predictions_list_all_seeds[0])):
        preds = [predictions_list_all_seeds[seed_idx][idx][1] for seed_idx in range(len(predictions_list_all_seeds))]
        vote_result = 1 if sum(pred == 1 for pred in preds) > len(preds) / 2 else 0
        aggregated_predictions[idx] = vote_result
    return aggregated_predictions


def evaluate_accuracy(aggregated_predictions, labels, dataset_name):
    correct_predictions = 0
    for idx, pred in aggregated_predictions.items():
        label = labels[idx]
        if pred == label:
            correct_predictions += 1
    accuracy = correct_predictions / len(aggregated_predictions)
    print(f'{dataset_name} Accuracy (after voting): {accuracy:.2f}')
    return accuracy

In [None]:
hidden_dim = 128
data_dim = 33
label_dim = 1
vf_hidden_dim = 512
vf_num_hidden = 3
ode_solver_stepsize = 1 / 250
stepsize = 60
num_seeds = 5
seeds = np.random.randint(0, 10000, size=num_seeds).tolist()
test_seeds=[1001,1002,1003,1004,1005]
train_with_seeds(test_seeds)