In [15]:
import numpy as np
import pandas as pd
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
from tensorflow.keras.models import Model, load_model, clone_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input, Dense, Multiply, Add, Lambda, Concatenate, Reshape, Flatten
from tensorflow.keras.initializers import GlorotUniform, RandomUniform, Constant
from tensorflow.keras.callbacks import LambdaCallback
import jax
import jax.numpy as jnp
from jax import random, grad, jit

In [16]:
# Verificar si hay GPUs disponibles
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print("TensorFlow está utilizando la GPU")
else:
    print("TensorFlow no está utilizando la GPU")

TensorFlow no está utilizando la GPU


In [17]:
# Cargar los módulos preentrenados (unit_module y carry_module)
unit_addition_model = load_model('unit_addition_module.keras')
unit_carry_model = load_model('unit_carry_module.keras')
dec_addition_model = load_model('dec_addition_module.keras')
dec_carry_model = load_model('dec_carry_module.keras')

unit_addition_model.trainable = False
unit_carry_model.trainable = False
dec_addition_model.trainable = False
dec_carry_model.trainable = False

unit_addition_model.name = 'unit_addition_model'
unit_carry_model.name = 'unit_carry_model'
dec_addition_model.name = 'dec_addition_model'
dec_carry_model.name = 'dec_carry_model'

In [18]:
# Cargar las parejas desde el archivo
with open(f"train_couples.txt", "r") as file:
    train_couples = eval(file.read())

with open(f"test_dataset.txt", "r") as file:
    test_dataset = eval(file.read())

def generate_test_dataset():
    x_data = []
    y_data = []
    
    for a, b in test_dataset:
        a_dec = a // 10  # Decena del primer número
        a_unit = a % 10  # Unidad del primer número
        b_dec = b // 10  # Decena del segundo número
        b_unit = b % 10  # Unidad del segundo número

        x_data.append([a_dec, a_unit, b_dec, b_unit])  # Entrada

        sum_units = (a_unit + b_unit) % 10
        carry_units = 1 if (a_unit + b_unit) >= 10 else 0
        sum_dec = (a_dec + b_dec + carry_units) % 10
        carry_dec = 1 if (a_dec + b_dec + carry_units) >= 10 else 0
        y_data.append([carry_dec, sum_dec, sum_units])  # Salida
    
    return jnp.array(x_data), jnp.array(y_data)


# Función para leer los datos desde un archivo .txt y generar el dataset de entrenamiento
def generate_train_dataset():
    x_data = []
    y_data = []
    
    for a, b in train_couples:
        a_dec = a // 10  # Decena del primer número
        a_unit = a % 10  # Unidad del primer número
        b_dec = b // 10  # Decena del segundo número
        b_unit = b % 10  # Unidad del segundo número

        x_data.append([a_dec, a_unit, b_dec, b_unit])  # Entrada

        sum_units = (a_unit + b_unit) % 10
        carry_units = 1 if (a_unit + b_unit) >= 10 else 0
        sum_dec = (a_dec + b_dec + carry_units) % 10
        carry_dec = 1 if (a_dec + b_dec + carry_units) >= 10 else 0
        y_data.append([carry_dec, sum_dec, sum_units])  # Salida
    
    return jnp.array(x_data), jnp.array(y_data)

In [19]:
# Función para generar los datos
def generate_final_data():
    x_data = []
    y_data = []
    for a_dec in range(10):
        for a_unit in range(10):
            for b_dec in range(10):
                for b_unit in range(10):
                    x_data.append([a_dec, a_unit, b_dec, b_unit])  # Entrada
                    sum_units = (a_unit + b_unit) % 10
                    carry_units = 1 if (a_unit + b_unit) >= 10 else 0
                    sum_dec = (a_dec + b_dec + carry_units) % 10
                    carry_dec = 1 if (a_dec + b_dec + carry_units) >= 10 else 0
                    y_data.append([carry_dec, sum_dec, sum_units])  # Salida
    return jnp.array(x_data), jnp.array(y_data)

# Función para crear parámetros entrenables (v_0, ..., v_251)
def init_params(epsilon = 0.1):
    v_values_init = jnp.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                               0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                               0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=jnp.float32)
    key = random.PRNGKey(0)
    keys = random.split(key, 252)
    v_params = {f'v{i}': random.normal(keys[i], (1,)) * epsilon + v_values_init[i] for i in range(252)}
    return v_params

# Función de pérdida
def loss_fn(params, x, y):
    y_pred_1, y_pred_2, y_pred_3 = model(params, x)
    return jnp.mean((y_pred_1 - y[0]) ** 2) + jnp.mean((y_pred_2 - y[1]) ** 2) + jnp.mean((y_pred_3 - y[2]) ** 2)
    
# Función para entrenar el modelo
def update_params(params, x, y, lr):
    # Asegúrate de usar JAX para los gradientes y operaciones
    gradients = grad(loss_fn)(params, x, y)
    step_loss = loss_fn(params, x, y)
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, gradients)
    return new_params, step_loss


def train_model(params, x_train, y_train, lr=0.01, epochs=100):
    final_loss = 0
    # Convertir x_train y y_train a arrays de JAX (si aún no lo son)
    x_train = jnp.array(x_train)
    y_train = jnp.array(y_train)
    
    # Entrenar el modelo
    for epoch in range(epochs):  # Número de épocas
        params, step_loss = update_params(params, x_train[epoch], y_train[epoch], lr)
        final_loss += step_loss
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {step_loss}")
        
    final_loss = final_loss / epochs
    return params, final_loss

# Función para imprimir las predicciones y el loss en cada época
def print_predictions_and_loss(epoch, predictions, y_train):
    pred_count = 0
    total_examples = x_train.shape[0]
    
    for i in range(total_examples):
        # Obtener las predicciones para la unidad, decena y acarreo
        normalized_pred = [int(jnp.round(predictions[j][i])) for j in range(3)]
        
        # Concatenar las predicciones en un número de 3 dígitos
        concatenated_pred = int("".join(str(pred) for pred in normalized_pred))
        
        # Generar la salida esperada, concatenando los valores reales de y_train
        expected_output = int("".join(str(int(round(val))) for val in y_train[i]))
        
        # Comprobar si la predicción es igual a la salida esperada
        if concatenated_pred == expected_output:
            pred_count += 1

    print(f"Epoch {epoch + 1}:")
    print(f"Predicciones correctas: {pred_count} de {total_examples}")
    print("-" * 40)

    # Si todas las predicciones son correctas, detener el entrenamiento
    if pred_count == total_examples:
        print("¡Todas las combinaciones han sido aprendidas correctamente! Deteniendo entrenamiento.")
        return True
    return False

