In [3]:
import pandas as pd
import numpy as np
import math

import jax
import jax.numpy as jnp
from jax import random, grad
from jax.nn import relu, sigmoid
from functools import partial

import matplotlib.pyplot as plt

import re
import time
import pytz
import os
import random
import sys
import pickle
from datetime import datetime

In [11]:
def initialize_random_weights(mean, std, shape = ()):
    return np.random.normal(loc=mean, scale=std, size=shape)

# We use a sinusoidal function to approximate odd numbers by their immediately preceding even number and preserve differentiability
def lower_even(x):
    return x - 0.5 * (1 - jnp.cos(jnp.pi * x))

# We use a sinusoidal function to approximate 0 for evens and 1 for odds while preserving differentiability
def differentiable_even_or_odd(x):
    return ((2 * x ** 3) / 3) - 3 * x ** 2 + ((10 * x) / 3)

folder = 'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition'

# Cargar las parejas desde el archivo
with open(f"{folder}/train_couples.txt", "r") as file:
    train_couples = eval(file.read())

with open(f"{folder}/combinations_with_carry_over.txt", "r") as file:
    combinations_with_carry_over = eval(file.read())  # Leer y convertir el contenido en una lista de tuplas

with open(f"{folder}/real_test_dataset.txt", 'r') as file:
    real_test_dataset = eval(file.read())  # Convertir el contenido del archivo a una lista de tuplas

with open(f"{folder}/test_dataset.txt", 'r') as file:
    test_dataset = eval(file.read())  # Convertir el contenido del archivo a una lista de tuplas
    
with open(f"{folder}/real_test_dataset_with_carry_over.txt", 'r') as file:
    real_test_dataset_with_carry_over = eval(file.read())

with open(f"{folder}/combinations_small_problem_size.txt", 'r') as file:
    combinations_small_problem_size = eval(file.read())

with open(f"{folder}/combinations_with_carry_over_decimal.txt", 'r') as file:
    combinations_with_carry_over_decimal = eval(file.read())
    

# Separar parejas con y sin ceros
train_without_zeros = [pair for pair in train_couples if 0 not in pair]
train_with_carry_over = [pair for pair in train_couples if pair in combinations_with_carry_over]
train_with_carry_over_decimal = [pair for pair in train_couples if pair in combinations_with_carry_over_decimal]

# Function to generate dataset with multiplication
def generate_dataset_with_zeros(size):
    # Seleccionar aleatoriamente parejas con ceros
    selected_pairs = random.choices(train_couples, k=size)
    
    # Separar las columnas de las parejas seleccionadas
    column_1 = [pair[0] for pair in selected_pairs]
    column_2 = [pair[1] for pair in selected_pairs]

    # Crear el DataFrame
    dataset = pd.DataFrame({
        'Column_1': column_1,
        'Column_2': column_2
    })

    # Crear la tercera columna sumando las dos primeras
    dataset['Column_3'] = dataset['Column_1'] + dataset['Column_2']
    return dataset

def generate_dataset_without_zeros(size):
    # Seleccionar aleatoriamente parejas sin ceros
    selected_pairs = random.choices(train_without_zeros, k=size)
    
    # Separar las columnas de las parejas seleccionadas
    column_1 = [pair[0] for pair in selected_pairs]
    column_2 = [pair[1] for pair in selected_pairs]

    # Crear el DataFrame
    dataset = pd.DataFrame({
        'Column_1': column_1,
        'Column_2': column_2
    })

    # Crear la tercera columna sumando las dos primeras
    dataset['Column_3'] = dataset['Column_1'] + dataset['Column_2']

    return dataset

def generate_test_dataset(n_max=100):
    # Create the columns
    column_1 = list(range(n_max)) * n_max  # Numbers from 0 to 9 repeated 10 times
    column_2 = [i for i in range(n_max) for _ in range(n_max)]  # Numbers from 0 to 9 repeated sequentially 10 times

    # Create a DataFrame with the two columns
    dataset = pd.DataFrame({
        'Column_1': column_1,
        'Column_2': column_2,
    })

    # Create the third column by multiplying the first two
    dataset['Column_3'] = dataset['Column_1'] + dataset['Column_2']

    return dataset

def generate_real_test_dataset():
    # Desempaquetar las parejas (a_i, b_i)
    #column_1, column_2 = zip(*real_test_dataset)
    column_1, column_2 = zip(*test_dataset)
    
    # Crear un DataFrame con las dos primeras columnas
    dataset = pd.DataFrame({
        'Column_1': column_1,
        'Column_2': column_2,
    })

    # Crear la tercera columna como la suma de las dos primeras
    dataset['Column_3'] = dataset['Column_1'] + dataset['Column_2']

    return dataset

