In [None]:
import keras.models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout
from keras.optimizers import Adam
import tensorflow
from keras.optimizers.legacy import Adam as LegacyAdam

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import pickle as pkl
from src.utils.data_transform import *
from src.utils.data_io import load_data
import pandas as pd
import os 
import pickle 
import json
import matplotlib.pyplot as plt
from src.utils.data_io import save_data
from src.analysis.viz_training import *

In [None]:
def build_model(input_shape):
    model = Sequential()
    model.add(Conv1D(filters=64, kernel_size=10, activation='relu', input_shape=input_shape, padding='same'))
    model.add(MaxPooling1D(pool_size=2))
    model.add(Conv1D(filters=128, kernel_size=10, activation='relu', padding='same'))
    model.add(MaxPooling1D(pool_size=2))
    model.add(Flatten())
    model.add(Dropout(0.5))
    model.add(Dense(5, activation='softmax'))  # Assuming 5 classes for the output layer
    # optimizer = Adam(learning_rate=1e-3)
    optimizer = LegacyAdam(learning_rate=1e-3)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=["accuracy"])
    return model

In [None]:
model = build_model((20,6))
model.summary()

In [None]:
with open("../../data/dataset-info-json/subject_to_indices.json", "r") as f:
    subject_to_indices = json.load(f)

subject_to_indices = {int(k): v for k, v in subject_to_indices.items()}

path_to_data = "../../data/ProcessedSubjects/MajorityLabel/sessions/grav_n_med/full_std_3"
path_to_save = "../../models/cnn/1000"
training_info_path = "../../models/cnn/1000/training_info" 
os.makedirs(training_info_path, exist_ok=True)

**Train LOSO**

In [None]:
for test_subject in subject_to_indices.keys():
    # Load the data
    print(f"Training without {test_subject}")
    results = []
    accuracy = []
    loss = []
    model = build_model(input_shape=(20,6))
    train_data, train_labels, test_data, test_labels = load_data(test_subject, subject_to_indices, path_to_data)
    history = model.fit(train_data, train_labels, epochs=32, batch_size=64)
    results.append(model.evaluate(test_data, test_labels))
    accuracy.append(history.history['accuracy'])
    loss.append(history.history['loss'])
    model.save(f"{path_to_save}/model_{test_subject}.keras")
    
    save_data(results, training_info_path, f"results_{test_subject}")
    save_data(accuracy, training_info_path, f"accuracy_{test_subject}")
    save_data(loss, training_info_path, f"loss_{test_subject}")


**Load and display results**

In [None]:
tot_acc = []
tot_res = []
tot_loss = []
info_path = "../models/full_loso/majority_label/processed/mean_std_3/training_info"
for i in range(1,13):
    with open(f"{info_path}/accuracy_{i}.pkl", "rb") as a:
        tot_acc.append(pickle.load(a))
    with open(f"{info_path}/results_{i}.pkl", "rb") as r:
        tot_res.append(pickle.load(r))
    with open(f"{info_path}/loss_{i}.pkl", "rb") as l:
        tot_loss.append(pickle.load(l))

tot_res = np.concatenate(tot_res)
tot_acc = np.concatenate(tot_acc)
tot_loss = np.concatenate(tot_loss)

print(f"Evaluation: {np.mean(tot_res[:,1])*100}%")
print(f"Accuracy: {np.mean(tot_acc)*100}%")
print(f"Loss: {np.mean(tot_res[:,0])*100}%")


**Train FULL**

In [None]:
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint
import datetime
import os

path = "../../models/cnn/test_1000/"
log_dir = os.path.join(path, "logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
checkpoint_dir = os.path.join(path, "checkpoints")
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, "cp-{epoch:04d}.ckpt")

In [None]:
model = build_model(input_shape=(20,6))
train_data, train_labels, test_data, test_labels = load_data(1, subject_to_indices, path_to_data)

In [None]:
all_data = np.concatenate((train_data, test_data), axis=0)
all_labels = np.concatenate((train_labels, test_labels), axis=0)

In [None]:
tensorboard_callback = TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,  # Record activation histograms every epoch
    write_graph=True,  # Visualize the graph
    write_images=True,  # Store images of the weights
    update_freq='epoch',
    profile_batch=2,  # Profiling the second batch to look at performance bottlenecks
    embeddings_freq=1,  # Visualize embeddings
)


In [None]:
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    verbose=1,
    save_freq='epoch',
    monitor='accuracy',  # Or another metric like 'val_accuracy'
    save_best_only=True,
    mode='max'  # Use 'max' for metrics where higher is better, like accuracy
)


In [None]:
epochs = 1000
history = model.fit(
    train_data,
    train_labels,
    epochs=epochs,
    batch_size=64,
    callbacks=[tensorboard_callback, model_checkpoint_callback]
)

model.save(f"{path}/model_{epochs}_test.keras")

acc = history.history["accuracy"]
loss = history.history["loss"]
res = model.evaluate(test_data, test_labels)

save_data(acc, f"{path}/training_info", f"accuracy_{epochs}_1")
save_data(loss, f"{path}/training_info", f"loss_{epochs}_1")
save_data(res, f"{path}/training_info", f"results_{epochs}_1")

plot_metric(acc,"Accuracy", f"../../models/cnn/test_{epochs}/figs/training_acc.svg")
plot_metric(loss,"Loss", f"../../models/cnn/test_{epochs}/figs/training_loss.svg")

**Load model from checkpoint and evaluate**

In [None]:
checkpoint_dir = "../../models/cnn/test_1000/checkpoints/"
ckpt_model = build_model((20,6))
results = evaluate_checkpoints(checkpoint_dir, test_data, test_labels, ckpt_model)
print(f"Best checkpoint: {results['best_ckpt']}")
print(f"Best accuracy: {results['best_accuracy']:.4f}")

In [None]:
# Extract the accuracy of each checkpoint and store them in a list
accuracies = [result['accuracy'] for result in results['all_results']]

In [None]:
%matplotlib notebook
plot_metric(accuracies, "Easy Accuracy from ckpts - Subject_1", "../../models/cnn/test_1000/figs/easy_acc_from_ckpts.svg")