def predictions(params, x_train, y_train):
    pred_count = 0
    total_examples = x_train.shape[0]   
    
    for i in range(total_examples):
        prediction = model(params, x_train[i])
        # Obtener las predicciones para la unidad, decena y acarreo
        normalized_pred = [int(jnp.round(prediction[j].item())) for j in range(3)]
        
        # Concatenar las predicciones en un número de 3 dígitos
        concatenated_pred = int("".join(str(pred) for pred in normalized_pred))
        
        # Generar la salida esperada, concatenando los valores reales de y_train
        expected_output = int("".join(str(int(round(val))) for val in y_train[i]))
        
        # Comprobar si la predicción es igual a la salida esperada
        if concatenated_pred == expected_output:
            pred_count += 1

    print(f"Predicciones correctas: {pred_count} de {total_examples}")

    if pred_count == total_examples:
        print("¡Todas las combinaciones han sido aprendidas correctamente!")

def count_predictions(params, x_train, y_train):
    pred_count = 0
    total_examples =  x_train.shape[0]   
    
    for i in range(total_examples):
        prediction = model(params, x_train[i])
        # Obtener las predicciones para la unidad, decena y acarreo
        normalized_pred = [int(jnp.round(prediction[j].item())) for j in range(3)]
        
        if normalized_pred[0] == y_train[i,0]:
            if normalized_pred[1] == y_train[i,1]:
                if normalized_pred[2] == y_train[i,2]:
                    pred_count += 1

        if (normalized_pred[0] != y_train[i,0]) or (normalized_pred[1]!= y_train[i,1]) or (normalized_pred[2] != y_train[i,2]):
            print(f'Error en la suma: {x_train[i]}')
            print(f'Predicción: {[normalized_pred[0], normalized_pred[1], normalized_pred[2]]}')
            break
            
    print(f"Predicciones correctas: {pred_count} de {total_examples}")

    if pred_count == total_examples:
        print("¡Todas las combinaciones han sido aprendidas correctamente!")