def decimal_to_binary(n, bits):
    if 0 <= n < 2**bits:
        # Convert the number to a binary string and then to an array of integers (0 and 1)
        return np.array(list(format(n, f'0{bits}b'))).astype(np.int8)
    else:
        raise ValueError("Number out of range")

# Function to convert binary number to decimal
def binary_to_decimal(binary_vector, bits):
    # Ensure the vector has the correct number of elements
    if len(binary_vector) != bits:
        raise ValueError(f"The vector must have exactly {bits} elements.")

    # Calculate the decimal number
    decimal = 0
    for i in range(bits):
        decimal += binary_vector[i] * (2 ** (bits - 1 - i))

    return decimal

def transform_to_tridimensional_matrix(dataset, bits_init=7, bits_end=8):
    rows, cols = dataset.shape
    if cols != 3:
        raise ValueError("The dataset must have exactly 3 columns.")

    # Initialize the three matrices
    matrix_column_1 = np.zeros((rows, bits_init), dtype=np.int8)
    matrix_column_2 = np.zeros((rows, bits_init), dtype=np.int8)
    matrix_column_3 = np.zeros((rows, bits_end), dtype=np.int8)

    # Fill the matrices with the binary representation of each column
    for i in range(rows):
        matrix_column_1[i] = decimal_to_binary(dataset.iloc[i, 0], bits_init)
        matrix_column_2[i] = decimal_to_binary(dataset.iloc[i, 1], bits_init)
        matrix_column_3[i] = decimal_to_binary(dataset.iloc[i, 2], bits_end)

    return matrix_column_1, matrix_column_2, matrix_column_3
    
def prepare_dataset(level, size=1, couples_included=[]):       
    if level == -3:
        column_1 = []
        column_2 = []
        pairs = couples_included
        while len(column_1) < size:
            choice = pairs[np.random.choice(len(pairs))]
            column_1.append(choice[0])
            column_2.append(choice[1])
        dataset = pd.DataFrame({'Column_1': column_1,'Column_2': column_2,})
        dataset['Column_3'] = dataset['Column_1'] * dataset['Column_2']
        return dataset

    elif level == -2:
        dataset = generate_dataset_with_zeros(size)
        return dataset
        
    elif level == -1:
        dataset = generate_dataset_without_zeros(size)
        return dataset

    elif level == 0:
        dataset = pd.DataFrame()
        while len(dataset) < size:
            column_1 = np.random.randint(1, 10, size)
            column_2 = np.random.randint(1, 10, size)
            temp_dataset = pd.DataFrame({'Column_1': column_1, 'Column_2': column_2})
            temp_dataset = temp_dataset[~temp_dataset[['Column_1', 'Column_2']].apply(tuple, axis=1).isin(combinations_with_carry_over)]
            dataset = pd.concat([dataset, temp_dataset])
        dataset = dataset.iloc[:size].reset_index(drop=True)
        dataset['Column_3'] = dataset['Column_1'] * dataset['Column_2']
        return dataset

    elif level == 1:
        pairs = random.choices(train_with_carry_over, k=size)
        column_1 = [pair[0] for pair in pairs]
        column_2 = [pair[1] for pair in pairs]
        dataset = pd.DataFrame({'Column_1': column_1, 'Column_2': column_2})
        dataset['Column_3'] = dataset['Column_1'] + dataset['Column_2']
        return dataset

    else:
        print('Bad index for the training stage.')
        return None

def prepare_outputs(stage, x1, x2, outputs_prev):
    if stage == 1:
        outputs = []
        for vec1, vec2 in zip(x1, x2):
            z2 = lower_even(vec1[6] + vec2[6])
            z3 = lower_even(vec1[5] + vec2[5] + z2 * 1/2)
            z4 = lower_even(vec1[4] + vec2[4] + z3 * 1/2)
            z5 = lower_even(vec1[3] + vec2[3] + z4 * 1/2)
            z6 = lower_even(vec1[2] + vec2[2] + z5 * 1/2)
            z7 = lower_even(vec1[1] + vec2[1] + z6 * 1/2)
            z8 = lower_even(vec1[0] + vec2[0] + z7 * 1/2)
            outputs.append([z8, z7, z6, z5, z4, z3, z2, 0])
        return np.array(outputs)

    elif stage == 2:
        return outputs_prev
        
    elif stage == 3:
        return outputs_prev

    else:
        print('Bad index for the training stage.')
        return None

