In [3]:
import import_ipynb
import numpy as np
from models import BaseGRU, MultiHeadCnnRnn


In [1]:
def cross_validate(model, num_folds, train_x_list, train_y_list, test_x_list, test_y_list, epochs=10):
    train_accuracies = []
    test_accuracies = []
    train_losses = []
    test_losses = []

    all_train_accuracies = []
    all_test_accuracies = []
    
    if isinstance(model, BaseGRU):
        print("Testing on Base GRU class: ")
    elif isinstance(model, MultiHeadCnnRnn):
        print("Testing on Multi-head model: ")
    
    for i in range(num_folds):
        print(f"Training on fold {i+1}/{num_folds}")
        
        # Extract the training and test data for this fold
        train_X = train_x_list[i]
        train_y = train_y_list[i]
        test_X = test_x_list[i]
        test_y = test_y_list[i]
        
        # Fit the model
        history = model.fit(train_X, train_y, epochs=epochs, validation_data=(test_X, test_y))

        # Storing each epoch's training and testing accuracies
        all_train_accuracies.append(history.history['accuracy'])
        all_test_accuracies.append(history.history['val_accuracy'])
        
        # Storing the accuracy and loss from the last epoch for each fold
        train_accuracies.append(history.history['accuracy'][-1])
        test_accuracies.append(history.history['val_accuracy'][-1])
        train_losses.append(history.history['loss'][-1])
        test_losses.append(history.history['val_loss'][-1])

    # Calculate the average accuracies and losses across all folds
    average_train_accuracy = np.mean(train_accuracies)
    average_test_accuracy = np.mean(test_accuracies)
    average_train_loss = np.mean(train_losses)
    average_test_loss = np.mean(test_losses)

    print(f"Average training accuracy: {average_train_accuracy:.4f}")
    print(f"Average testing accuracy: {average_test_accuracy:.4f}")
    print(f"Average training loss: {average_train_loss:.4f}")
    print(f"Average testing loss: {average_test_loss:.4f}")

    return all_train_accuracies, all_test_accuracies