In [20]:
# Modelo dinámico en JAX
def model(params, x):
    # Extraer unidades de los valores de entrada
    units_input = jnp.array([x[1], x[3]])
    units_input = units_input[None, None, :]
    
    incorrect_units_1_input = jnp.array([x[0], x[1]])
    incorrect_units_1_input = incorrect_units_1_input[None, None, :]
    
    incorrect_units_2_input = jnp.array([x[0], x[2]])
    incorrect_units_2_input = incorrect_units_2_input[None, None, :]
    
    incorrect_units_3_input = jnp.array([x[0], x[3]])
    incorrect_units_3_input = incorrect_units_3_input[None, None, :]
    
    incorrect_units_4_input = jnp.array([x[1], x[2]])
    incorrect_units_4_input = incorrect_units_4_input[None, None, :]
    
    incorrect_units_5_input = jnp.array([x[2], x[3]])
    incorrect_units_5_input = incorrect_units_5_input[None, None, :]
    
    # Llamar a los modelos unit_module y carry_module
    unit_output = jnp.array(unit_addition_model(units_input))
    unit_carry_output = jnp.array(unit_carry_model(units_input)) 
    
    incorrect_units_1_output = jnp.array(unit_addition_model(incorrect_units_1_input))
    incorrect_units_1_carry_output = jnp.array(unit_carry_model(incorrect_units_1_input)) 
    
    incorrect_units_2_output = jnp.array(unit_addition_model(incorrect_units_2_input))
    incorrect_units_2_carry_output = jnp.array(unit_carry_model(incorrect_units_2_input)) 
    
    incorrect_units_3_output = jnp.array(unit_addition_model(incorrect_units_3_input))
    incorrect_units_3_carry_output = jnp.array(unit_carry_model(incorrect_units_3_input)) 
    
    incorrect_units_4_output = jnp.array(unit_addition_model(incorrect_units_4_input))
    incorrect_units_4_carry_output = jnp.array(unit_carry_model(incorrect_units_4_input)) 
    
    incorrect_units_5_output = jnp.array(unit_addition_model(incorrect_units_5_input))
    incorrect_units_5_carry_output = jnp.array(unit_carry_model(incorrect_units_5_input)) 
    
    # Tomar el valor máximo de las predicciones (argmax en JAX)
    unit_val = jnp.argmax(unit_output, axis=-1)
    carry_unit_val = jnp.argmax(unit_carry_output, axis=-1)
    
    incorrect_units_1_val = jnp.argmax(incorrect_units_1_output, axis=-1)
    carry_incorrect_units_1_val = jnp.argmax(incorrect_units_1_carry_output, axis=-1)
    
    incorrect_units_2_val = jnp.argmax(incorrect_units_2_output, axis=-1)
    carry_incorrect_units_2_val = jnp.argmax(incorrect_units_2_carry_output, axis=-1)
    
    incorrect_units_3_val = jnp.argmax(incorrect_units_3_output, axis=-1)
    carry_incorrect_units_3_val = jnp.argmax(incorrect_units_3_carry_output, axis=-1)
    
    incorrect_units_4_val = jnp.argmax(incorrect_units_4_output, axis=-1)
    carry_incorrect_units_4_val = jnp.argmax(incorrect_units_4_carry_output, axis=-1)
    
    incorrect_units_5_val = jnp.argmax(incorrect_units_5_output, axis=-1)
    carry_incorrect_units_5_val = jnp.argmax(incorrect_units_5_carry_output, axis=-1)
    
    
    # Extraer decenas de los valores de entrada
    decs_input = jnp.array([x[0], x[2], carry_unit_val[0]])
    decs_input = decs_input[None, None, :]
    
    incorrect_decs_1_input = jnp.array([x[0], x[2], carry_incorrect_units_1_val[0]])
    incorrect_decs_1_input = incorrect_decs_1_input[None, None, :]
    
    incorrect_decs_2_input = jnp.array([x[0], x[2], carry_incorrect_units_2_val[0]])
    incorrect_decs_2_input = incorrect_decs_2_input[None, None, :]
    
    incorrect_decs_3_input = jnp.array([x[0], x[2], carry_incorrect_units_3_val[0]])
    incorrect_decs_3_input = incorrect_decs_3_input[None, None, :]
    
    incorrect_decs_4_input = jnp.array([x[0], x[2], carry_incorrect_units_4_val[0]])
    incorrect_decs_4_input = incorrect_decs_4_input[None, None, :]
    
    incorrect_decs_5_input = jnp.array([x[0], x[2], carry_incorrect_units_5_val[0]])
    incorrect_decs_5_input = incorrect_decs_5_input[None, None, :]
    
    
    incorrect_decs_6_input = jnp.array([x[0], x[1], carry_unit_val[0]])
    incorrect_decs_6_input = incorrect_decs_6_input[None, None, :]
    
    incorrect_decs_7_input = jnp.array([x[0], x[1], carry_incorrect_units_1_val[0]])
    incorrect_decs_7_input = incorrect_decs_7_input[None, None, :]
    
    incorrect_decs_8_input = jnp.array([x[0], x[1], carry_incorrect_units_2_val[0]])
    incorrect_decs_8_input = incorrect_decs_8_input[None, None, :]
    
    incorrect_decs_9_input = jnp.array([x[0], x[1], carry_incorrect_units_3_val[0]])
    incorrect_decs_9_input = incorrect_decs_9_input[None, None, :]
    
    incorrect_decs_10_input = jnp.array([x[0], x[1], carry_incorrect_units_4_val[0]])
    incorrect_decs_10_input = incorrect_decs_10_input[None, None, :]
    
    incorrect_decs_11_input = jnp.array([x[0], x[1], carry_incorrect_units_5_val[0]])
    incorrect_decs_11_input = incorrect_decs_11_input[None, None, :]
    
    
    incorrect_decs_12_input = jnp.array([x[0], x[3], carry_unit_val[0]])
    incorrect_decs_12_input = incorrect_decs_12_input[None, None, :]
    
    incorrect_decs_13_input = jnp.array([x[0], x[3], carry_incorrect_units_1_val[0]])
    incorrect_decs_13_input = incorrect_decs_13_input[None, None, :]
    
    incorrect_decs_14_input = jnp.array([x[0], x[3], carry_incorrect_units_2_val[0]])
    incorrect_decs_14_input = incorrect_decs_14_input[None, None, :]
    
    incorrect_decs_15_input = jnp.array([x[0], x[3], carry_incorrect_units_3_val[0]])
    incorrect_decs_15_input = incorrect_decs_15_input[None, None, :]
    
    incorrect_decs_16_input = jnp.array([x[0], x[3], carry_incorrect_units_4_val[0]])
    incorrect_decs_16_input = incorrect_decs_16_input[None, None, :]
    
    incorrect_decs_17_input = jnp.array([x[0], x[3], carry_incorrect_units_5_val[0]])
    incorrect_decs_17_input = incorrect_decs_17_input[None, None, :]
    
    
    incorrect_decs_18_input = jnp.array([x[1], x[2], carry_unit_val[0]])
    incorrect_decs_18_input = incorrect_decs_18_input[None, None, :]
    
    incorrect_decs_19_input = jnp.array([x[1], x[2], carry_incorrect_units_1_val[0]])
    incorrect_decs_19_input = incorrect_decs_19_input[None, None, :]
    
    incorrect_decs_20_input = jnp.array([x[1], x[2], carry_incorrect_units_2_val[0]])
    incorrect_decs_20_input = incorrect_decs_20_input[None, None, :]
    
    incorrect_decs_21_input = jnp.array([x[1], x[2], carry_incorrect_units_3_val[0]])
    incorrect_decs_21_input = incorrect_decs_21_input[None, None, :]
    
    incorrect_decs_22_input = jnp.array([x[1], x[2], carry_incorrect_units_4_val[0]])
    incorrect_decs_22_input = incorrect_decs_22_input[None, None, :]
    
    incorrect_decs_23_input = jnp.array([x[1], x[2], carry_incorrect_units_5_val[0]])
    incorrect_decs_23_input = incorrect_decs_23_input[None, None, :]
    
    
    incorrect_decs_24_input = jnp.array([x[1], x[3], carry_unit_val[0]])
    incorrect_decs_24_input = incorrect_decs_24_input[None, None, :]
    
    incorrect_decs_25_input = jnp.array([x[1], x[3], carry_incorrect_units_1_val[0]])
    incorrect_decs_25_input = incorrect_decs_25_input[None, None, :]
    
    incorrect_decs_26_input = jnp.array([x[1], x[3], carry_incorrect_units_2_val[0]])
    incorrect_decs_26_input = incorrect_decs_26_input[None, None, :]
    
    incorrect_decs_27_input = jnp.array([x[1], x[3], carry_incorrect_units_3_val[0]])
    incorrect_decs_27_input = incorrect_decs_27_input[None, None, :]
    
    incorrect_decs_28_input = jnp.array([x[1], x[3], carry_incorrect_units_4_val[0]])
    incorrect_decs_28_input = incorrect_decs_28_input[None, None, :]
    
    incorrect_decs_29_input = jnp.array([x[1], x[3], carry_incorrect_units_5_val[0]])
    incorrect_decs_29_input = incorrect_decs_29_input[None, None, :]
    
    
    incorrect_decs_30_input = jnp.array([x[2], x[3], carry_unit_val[0]])
    incorrect_decs_30_input = incorrect_decs_30_input[None, None, :]
    
    incorrect_decs_31_input = jnp.array([x[2], x[3], carry_incorrect_units_1_val[0]])
    incorrect_decs_31_input = incorrect_decs_31_input[None, None, :]
    
    incorrect_decs_32_input = jnp.array([x[2], x[3], carry_incorrect_units_2_val[0]])
    incorrect_decs_32_input = incorrect_decs_32_input[None, None, :]
    
    incorrect_decs_33_input = jnp.array([x[2], x[3], carry_incorrect_units_3_val[0]])
    incorrect_decs_33_input = incorrect_decs_33_input[None, None, :]
    
    incorrect_decs_34_input = jnp.array([x[2], x[3], carry_incorrect_units_4_val[0]])
    incorrect_decs_34_input = incorrect_decs_34_input[None, None, :]
    
    incorrect_decs_35_input = jnp.array([x[2], x[3], carry_incorrect_units_5_val[0]])
    incorrect_decs_35_input = incorrect_decs_35_input[None, None, :]
    
    
    # Llamar a los modelos unit_module y carry_module
    dec_output = jnp.array(dec_addition_model(decs_input))
    dec_carry_output = jnp.array(dec_carry_model(decs_input)) 
    dec_val = jnp.argmax(dec_output, axis=-1)
    carry_dec_val = jnp.argmax(dec_carry_output, axis=-1)
    
    incorrect_decs_1_output = jnp.array(dec_addition_model(incorrect_decs_1_input))
    incorrect_decs_1_carry_output = jnp.array(dec_carry_model(incorrect_decs_1_input))
    incorrect_decs_1_val = jnp.argmax(incorrect_decs_1_output, axis=-1)
    incorrect_decs_1_carry_val = jnp.argmax(incorrect_decs_1_carry_output, axis=-1)
    
    incorrect_decs_2_output = jnp.array(dec_addition_model(incorrect_decs_2_input))
    incorrect_decs_2_carry_output = jnp.array(dec_carry_model(incorrect_decs_2_input))
    incorrect_decs_2_val = jnp.argmax(incorrect_decs_2_output, axis=-1)
    incorrect_decs_2_carry_val = jnp.argmax(incorrect_decs_2_carry_output, axis=-1)
    
    incorrect_decs_3_output = jnp.array(dec_addition_model(incorrect_decs_3_input))
    incorrect_decs_3_carry_output = jnp.array(dec_carry_model(incorrect_decs_3_input))
    incorrect_decs_3_val = jnp.argmax(incorrect_decs_3_output, axis=-1)
    incorrect_decs_3_carry_val = jnp.argmax(incorrect_decs_3_carry_output, axis=-1)
    
    incorrect_decs_4_output = jnp.array(dec_addition_model(incorrect_decs_4_input))
    incorrect_decs_4_carry_output = jnp.array(dec_carry_model(incorrect_decs_4_input))
    incorrect_decs_4_val = jnp.argmax(incorrect_decs_4_output, axis=-1)
    incorrect_decs_4_carry_val = jnp.argmax(incorrect_decs_4_carry_output, axis=-1)
    
    incorrect_decs_5_output = jnp.array(dec_addition_model(incorrect_decs_5_input))
    incorrect_decs_5_carry_output = jnp.array(dec_carry_model(incorrect_decs_5_input))
    incorrect_decs_5_val = jnp.argmax(incorrect_decs_5_output, axis=-1)
    incorrect_decs_5_carry_val = jnp.argmax(incorrect_decs_5_carry_output, axis=-1)
    
    incorrect_decs_6_output = jnp.array(dec_addition_model(incorrect_decs_6_input))
    incorrect_decs_6_carry_output = jnp.array(dec_carry_model(incorrect_decs_6_input))
    incorrect_decs_6_val = jnp.argmax(incorrect_decs_6_output, axis=-1)
    incorrect_decs_6_carry_val = jnp.argmax(incorrect_decs_6_carry_output, axis=-1)
    
    incorrect_decs_7_output = jnp.array(dec_addition_model(incorrect_decs_7_input))
    incorrect_decs_7_carry_output = jnp.array(dec_carry_model(incorrect_decs_7_input))
    incorrect_decs_7_val = jnp.argmax(incorrect_decs_7_output, axis=-1)
    incorrect_decs_7_carry_val = jnp.argmax(incorrect_decs_7_carry_output, axis=-1)
    
    incorrect_decs_8_output = jnp.array(dec_addition_model(incorrect_decs_8_input))
    incorrect_decs_8_carry_output = jnp.array(dec_carry_model(incorrect_decs_8_input))
    incorrect_decs_8_val = jnp.argmax(incorrect_decs_8_output, axis=-1)
    incorrect_decs_8_carry_val = jnp.argmax(incorrect_decs_8_carry_output, axis=-1)
    
    incorrect_decs_9_output = jnp.array(dec_addition_model(incorrect_decs_9_input))
    incorrect_decs_9_carry_output = jnp.array(dec_carry_model(incorrect_decs_9_input))
    incorrect_decs_9_val = jnp.argmax(incorrect_decs_9_output, axis=-1)
    incorrect_decs_9_carry_val = jnp.argmax(incorrect_decs_9_carry_output, axis=-1)
    
    incorrect_decs_10_output = jnp.array(dec_addition_model(incorrect_decs_10_input))
    incorrect_decs_10_carry_output = jnp.array(dec_carry_model(incorrect_decs_10_input))
    incorrect_decs_10_val = jnp.argmax(incorrect_decs_10_output, axis=-1)
    incorrect_decs_10_carry_val = jnp.argmax(incorrect_decs_10_carry_output, axis=-1)
    
    incorrect_decs_11_output = jnp.array(dec_addition_model(incorrect_decs_11_input))
    incorrect_decs_11_carry_output = jnp.array(dec_carry_model(incorrect_decs_11_input))
    incorrect_decs_11_val = jnp.argmax(incorrect_decs_11_output, axis=-1)
    incorrect_decs_11_carry_val = jnp.argmax(incorrect_decs_11_carry_output, axis=-1)
    
    incorrect_decs_12_output = jnp.array(dec_addition_model(incorrect_decs_12_input))
    incorrect_decs_12_carry_output = jnp.array(dec_carry_model(incorrect_decs_12_input))
    incorrect_decs_12_val = jnp.argmax(incorrect_decs_12_output, axis=-1)
    incorrect_decs_12_carry_val = jnp.argmax(incorrect_decs_12_carry_output, axis=-1)
    
    incorrect_decs_13_output = jnp.array(dec_addition_model(incorrect_decs_13_input))
    incorrect_decs_13_carry_output = jnp.array(dec_carry_model(incorrect_decs_13_input))
    incorrect_decs_13_val = jnp.argmax(incorrect_decs_13_output, axis=-1)
    incorrect_decs_13_carry_val = jnp.argmax(incorrect_decs_13_carry_output, axis=-1)
    
    incorrect_decs_14_output = jnp.array(dec_addition_model(incorrect_decs_14_input))
    incorrect_decs_14_carry_output = jnp.array(dec_carry_model(incorrect_decs_14_input))
    incorrect_decs_14_val = jnp.argmax(incorrect_decs_14_output, axis=-1)
    incorrect_decs_14_carry_val = jnp.argmax(incorrect_decs_14_carry_output, axis=-1)
    
    incorrect_decs_15_output = jnp.array(dec_addition_model(incorrect_decs_15_input))
    incorrect_decs_15_carry_output = jnp.array(dec_carry_model(incorrect_decs_15_input))
    incorrect_decs_15_val = jnp.argmax(incorrect_decs_15_output, axis=-1)
    incorrect_decs_15_carry_val = jnp.argmax(incorrect_decs_15_carry_output, axis=-1)
    
    incorrect_decs_16_output = jnp.array(dec_addition_model(incorrect_decs_16_input))
    incorrect_decs_16_carry_output = jnp.array(dec_carry_model(incorrect_decs_16_input))
    incorrect_decs_16_val = jnp.argmax(incorrect_decs_16_output, axis=-1)
    incorrect_decs_16_carry_val = jnp.argmax(incorrect_decs_16_carry_output, axis=-1)
    
    incorrect_decs_17_output = jnp.array(dec_addition_model(incorrect_decs_17_input))
    incorrect_decs_17_carry_output = jnp.array(dec_carry_model(incorrect_decs_17_input))
    incorrect_decs_17_val = jnp.argmax(incorrect_decs_17_output, axis=-1)
    incorrect_decs_17_carry_val = jnp.argmax(incorrect_decs_17_carry_output, axis=-1)
    
    incorrect_decs_18_output = jnp.array(dec_addition_model(incorrect_decs_18_input))
    incorrect_decs_18_carry_output = jnp.array(dec_carry_model(incorrect_decs_18_input))
    incorrect_decs_18_val = jnp.argmax(incorrect_decs_18_output, axis=-1)
    incorrect_decs_18_carry_val = jnp.argmax(incorrect_decs_18_carry_output, axis=-1)
    
    incorrect_decs_19_output = jnp.array(dec_addition_model(incorrect_decs_19_input))
    incorrect_decs_19_carry_output = jnp.array(dec_carry_model(incorrect_decs_19_input))
    incorrect_decs_19_val = jnp.argmax(incorrect_decs_19_output, axis=-1)
    incorrect_decs_19_carry_val = jnp.argmax(incorrect_decs_19_carry_output, axis=-1)
    
    incorrect_decs_20_output = jnp.array(dec_addition_model(incorrect_decs_20_input))
    incorrect_decs_20_carry_output = jnp.array(dec_carry_model(incorrect_decs_20_input))
    incorrect_decs_20_val = jnp.argmax(incorrect_decs_20_output, axis=-1)
    incorrect_decs_20_carry_val = jnp.argmax(incorrect_decs_20_carry_output, axis=-1)
    
    incorrect_decs_21_output = jnp.array(dec_addition_model(incorrect_decs_21_input))
    incorrect_decs_21_carry_output = jnp.array(dec_carry_model(incorrect_decs_21_input))
    incorrect_decs_21_val = jnp.argmax(incorrect_decs_21_output, axis=-1)
    incorrect_decs_21_carry_val = jnp.argmax(incorrect_decs_21_carry_output, axis=-1)
    
    incorrect_decs_22_output = jnp.array(dec_addition_model(incorrect_decs_22_input))
    incorrect_decs_22_carry_output = jnp.array(dec_carry_model(incorrect_decs_22_input))
    incorrect_decs_22_val = jnp.argmax(incorrect_decs_22_output, axis=-1)
    incorrect_decs_22_carry_val = jnp.argmax(incorrect_decs_22_carry_output, axis=-1)
    
    incorrect_decs_23_output = jnp.array(dec_addition_model(incorrect_decs_23_input))
    incorrect_decs_23_carry_output = jnp.array(dec_carry_model(incorrect_decs_23_input))
    incorrect_decs_23_val = jnp.argmax(incorrect_decs_23_output, axis=-1)
    incorrect_decs_23_carry_val = jnp.argmax(incorrect_decs_23_carry_output, axis=-1)
    
    incorrect_decs_24_output = jnp.array(dec_addition_model(incorrect_decs_24_input))
    incorrect_decs_24_carry_output = jnp.array(dec_carry_model(incorrect_decs_24_input))
    incorrect_decs_24_val = jnp.argmax(incorrect_decs_24_output, axis=-1)
    incorrect_decs_24_carry_val = jnp.argmax(incorrect_decs_24_carry_output, axis=-1)
    
    incorrect_decs_25_output = jnp.array(dec_addition_model(incorrect_decs_25_input))
    incorrect_decs_25_carry_output = jnp.array(dec_carry_model(incorrect_decs_25_input))
    incorrect_decs_25_val = jnp.argmax(incorrect_decs_25_output, axis=-1)
    incorrect_decs_25_carry_val = jnp.argmax(incorrect_decs_25_carry_output, axis=-1)
    
    incorrect_decs_26_output = jnp.array(dec_addition_model(incorrect_decs_26_input))
    incorrect_decs_26_carry_output = jnp.array(dec_carry_model(incorrect_decs_26_input))
    incorrect_decs_26_val = jnp.argmax(incorrect_decs_26_output, axis=-1)
    incorrect_decs_26_carry_val = jnp.argmax(incorrect_decs_26_carry_output, axis=-1)
    
    incorrect_decs_27_output = jnp.array(dec_addition_model(incorrect_decs_27_input))
    incorrect_decs_27_carry_output = jnp.array(dec_carry_model(incorrect_decs_27_input))
    incorrect_decs_27_val = jnp.argmax(incorrect_decs_27_output, axis=-1)
    incorrect_decs_27_carry_val = jnp.argmax(incorrect_decs_27_carry_output, axis=-1)
    
    incorrect_decs_28_output = jnp.array(dec_addition_model(incorrect_decs_28_input))
    incorrect_decs_28_carry_output = jnp.array(dec_carry_model(incorrect_decs_28_input))
    incorrect_decs_28_val = jnp.argmax(incorrect_decs_28_output, axis=-1)
    incorrect_decs_28_carry_val = jnp.argmax(incorrect_decs_28_carry_output, axis=-1)
    
    incorrect_decs_29_output = jnp.array(dec_addition_model(incorrect_decs_29_input))
    incorrect_decs_29_carry_output = jnp.array(dec_carry_model(incorrect_decs_29_input))
    incorrect_decs_29_val = jnp.argmax(incorrect_decs_29_output, axis=-1)
    incorrect_decs_29_carry_val = jnp.argmax(incorrect_decs_29_carry_output, axis=-1)
    
    incorrect_decs_30_output = jnp.array(dec_addition_model(incorrect_decs_30_input))
    incorrect_decs_30_carry_output = jnp.array(dec_carry_model(incorrect_decs_30_input))
    incorrect_decs_30_val = jnp.argmax(incorrect_decs_30_output, axis=-1)
    incorrect_decs_30_carry_val = jnp.argmax(incorrect_decs_30_carry_output, axis=-1)
    
    incorrect_decs_31_output = jnp.array(dec_addition_model(incorrect_decs_31_input))
    incorrect_decs_31_carry_output = jnp.array(dec_carry_model(incorrect_decs_31_input))
    incorrect_decs_31_val = jnp.argmax(incorrect_decs_31_output, axis=-1)
    incorrect_decs_31_carry_val = jnp.argmax(incorrect_decs_31_carry_output, axis=-1)
    
    incorrect_decs_32_output = jnp.array(dec_addition_model(incorrect_decs_32_input))
    incorrect_decs_32_carry_output = jnp.array(dec_carry_model(incorrect_decs_32_input))
    incorrect_decs_32_val = jnp.argmax(incorrect_decs_32_output, axis=-1)
    incorrect_decs_32_carry_val = jnp.argmax(incorrect_decs_32_carry_output, axis=-1)
    
    incorrect_decs_33_output = jnp.array(dec_addition_model(incorrect_decs_33_input))
    incorrect_decs_33_carry_output = jnp.array(dec_carry_model(incorrect_decs_33_input))
    incorrect_decs_33_val = jnp.argmax(incorrect_decs_33_output, axis=-1)
    incorrect_decs_33_carry_val = jnp.argmax(incorrect_decs_33_carry_output, axis=-1)
    
    incorrect_decs_34_output = jnp.array(dec_addition_model(incorrect_decs_34_input))
    incorrect_decs_34_carry_output = jnp.array(dec_carry_model(incorrect_decs_34_input))
    incorrect_decs_34_val = jnp.argmax(incorrect_decs_34_output, axis=-1)
    incorrect_decs_34_carry_val = jnp.argmax(incorrect_decs_34_carry_output, axis=-1)
    
    incorrect_decs_35_output = jnp.array(dec_addition_model(incorrect_decs_35_input))
    incorrect_decs_35_carry_output = jnp.array(dec_carry_model(incorrect_decs_35_input))
    incorrect_decs_35_val = jnp.argmax(incorrect_decs_35_output, axis=-1)
    incorrect_decs_35_carry_val = jnp.argmax(incorrect_decs_35_carry_output, axis=-1)
    
    # Calcular las salidas combinadas con los parámetros v
    
    salida_1 = ((params['v0'] * carry_dec_val) + (params['v1'] * dec_val) + (params['v2'] * carry_unit_val) + (params['v3'] * unit_val)
                    + (params['v4'] * carry_incorrect_units_1_val) + (params['v5'] * incorrect_units_1_val) + (params['v6'] * carry_incorrect_units_2_val) + (params['v7'] * incorrect_units_2_val)
                    + (params['v8'] * carry_incorrect_units_3_val) + (params['v9'] * incorrect_units_3_val) + (params['v10'] * carry_incorrect_units_4_val) + (params['v11'] * incorrect_units_4_val)
                    + (params['v12'] * carry_incorrect_units_5_val) + (params['v13'] * incorrect_units_5_val) + (params['v14'] * incorrect_decs_1_carry_val) + (params['v15'] * incorrect_decs_1_val)
                    + (params['v16'] * incorrect_decs_2_carry_val) + (params['v17'] * incorrect_decs_2_val) + (params['v18'] * incorrect_decs_3_carry_val) + (params['v19'] * incorrect_decs_3_val)
                    + (params['v20'] * incorrect_decs_4_carry_val) + (params['v21'] * incorrect_decs_4_val) + (params['v22'] * incorrect_decs_5_carry_val) + (params['v23'] * incorrect_decs_5_val)
                    + (params['v24'] * incorrect_decs_6_carry_val) + (params['v25'] * incorrect_decs_6_val) + (params['v26'] * incorrect_decs_7_carry_val) + (params['v27'] * incorrect_decs_7_val)
                    + (params['v28'] * incorrect_decs_8_carry_val) + (params['v29'] * incorrect_decs_8_val) + (params['v30'] * incorrect_decs_9_carry_val) + (params['v31'] * incorrect_decs_9_val)
                    + (params['v32'] * incorrect_decs_10_carry_val) + (params['v33'] * incorrect_decs_10_val) + (params['v34'] * incorrect_decs_11_carry_val) + (params['v35'] * incorrect_decs_11_val)
                    + (params['v36'] * incorrect_decs_12_carry_val) + (params['v37'] * incorrect_decs_12_val) + (params['v38'] * incorrect_decs_13_carry_val) + (params['v39'] * incorrect_decs_13_val)
                    + (params['v40'] * incorrect_decs_14_carry_val) + (params['v41'] * incorrect_decs_14_val) + (params['v42'] * incorrect_decs_15_carry_val) + (params['v43'] * incorrect_decs_15_val)
                    + (params['v44'] * incorrect_decs_16_carry_val) + (params['v45'] * incorrect_decs_16_val) + (params['v46'] * incorrect_decs_17_carry_val) + (params['v47'] * incorrect_decs_17_val)
                    + (params['v48'] * incorrect_decs_18_carry_val) + (params['v49'] * incorrect_decs_18_val) + (params['v50'] * incorrect_decs_19_carry_val) + (params['v51'] * incorrect_decs_19_val)
                    + (params['v52'] * incorrect_decs_20_carry_val) + (params['v53'] * incorrect_decs_20_val) + (params['v54'] * incorrect_decs_21_carry_val) + (params['v55'] * incorrect_decs_21_val)
                    + (params['v56'] * incorrect_decs_22_carry_val) + (params['v57'] * incorrect_decs_22_val) + (params['v58'] * incorrect_decs_23_carry_val) + (params['v59'] * incorrect_decs_23_val)
                    + (params['v60'] * incorrect_decs_24_carry_val) + (params['v61'] * incorrect_decs_24_val) + (params['v62'] * incorrect_decs_25_carry_val) + (params['v63'] * incorrect_decs_25_val)
                    + (params['v64'] * incorrect_decs_26_carry_val) + (params['v65'] * incorrect_decs_26_val) + (params['v66'] * incorrect_decs_27_carry_val) + (params['v67'] * incorrect_decs_27_val)
                    + (params['v68'] * incorrect_decs_28_carry_val) + (params['v69'] * incorrect_decs_28_val) + (params['v70'] * incorrect_decs_29_carry_val) + (params['v71'] * incorrect_decs_29_val)
                    + (params['v72'] * incorrect_decs_30_carry_val) + (params['v73'] * incorrect_decs_30_val) + (params['v74'] * incorrect_decs_31_carry_val) + (params['v75'] * incorrect_decs_31_val)
                    + (params['v76'] * incorrect_decs_32_carry_val) + (params['v77'] * incorrect_decs_32_val) + (params['v78'] * incorrect_decs_33_carry_val) + (params['v79'] * incorrect_decs_33_val)
                    + (params['v80'] * incorrect_decs_34_carry_val) + (params['v81'] * incorrect_decs_34_val) + (params['v82'] * incorrect_decs_35_carry_val) + (params['v83'] * incorrect_decs_35_val))   
    
    salida_2 = ((params['v84'] * carry_dec_val) + (params['v85'] * dec_val) + (params['v86'] * carry_unit_val) + (params['v87'] * unit_val)
                    + (params['v88'] * carry_incorrect_units_1_val) + (params['v89'] * incorrect_units_1_val) + (params['v90'] * carry_incorrect_units_2_val) + (params['v91'] * incorrect_units_2_val)
                    + (params['v92'] * carry_incorrect_units_3_val) + (params['v93'] * incorrect_units_3_val) + (params['v94'] * carry_incorrect_units_4_val) + (params['v95'] * incorrect_units_4_val)
                    + (params['v96'] * carry_incorrect_units_5_val) + (params['v97'] * incorrect_units_5_val) + (params['v98'] * incorrect_decs_1_carry_val) + (params['v99'] * incorrect_decs_1_val)
                    + (params['v100'] * incorrect_decs_2_carry_val) + (params['v101'] * incorrect_decs_2_val) + (params['v102'] * incorrect_decs_3_carry_val) + (params['v103'] * incorrect_decs_3_val)
                    + (params['v104'] * incorrect_decs_4_carry_val) + (params['v105'] * incorrect_decs_4_val) + (params['v106'] * incorrect_decs_5_carry_val) + (params['v107'] * incorrect_decs_5_val)
                    + (params['v108'] * incorrect_decs_6_carry_val) + (params['v109'] * incorrect_decs_6_val) + (params['v110'] * incorrect_decs_7_carry_val) + (params['v111'] * incorrect_decs_7_val)
                    + (params['v112'] * incorrect_decs_8_carry_val) + (params['v113'] * incorrect_decs_8_val) + (params['v114'] * incorrect_decs_9_carry_val) + (params['v115'] * incorrect_decs_9_val)
                    + (params['v116'] * incorrect_decs_10_carry_val) + (params['v117'] * incorrect_decs_10_val) + (params['v118'] * incorrect_decs_11_carry_val) + (params['v119'] * incorrect_decs_11_val)
                    + (params['v120'] * incorrect_decs_12_carry_val) + (params['v121'] * incorrect_decs_12_val) + (params['v122'] * incorrect_decs_13_carry_val) + (params['v123'] * incorrect_decs_13_val)
                    + (params['v124'] * incorrect_decs_14_carry_val) + (params['v125'] * incorrect_decs_14_val) + (params['v126'] * incorrect_decs_15_carry_val) + (params['v127'] * incorrect_decs_15_val)
                    + (params['v128'] * incorrect_decs_16_carry_val) + (params['v129'] * incorrect_decs_16_val) + (params['v130'] * incorrect_decs_17_carry_val) + (params['v131'] * incorrect_decs_17_val)
                    + (params['v132'] * incorrect_decs_18_carry_val) + (params['v133'] * incorrect_decs_18_val) + (params['v134'] * incorrect_decs_19_carry_val) + (params['v135'] * incorrect_decs_19_val)
                    + (params['v136'] * incorrect_decs_20_carry_val) + (params['v137'] * incorrect_decs_20_val) + (params['v138'] * incorrect_decs_21_carry_val) + (params['v139'] * incorrect_decs_21_val)
                    + (params['v140'] * incorrect_decs_22_carry_val) + (params['v141'] * incorrect_decs_22_val) + (params['v142'] * incorrect_decs_23_carry_val) + (params['v143'] * incorrect_decs_23_val)
                    + (params['v144'] * incorrect_decs_24_carry_val) + (params['v145'] * incorrect_decs_24_val) + (params['v146'] * incorrect_decs_25_carry_val) + (params['v147'] * incorrect_decs_25_val)
                    + (params['v148'] * incorrect_decs_26_carry_val) + (params['v149'] * incorrect_decs_26_val) + (params['v150'] * incorrect_decs_27_carry_val) + (params['v151'] * incorrect_decs_27_val)
                    + (params['v152'] * incorrect_decs_28_carry_val) + (params['v153'] * incorrect_decs_28_val) + (params['v154'] * incorrect_decs_29_carry_val) + (params['v155'] * incorrect_decs_29_val)
                    + (params['v156'] * incorrect_decs_30_carry_val) + (params['v157'] * incorrect_decs_30_val) + (params['v158'] * incorrect_decs_31_carry_val) + (params['v159'] * incorrect_decs_31_val)
                    + (params['v160'] * incorrect_decs_32_carry_val) + (params['v161'] * incorrect_decs_32_val) + (params['v162'] * incorrect_decs_33_carry_val) + (params['v163'] * incorrect_decs_33_val)
                    + (params['v164'] * incorrect_decs_34_carry_val) + (params['v165'] * incorrect_decs_34_val) + (params['v166'] * incorrect_decs_35_carry_val) + (params['v167'] * incorrect_decs_35_val))
    
    salida_3 = ((params['v168'] * carry_dec_val) + (params['v169'] * dec_val) + (params['v170'] * carry_unit_val) + (params['v171'] * unit_val)
                    + (params['v172'] * carry_incorrect_units_1_val) + (params['v173'] * incorrect_units_1_val) + (params['v174'] * carry_incorrect_units_2_val) + (params['v175'] * incorrect_units_2_val)
                    + (params['v176'] * carry_incorrect_units_3_val) + (params['v177'] * incorrect_units_3_val) + (params['v178'] * carry_incorrect_units_4_val) + (params['v179'] * incorrect_units_4_val)
                    + (params['v180'] * carry_incorrect_units_5_val) + (params['v181'] * incorrect_units_5_val) + (params['v182'] * incorrect_decs_1_carry_val) + (params['v183'] * incorrect_decs_1_val)
                    + (params['v184'] * incorrect_decs_2_carry_val) + (params['v185'] * incorrect_decs_2_val) + (params['v186'] * incorrect_decs_3_carry_val) + (params['v187'] * incorrect_decs_3_val)
                    + (params['v188'] * incorrect_decs_4_carry_val) + (params['v189'] * incorrect_decs_4_val) + (params['v190'] * incorrect_decs_5_carry_val) + (params['v191'] * incorrect_decs_5_val)
                    + (params['v192'] * incorrect_decs_6_carry_val) + (params['v193'] * incorrect_decs_6_val) + (params['v194'] * incorrect_decs_7_carry_val) + (params['v195'] * incorrect_decs_7_val)
                    + (params['v196'] * incorrect_decs_8_carry_val) + (params['v197'] * incorrect_decs_8_val) + (params['v198'] * incorrect_decs_9_carry_val) + (params['v199'] * incorrect_decs_9_val)
                    + (params['v200'] * incorrect_decs_10_carry_val) + (params['v201'] * incorrect_decs_10_val) + (params['v202'] * incorrect_decs_11_carry_val) + (params['v203'] * incorrect_decs_11_val)
                    + (params['v204'] * incorrect_decs_12_carry_val) + (params['v205'] * incorrect_decs_12_val) + (params['v206'] * incorrect_decs_13_carry_val) + (params['v207'] * incorrect_decs_13_val)
                    + (params['v208'] * incorrect_decs_14_carry_val) + (params['v209'] * incorrect_decs_14_val) + (params['v210'] * incorrect_decs_15_carry_val) + (params['v211'] * incorrect_decs_15_val)
                    + (params['v212'] * incorrect_decs_16_carry_val) + (params['v213'] * incorrect_decs_16_val) + (params['v214'] * incorrect_decs_17_carry_val) + (params['v215'] * incorrect_decs_17_val)
                    + (params['v216'] * incorrect_decs_18_carry_val) + (params['v217'] * incorrect_decs_18_val) + (params['v218'] * incorrect_decs_19_carry_val) + (params['v219'] * incorrect_decs_19_val)
                    + (params['v220'] * incorrect_decs_20_carry_val) + (params['v221'] * incorrect_decs_20_val) + (params['v222'] * incorrect_decs_21_carry_val) + (params['v223'] * incorrect_decs_21_val)
                    + (params['v224'] * incorrect_decs_22_carry_val) + (params['v225'] * incorrect_decs_22_val) + (params['v226'] * incorrect_decs_23_carry_val) + (params['v227'] * incorrect_decs_23_val)
                    + (params['v228'] * incorrect_decs_24_carry_val) + (params['v229'] * incorrect_decs_24_val) + (params['v230'] * incorrect_decs_25_carry_val) + (params['v231'] * incorrect_decs_25_val)
                    + (params['v232'] * incorrect_decs_26_carry_val) + (params['v233'] * incorrect_decs_26_val) + (params['v234'] * incorrect_decs_27_carry_val) + (params['v235'] * incorrect_decs_27_val)
                    + (params['v236'] * incorrect_decs_28_carry_val) + (params['v237'] * incorrect_decs_28_val) + (params['v238'] * incorrect_decs_29_carry_val) + (params['v239'] * incorrect_decs_29_val)
                    + (params['v240'] * incorrect_decs_30_carry_val) + (params['v241'] * incorrect_decs_30_val) + (params['v242'] * incorrect_decs_31_carry_val) + (params['v243'] * incorrect_decs_31_val)
                    + (params['v244'] * incorrect_decs_32_carry_val) + (params['v245'] * incorrect_decs_32_val) + (params['v246'] * incorrect_decs_33_carry_val) + (params['v247'] * incorrect_decs_33_val)
                    + (params['v248'] * incorrect_decs_34_carry_val) + (params['v249'] * incorrect_decs_34_val) + (params['v250'] * incorrect_decs_35_carry_val) + (params['v251'] * incorrect_decs_35_val))
    
    return salida_1, salida_2, salida_3