# Perfect parameters needed for the stages where a part of the NN performs perfectly
# R vectors of dimension (14,1)
R2_perfect = np.zeros((14))
R3_perfect = np.zeros((14))
R4_perfect = np.zeros((14))
R5_perfect = np.zeros((14))
R6_perfect = np.zeros((14))
R7_perfect = np.zeros((14))
R8_perfect = np.zeros((14))

for i in range(2):
    R2_perfect[7*i + 6] = 1
    R3_perfect[7*i + 5] = 1
    R4_perfect[7*i + 4] = 1
    R5_perfect[7*i + 3] = 1
    R6_perfect[7*i + 2] = 1
    R7_perfect[7*i + 1] = 1
    R8_perfect[7*i + 0] = 1

# Scalar parameters v
v2_perfect = 1/2
v3_perfect = 1/2
v4_perfect = 1/2
v5_perfect = 1/2
v6_perfect = 1/2
v7_perfect = 1/2

# Matrix T of dimension (28,7)
T_perfect = np.zeros((14,8))
for i in range(7):
    for j in range(2):
        T_perfect[7*j + i, i + 1] = 1

# Parameter v
v_perfect = 1/2

# Neural network in every stage
def neural_network_1(params, x1, x2):
    R2, R3, R4, R5, R6, R7, R8, v2, v3, v4, v5, v6, v7 = params
    x = jnp.concatenate((x1, x2), axis=0)
    z2 = lower_even(jnp.dot(x, R2)) # z2 is a scalar with the first carry over
    z3 = lower_even(jnp.dot(x, R3) + jnp.dot(z2, v2)) # z3 is a scalar with the second carry over
    z4 = lower_even(jnp.dot(x, R4) + jnp.dot(z3, v3)) # z4 is a scalar with the third carry over
    z5 = lower_even(jnp.dot(x, R5) + jnp.dot(z4, v4)) # z5 is a scalar with the fourth carry over
    z6 = lower_even(jnp.dot(x, R6) + jnp.dot(z5, v5)) # z6 is a scalar with the fifth carry over
    z7 = lower_even(jnp.dot(x, R7) + jnp.dot(z6, v6)) # z7 is a scalar with the seventh carry over
    z8 = lower_even(jnp.dot(x, R8) + jnp.dot(z7, v7)) # z7 is a scalar with the seventh carry over
    z = jnp.array([z8, z7, z6, z5, z4, z3, z2, 0])
    #y = differentiable_even_or_odd(relu(jnp.dot(vec, T) + jnp.dot(z, v7)))
    return z

def neural_network_2(params, x1, x2):
    T, v = params
    x = jnp.concatenate((x1, x2), axis=0)
    z2 = lower_even(jnp.dot(x, R2_perfect)) # z2 is a scalar with the first carry over
    z3 = lower_even(jnp.dot(x, R3_perfect) + jnp.dot(z2, v2_perfect)) # z3 is a scalar with the second carry over
    z4 = lower_even(jnp.dot(x, R4_perfect) + jnp.dot(z3, v3_perfect)) # z4 is a scalar with the third carry over
    z5 = lower_even(jnp.dot(x, R5_perfect) + jnp.dot(z4, v4_perfect)) # z5 is a scalar with the fourth carry over
    z6 = lower_even(jnp.dot(x, R6_perfect) + jnp.dot(z5, v5_perfect)) # z6 is a scalar with the fifth carry over
    z7 = lower_even(jnp.dot(x, R7_perfect) + jnp.dot(z6, v6_perfect)) # z7 is a scalar with the seventh carry over
    z8 = lower_even(jnp.dot(x, R8_perfect) + jnp.dot(z7, v7_perfect)) # z7 is a scalar with the seventh carry over
    z = jnp.array([z8, z7, z6, z5, z4, z3, z2, 0])
    y = differentiable_even_or_odd(jnp.dot(x, T) + jnp.dot(z, v))
    return y
    
