In [None]:
import numpy as np
import keras
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from pathlib import Path
import warnings
import tensorflow as tf
tf.keras.utils.disable_interactive_logging()
# Suppress specific warning
warnings.filterwarnings("ignore", category=FutureWarning)
from models.resnet import ClassifierRESNET
from tqdm import tqdm
from sktime.datasets import load_UCR_UEA_dataset

## Load Data

In [None]:
dataset = 'Epilepsy'

#Data will load with shape (instances, dimensions, timesteps)
X_train, y_train = load_UCR_UEA_dataset(dataset, split="train", return_type="numpy3d")
X_test, y_test = load_UCR_UEA_dataset(dataset, split="test", return_type="numpy3d")

# Reshape → (instances, timesteps, dimensions)
X_train = X_train.transpose(0, 2, 1)
X_test  = X_test.transpose(0, 2, 1)

#Integer‑encode labels
le = preprocessing.LabelEncoder()
y_train = le.fit_transform(y_train)
y_test = le.transform(y_test)

nb_classes = len(np.unique(np.concatenate([y_train,y_test])))
input_shape = X_train.shape[1:] #The input shape for our CNN should be (timesteps, dimensions)

X_tr, X_val, y_tr_int, y_val_int = train_test_split(
        X_train,
        y_train,
        test_size=0.20,
        stratify=y_train,
        random_state=42)

#One Hot the labels for the ResNet
y_tr  = keras.utils.to_categorical(y_tr_int,  nb_classes)
y_val = keras.utils.to_categorical(y_val_int, nb_classes)

# integer versions kept for metrics/logging later → these are y_true
y_true_val = y_val_int          # for validation
y_true_test = y_test

## Model

### Train ResNet

In [None]:
model_output_path = f'/Users/alan.paredes/Desktop/confetti/models/trained_models/{dataset}/'

model = ClassifierRESNET(output_directory=model_output_path,
                             input_shape=input_shape,
                             nb_classes=nb_classes,
                             dataset_name=dataset,
                             verbose=True)

model.fit(x_train=X_tr,
              y_train=y_tr,
              x_val=X_val,
              y_val=y_val,
              y_true=y_val_int) # for final test scoring