In [21]:
x_train, y_train = generate_final_data()
params = init_params()

count_predictions(params, x_train, y_train)

Error en la suma: [0 0 0 1]
Predicción: [0, 0, 0]
Predicciones correctas: 1 de 10000


In [16]:
x = [0, 1, 9, 9]
y_train = [1, 0, 0]
params = init_params()

units_input = jnp.array([x[1], x[3]])
units_input = units_input[None, None, :]

# Llamar a los modelos unit_module y carry_module
unit_output = jnp.array(unit_addition_model(units_input))  # Salida para unidades
unit_carry_output = jnp.array(unit_carry_model(units_input))  # Salida de acarreo de unidades
# Tomar el valor máximo de las predicciones (argmax en JAX)
unit_val = jnp.argmax(unit_output, axis=-1)
carry_unit_val = jnp.argmax(unit_carry_output, axis=-1)

decs_input = jnp.array([x[0], x[2], carry_unit_val[0]])
decs_input = decs_input[None, None, :]

dec_output = jnp.array(dec_addition_model(decs_input))  # Salida para decenas
dec_carry_output = jnp.array(dec_carry_model(decs_input))  # Salida de acarreo de decenas

dec_val = jnp.argmax(dec_output, axis=-1)
carry_dec_val = jnp.argmax(dec_carry_output, axis=-1)
# Calcular las salidas combinadas con los parámetros v
salida_1 = (params['v0'] * carry_dec_val) + (params['v1'] * dec_val) + (params['v2'] * carry_unit_val) + (params['v3'] * unit_val)
salida_2 = (params['v4'] * carry_dec_val) + (params['v5'] * dec_val) + (params['v6'] * carry_unit_val) + (params['v7'] * unit_val)
salida_3 = (params['v8'] * carry_dec_val) + (params['v9'] * dec_val) + (params['v10'] * carry_unit_val) + (params['v11'] * unit_val)