def neural_network_3(params, x1, x2):
    R2, R3, R4, R5, R6, R7, R8, v2, v3, v4, v5, v6, v7, T, v = params
    x = jnp.concatenate((x1, x2), axis=0)
    z2 = lower_even(jnp.dot(x, R2)) # z2 is a scalar with the first carry over
    z3 = lower_even(jnp.dot(x, R3) + jnp.dot(z2, v2)) # z3 is a scalar with the second carry over
    z4 = lower_even(jnp.dot(x, R4) + jnp.dot(z3, v3)) # z4 is a scalar with the third carry over
    z5 = lower_even(jnp.dot(x, R5) + jnp.dot(z4, v4)) # z5 is a scalar with the fourth carry over
    z6 = lower_even(jnp.dot(x, R6) + jnp.dot(z5, v5)) # z6 is a scalar with the fifth carry over
    z7 = lower_even(jnp.dot(x, R7) + jnp.dot(z6, v6)) # z7 is a scalar with the seventh carry over
    z8 = lower_even(jnp.dot(x, R8) + jnp.dot(z7, v7)) # z7 is a scalar with the seventh carry over
    z = jnp.array([z8, z7, z6, z5, z4, z3, z2, 0])
    y = differentiable_even_or_odd(jnp.dot(x, T) + jnp.dot(z, v))
    return y

# Loss functions in every stage
def loss_1(params, x1, x2, y):
    pred = neural_network_1(params, x1, x2)
    return jnp.mean((pred - y)**2)

def loss_2(params, x1, x2, y):
    pred = neural_network_2(params, x1, x2)
    return jnp.mean((pred - y)**2)

def loss_3(params, x1, x2, y):
    pred = neural_network_3(params, x1, x2)
    return jnp.mean((pred - y)**2)

# Loss functions in every step
@jax.jit
def update_params_1(params, x1, x2, y, lr):
    gradients = grad(loss_1)(params, x1, x2, y)
    step_loss = loss_1(params, x1, x2, y)
    return [(p - lr * g) for p, g in zip(params, gradients)], step_loss

@jax.jit
def update_params_2(params, x1, x2, y, lr):
    gradients = grad(loss_2)(params, x1, x2, y)
    step_loss = loss_2(params, x1, x2, y)
    return [(p - lr * g) for p, g in zip(params, gradients)], step_loss
    
@jax.jit
def update_params_3(params, x1, x2, y, lr):
    gradients = grad(loss_3)(params, x1, x2, y)
    step_loss = loss_3(params, x1, x2, y)
    return [(p - lr * g) for p, g in zip(params, gradients)], step_loss

def decide_training(params, x1, x2, y, lr, stage):
    if stage == 1:
        params, step_loss = update_params_1(params, x1, x2, y, lr)
        return params, step_loss

    elif stage == 2:
        params, step_loss = update_params_2(params, x1, x2, y, lr)
        return params, step_loss
        
    elif stage == 3:
        params, step_loss = update_params_3(params, x1, x2, y, lr)
        return params, step_loss

    else:
        print('Bad index for the training stage.')
        return None
        
# Main function to train the network
def train_stages_neural_network(params, stage, level, lr=0.01, epochs=100):
    decimal_dataset = prepare_dataset(level, epochs)
    inputs_1, inputs_2, outputs_prev = transform_to_tridimensional_matrix(decimal_dataset)
    outputs = prepare_outputs(stage, inputs_1, inputs_2, outputs_prev)
    final_loss = 0
    # Train the network
    for epoch in range(epochs):
        # Update parameters at each step
        params, step_loss = decide_training(params, inputs_1[epoch], inputs_2[epoch], outputs[epoch], lr, stage)
        final_loss += step_loss

    final_loss = final_loss / epochs
    #print(f"Loss: {final_loss:.6f}")
    return params, final_loss



def decide_test(params, stage, real_test=0, visualize_errors=0):
    if real_test == 1:
        test_count, correct_predictions_test_count, train_count, correct_predictions_train_count, test_carry_over_count, correct_carry_over_predictions_test_count, train_carry_over_count, correct_carry_over_predictions_train_count, small_train_count, correct_predictions_small_train_count, small_test_count, correct_predictions_small_test_count, small_train_carry_count, correct_predictions_small_train_carry_count, small_test_carry_count, correct_predictions_small_test_carry_count, test_carry_over_decimal_count, correct_carry_over_decimal_predictions_test_count, train_carry_over_decimal_count, correct_carry_over_decimal_predictions_train_count, small_train_carry_decimal_count, correct_predictions_small_train_carry_decimal_count, small_test_carry_decimal_count, correct_predictions_small_test_carry_decimal_count, reaction_time_carry, reaction_time_carry_decimal, reaction_time_train,reaction_time_test, reaction_time_train_carry, reaction_time_train_carry_decimal, reaction_time_small_train, reaction_time_small_test, reaction_time_small_train_carry, reaction_time_small_test_carry, reaction_time_small_train_carry_decimal, reaction_time_small_test_carry_decimal= real_test_stages_neural_network(params, stage, visualize_errors=0)
        print(f"STAGE {stage}: Out of {train_count}, {correct_predictions_train_count} trained were predicted correctly in the current model.")
        print(f"STAGE {stage}: Out of {test_count}, {correct_predictions_test_count} tested were predicted correctly in the current model.")
        print(f"STAGE {stage}: Out of {train_carry_over_count}, {correct_carry_over_predictions_train_count} trained with carry-over were predicted correctly in the current model.")      
        print(f"STAGE {stage}: Out of {test_carry_over_count}, {correct_carry_over_predictions_test_count} tested with carry-over were predicted correctly in the current model.")
        print(f"STAGE {stage}: Out of {small_train_count}, {correct_predictions_small_train_count} trained with small problem size were predicted correctly in the current model.")
        print(f"STAGE {stage}: Out of {small_test_count}, {correct_predictions_small_test_count} tested with small problem size were predicted correctly in the current model.")
        print(f"STAGE {stage}: Out of {small_train_carry_count}, {correct_predictions_small_train_carry_count} trained with small problem size and with carry-over were predicted correctly in the current model.")      
        print(f"STAGE {stage}: Out of {small_test_carry_count}, {correct_predictions_small_test_carry_count} tested with small problem size and with carry-over were predicted correctly in the current model.")
        print(f"STAGE {stage}: Out of {train_carry_over_decimal_count}, {correct_carry_over_decimal_predictions_train_count} trained with carry-over decimal were predicted correctly in the current model.")      
        print(f"STAGE {stage}: Out of {test_carry_over_decimal_count}, {correct_carry_over_decimal_predictions_test_count} tested with carry-over decimal were predicted correctly in the current model.")
        print(f"STAGE {stage}: Out of {small_train_carry_decimal_count}, {correct_predictions_small_train_carry_decimal_count} trained with small problem size and with carry-over decimal were predicted correctly in the current model.")      
        print(f"STAGE {stage}: Out of {small_test_carry_decimal_count}, {correct_predictions_small_test_carry_decimal_count} tested with small problem size and with carry-over decimal were predicted correctly in the current model.")
        print(f"STAGE {stage}: {reaction_time_carry} reaction time with carry over.")
        print(f"STAGE {stage}: {reaction_time_carry_decimal} reaction time with carry over decimal.")
        print(f"STAGE {stage}: {reaction_time_train} reaction time for train.")
        print(f"STAGE {stage}: {reaction_time_test} reaction time for test.")
        print(f"STAGE {stage}: {reaction_time_train_carry} reaction time for train with carry over .")
        print(f"STAGE {stage}: {reaction_time_train_carry_decimal} reaction time for train with carry over decimal.")
        print(f"STAGE {stage}: {reaction_time_small_train} reaction time for train small.")
        print(f"STAGE {stage}: {reaction_time_small_test} reaction time for test small.")
        print(f"STAGE {stage}: {reaction_time_small_train_carry} reaction time for train with carry over and small.")
        print(f"STAGE {stage}: {reaction_time_small_test_carry} reaction time for test with carry over and small.")
        print(f"STAGE {stage}: {reaction_time_small_train_carry_decimal} reaction time for train with carry over decimal and small.")
        print(f"STAGE {stage}: {reaction_time_small_test_carry_decimal} reaction time for test with carry over decimal and small.") 
            
    else: 
        test_size, correct_predictions_tested_count, train_size, correct_predictions_trained_count = test_stages_neural_network(params, stage, visualize_errors=0)
        print(f"STAGE {stage}: Out of {train_size}, {correct_predictions_trained_count} trained were predicted correctly in the current model.")
        print(f"STAGE {stage}: Out of {test_size}, {correct_predictions_tested_count} tested were predicted correctly in the current model.")
        