print(salida_1, salida_2, salida_3)

[1.] [0.] [0.]


In [38]:
x_train, y_train = generate_final_data()
params = init_params(epsilon = 0.05)

new_params, average_loss = train_model(params, x_train, y_train, lr=0.01, epochs=250)

Epoch 0, Loss: 0.0
Epoch 10, Loss: 0.013582281768321991
Epoch 20, Loss: 0.04610239714384079
Epoch 30, Loss: 0.06182209029793739
Epoch 40, Loss: 0.041956741362810135
Epoch 50, Loss: 0.01702243834733963
Epoch 60, Loss: 0.005225929897278547
Epoch 70, Loss: 0.001522018457762897
Epoch 80, Loss: 0.0004784985212609172
Epoch 90, Loss: 0.0001815989671740681
Epoch 100, Loss: 2.85101441477309e-06
Epoch 110, Loss: 8.297998647321947e-06
Epoch 120, Loss: 0.0001423328067176044
Epoch 130, Loss: 0.0009731091558933258
Epoch 140, Loss: 0.002816295251250267
Epoch 150, Loss: 0.005074201617389917
Epoch 160, Loss: 0.007598884403705597
Epoch 170, Loss: 0.010955652222037315
Epoch 180, Loss: 0.01573774218559265
Epoch 190, Loss: 0.023639684543013573
Epoch 200, Loss: 0.008644448593258858
Epoch 210, Loss: 0.0007096093613654375
Epoch 220, Loss: 0.0006683083483949304
Epoch 230, Loss: 1.9875062207574956e-05
Epoch 240, Loss: 0.0017874983604997396


In [39]:
count_predictions(new_params, x_train, y_train)

Predicciones correctas: 10000 de 10000
¡Todas las combinaciones han sido aprendidas correctamente!