# Main function to test the network
def test_stages_neural_network(params, stage, visualize_errors=0):
    decimal_dataset = generate_test_dataset()
    inputs_1, inputs_2, outputs_prev = transform_to_tridimensional_matrix(decimal_dataset)
    outputs = prepare_outputs(stage, inputs_1, inputs_2, outputs_prev)
    
    correct_predictions_tested_count = 0
    correct_predictions_trained_count = 0  # Counter for trained couples
    set_size = inputs_1.shape[0]
    train_size = len(train_couples)
    test_size = set_size - train_size
    
    for i in range(set_size):
        prediction, binary_pred = predict(params, inputs_1[i], inputs_2[i], stage)
        # Check if the prediction matches the expected output
        if jnp.all(prediction == outputs[i]):  
            if (decimal_dataset.iloc[i, 0], decimal_dataset.iloc[i, 1]) in train_couples:
                correct_predictions_trained_count += 1  # Increment for trained couples
            else:
                correct_predictions_tested_count += 1 # Increment for tested couples
        elif visualize_errors == 1:
            print(f'{decimal_dataset.iloc[i, 0]} plus {decimal_dataset.iloc[i, 1]} has failed.')

    return test_size, correct_predictions_tested_count, train_size, correct_predictions_trained_count

def real_test_stages_neural_network(params, stage, visualize_errors=0):
    decimal_dataset = generate_real_test_dataset()    
    inputs_1, inputs_2, outputs_prev = transform_to_tridimensional_matrix(decimal_dataset)
    outputs = prepare_outputs(stage, inputs_1, inputs_2, outputs_prev)
    
    correct_predictions_test_count = 0
    correct_predictions_train_count = 0
    correct_carry_over_predictions_count = 0
    correct_carry_over_predictions_test_count = 0
    correct_carry_over_predictions_train_count = 0
    correct_predictions_small_train_count = 0
    correct_predictions_small_test_count = 0
    correct_predictions_small_train_carry_count = 0
    correct_predictions_small_test_carry_count = 0
    correct_carry_over_decimal_predictions_count = 0
    correct_carry_over_decimal_predictions_test_count = 0
    correct_carry_over_decimal_predictions_train_count = 0
    correct_predictions_small_train_carry_decimal_count = 0
    correct_predictions_small_test_carry_decimal_count = 0

    set_size = inputs_1.shape[0]
    train_count = len(train_couples)
    test_count = set_size - train_count
    carry_over_count = len(combinations_with_carry_over) 
    carry_over_decimal_count = len(combinations_with_carry_over_decimal) 
    train_carry_over_count = len(train_with_carry_over)
    test_carry_over_count = carry_over_count - train_carry_over_count
    small_count = len(combinations_small_problem_size)
    train_carry_over_decimal_count = len(train_with_carry_over_decimal)
    test_carry_over_decimal_count = carry_over_decimal_count - train_carry_over_decimal_count

    # Contadores adicionales
    small_train_count = 0
    small_test_count = 0
    small_train_carry_count = 0
    small_test_carry_count = 0
    small_train_carry_decimal_count = 0
    small_test_carry_decimal_count = 0
    
    reaction_time_carry = 0
    reaction_time_carry_decimal = 0
    reaction_time_train = 0
    reaction_time_test = 0
    reaction_time_train_carry = 0
    reaction_time_train_carry_decimal = 0
    reaction_time_small_train = 0
    reaction_time_small_test = 0
    reaction_time_small_train_carry = 0
    reaction_time_small_test_carry = 0
    reaction_time_small_train_carry_decimal = 0
    reaction_time_small_test_carry_decimal = 0

    for i in range(set_size):
        pair = (decimal_dataset.iloc[i, 0], decimal_dataset.iloc[i, 1])

        start_time = time.perf_counter_ns()
        prediction, binary_pred = predict(params, inputs_1[i], inputs_2[i], stage)
        elapsed_time = time.perf_counter_ns() - start_time

        is_small = pair in combinations_small_problem_size
        is_train = pair in train_couples
        is_carry = pair in real_test_dataset_with_carry_over
        is_carry_decimal = pair in combinations_with_carry_over_decimal

        # Actualizar tiempos de reacción y contadores totales
        if is_small:
            if is_train:
                small_train_count += 1
                reaction_time_small_train += elapsed_time
                if is_carry:
                    small_train_carry_count += 1
                    reaction_time_small_train_carry += elapsed_time
                if is_carry_decimal:
                    small_train_carry_decimal_count += 1
                    reaction_time_small_train_carry_decimal += elapsed_time
            else:
                small_test_count += 1
                reaction_time_small_test += elapsed_time
                if is_carry:
                    small_test_carry_count += 1
                    reaction_time_small_test_carry += elapsed_time
                if is_carry_decimal:
                    small_test_carry_decimal_count += 1
                    reaction_time_small_test_carry_decimal += elapsed_time

        if is_carry:
            reaction_time_carry += elapsed_time
        if is_carry_decimal:
            reaction_time_carry_decimal += elapsed_time
        if is_train:
            reaction_time_train += elapsed_time
        else:
            reaction_time_test += elapsed_time
        if is_train and is_carry:
            reaction_time_train_carry += elapsed_time
        if is_train and is_carry_decimal:
            reaction_time_train_carry_decimal += elapsed_time

        # Contar predicciones correctas
        if jnp.all(prediction == outputs[i]):
            if is_small:
                if is_train:
                    correct_predictions_small_train_count += 1
                    if is_carry:
                        correct_predictions_small_train_carry_count += 1
                    if is_carry_decimal:
                        correct_predictions_small_train_carry_decimal_count += 1
                else:
                    correct_predictions_small_test_count += 1
                    if is_carry:
                        correct_predictions_small_test_carry_count += 1
                    if is_carry_decimal:
                        correct_predictions_small_test_carry_decimal_count += 1

            # Actualizar contadores previos
            if is_carry:
                correct_carry_over_predictions_count += 1 
            if is_carry_decimal:
                correct_carry_over_decimal_predictions_count += 1 
            if is_train:
                correct_predictions_train_count += 1 
            else:
                correct_predictions_test_count += 1
            if is_train and is_carry:
                correct_carry_over_predictions_train_count += 1   
            if is_train and is_carry_decimal:
                correct_carry_over_decimal_predictions_train_count += 1 
                
        elif visualize_errors == 1:
            print(f'{pair[0]} plus {pair[1]} has failed.')

    correct_carry_over_predictions_test_count = correct_carry_over_predictions_count - correct_carry_over_predictions_train_count
    correct_carry_over_decimal_predictions_test_count = correct_carry_over_decimal_predictions_count - correct_carry_over_decimal_predictions_train_count
    
    return (
        test_count,
        correct_predictions_test_count,
        train_count,
        correct_predictions_train_count,
        test_carry_over_count,
        correct_carry_over_predictions_test_count,
        train_carry_over_count,
        correct_carry_over_predictions_train_count,
        small_train_count,
        correct_predictions_small_train_count,
        small_test_count,
        correct_predictions_small_test_count,
        small_train_carry_count,
        correct_predictions_small_train_carry_count,
        small_test_carry_count,
        correct_predictions_small_test_carry_count,
        test_carry_over_decimal_count,
        correct_carry_over_decimal_predictions_test_count,
        train_carry_over_decimal_count,
        correct_carry_over_decimal_predictions_train_count,
        small_train_carry_decimal_count,
        correct_predictions_small_train_carry_decimal_count,
        small_test_carry_decimal_count,
        correct_predictions_small_test_carry_decimal_count,
        reaction_time_carry,
        reaction_time_carry_decimal,
        reaction_time_train,
        reaction_time_test,
        reaction_time_train_carry,
        reaction_time_train_carry_decimal,
        reaction_time_small_train,
        reaction_time_small_test,
        reaction_time_small_train_carry,
        reaction_time_small_test_carry,
        reaction_time_small_train_carry_decimal,
        reaction_time_small_test_carry_decimal
    )

# Predict using the trained neural network
def predict(params, x1, x2, stage):
    if stage == 1:
        binary_pred = neural_network_1(params, x1, x2)
        rounded_pred = np.round(binary_pred)
        return rounded_pred, binary_pred
        
    elif stage == 2:
        binary_pred = neural_network_2(params, x1, x2)
        rounded_pred = np.round(binary_pred)
        return rounded_pred, binary_pred
        
    elif stage == 3:
        binary_pred = neural_network_3(params, x1, x2)
        rounded_pred = np.round(binary_pred)
        return rounded_pred, binary_pred
        
    else:
        print('Bad index for the training stage.')
        return None

In [12]:
class Tee(object):
    def __init__(self, file, mode='w'):
        self.file = open(file, mode)
        self.console = sys.stdout  

    def write(self, data):
        self.console.write(data)   
        self.file.write(data)    

    def flush(self):
        self.console.flush()
        self.file.flush()

    def close(self):
        self.file.close()
        
def load_trainable_model(model, current_time, training_type, stage = 0):
    folder = 'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition'
    if stage == 0:
        model_path = f'{folder}/Parameters/{training_type}/{model}_{current_time}.pkl'
        with open(model_path, 'rb') as f:
            globals()[f'trainable_model'] = pickle.load(f)
        print(f'Model trainable_model_{current_time} loaded successfully.')
        return globals()[f'trainable_model']
        
    else:
        model_path = f'{folder}/Trained_models/Stages/{training_type}/Stage_{stage}/{model}_{stage}-{current_time}.pkl'
        with open(model_path, 'rb') as f:
            globals()[f'{model}_{stage}'] = pickle.load(f)
        print(f'Model {model}_{stage}_{current_time} loaded successfully.')
        return globals()[f'{model}_{stage}']

In [13]:
validation_performance = 'no'
easy_examples = 'no'
type_training = 'Stages'
file_name = 'Generated_model'

if easy_examples == 'yes':
    if validation_performance == 'yes':
        output_file = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Tests/Validation_performance/tests-Easy_examples_{type_training}-{file_name}.txt'
        folder_path = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Trained_models/Validation_performance/Easy_examples/{type_training}/{file_name}/Stage_3'
    else:
        output_file = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Tests/tests-Easy_examples_{type_training}-{file_name}.txt'
        folder_path = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Trained_models/Easy_examples/{type_training}/{file_name}/Stage_3'
    
else:
    if validation_performance == 'yes':    
        output_file = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Tests/Validation_performance/tests-{type_training}-{file_name}.txt'
        folder_path = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Trained_models/Validation_performance/{type_training}/{file_name}/Stage_3'
    else:
        output_file = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Tests/tests-{type_training}-{file_name}.txt'
        folder_path = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Trained_models/{type_training}/{file_name}/Stage_3'

os.makedirs(os.path.dirname(output_file), exist_ok=True)

tee = Tee(output_file, 'w') 
sys.stdout = tee

try:
    visualize_errors = 0
    model = 'trainable_model_stage'
    
    date_pattern = r'trainable_model_stage_3-(\d{4}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}).pkl'
    files = sorted(
        (f for f in os.listdir(folder_path) if f.endswith('.pkl') and not f.startswith('.')),  # Filtrar archivos .pkl y ocultos
        key=lambda x: re.search(date_pattern, x).group(0) if re.search(date_pattern, x) else ''
    )
    
    for filename in files:
        match = re.search(date_pattern, filename)
        if match:
            current_time = match.group(1)
        else:
            print('Error')
            break
    
        for stage in range(1, 4):
            if easy_examples == 'yes':
                if validation_performance == 'yes':    
                    file_path = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Trained_models/Validation_performance/Easy_examples/{type_training}/{file_name}/Stage_{stage}/{model}_{stage}-{current_time}.pkl'
                else:
                    file_path = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Trained_models/Easy_examples/{type_training}/{file_name}/Stage_{stage}/{model}_{stage}-{current_time}.pkl'
                    
            else:
                if validation_performance == 'yes':   
                    file_path = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Trained_models/Validation_performance/{type_training}/{file_name}/Stage_{stage}/{model}_{stage}-{current_time}.pkl'
                else:
                    file_path = f'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition/Trained_models/{type_training}/{file_name}/Stage_{stage}/{model}_{stage}-{current_time}.pkl'

            with open(file_path, 'rb') as file:
                globals()[f"{model}_{stage}"] = pickle.load(file)
    
            print(f'Loaded {model}_{stage}-{current_time}.pkl')

            real_test = 1
            decide_test(params=globals()[f"{model}_{stage}"], stage=stage, real_test=real_test, visualize_errors=visualize_errors)
                
finally:
    sys.stdout = tee.console
    tee.close()
    
print(f'Finished, file {file_name} created')

Loaded trainable_model_stage_1-2024_11_19_13_22_02.pkl
STAGE 1: Out of 8000, 8000 trained were predicted correctly in the current model.
STAGE 1: Out of 2000, 2000 tested were predicted correctly in the current model.
STAGE 1: Out of 6524, 6524 trained with carry-over were predicted correctly in the current model.
STAGE 1: Out of 1631, 1631 tested with carry-over were predicted correctly in the current model.
STAGE 1: Out of 1536, 1536 trained with small problem size were predicted correctly in the current model.
STAGE 1: Out of 364, 364 tested with small problem size were predicted correctly in the current model.
STAGE 1: Out of 904, 904 trained with small problem size and with carry-over were predicted correctly in the current model.
STAGE 1: Out of 207, 207 tested with small problem size and with carry-over were predicted correctly in the current model.
STAGE 1: Out of 5576, 5576 trained with carry-over decimal were predicted correctly in the current model.
STAGE 1: Out of 1399, 139