In [1]:
import numpy as np
import os
import re
import random
import json
import sys
import time
from datetime import datetime
import pandas as pd
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
from tensorflow.keras.models import Model, load_model, clone_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input, Dense, Multiply, Add, Lambda, Concatenate, Reshape, Flatten
from tensorflow.keras.initializers import GlorotUniform, RandomUniform, Constant
from tensorflow.keras.callbacks import LambdaCallback
import jax
import jax.numpy as jnp
from jax import grad, jit

In [10]:
epsilon = 0.2
folder = 'D:/OneDrive - Universidad Complutense de Madrid (UCM)/Doctorado/Curriculum_Learning/Multidigit_Addition_Decimal/'

In [11]:
# Verificar si hay GPUs disponibles
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print("TensorFlow está utilizando la GPU")
else:
    print("TensorFlow no está utilizando la GPU")

TensorFlow no está utilizando la GPU


In [12]:
# Cargar los módulos preentrenados (unit_module y carry_module)
unit_addition_model = load_model('unit_addition_module.keras')
unit_carry_model = load_model('unit_carry_module.keras')
dec_addition_model = load_model('dec_addition_module.keras')
dec_carry_model = load_model('dec_carry_module.keras')

unit_addition_model.trainable = False
unit_carry_model.trainable = False
dec_addition_model.trainable = False
dec_carry_model.trainable = False

unit_addition_model.name = 'unit_addition_model'
unit_carry_model.name = 'unit_carry_model'
dec_addition_model.name = 'dec_addition_model'
dec_carry_model.name = 'dec_carry_model'

In [13]:
# Cargar las parejas desde el archivo
with open(f"{folder}sorted_train_couples_stimuli.txt", "r") as file:
    train_couples = eval(file.read())

with open(f"{folder}stimuli.txt", "r") as file:
    test_couples = eval(file.read())

with open(f"{folder}test_dataset.txt", "r") as file:
    test_dataset = eval(file.read())

def generate_test_dataset():
    x_data = []
    y_data = []
    
    for a, b in test_dataset:
        a_dec = a // 10  # Decena del primer número
        a_unit = a % 10  # Unidad del primer número
        b_dec = b // 10  # Decena del segundo número
        b_unit = b % 10  # Unidad del segundo número

        x_data.append([a_dec, a_unit, b_dec, b_unit])  # Entrada

        sum_units = (a_unit + b_unit) % 10
        carry_units = 1 if (a_unit + b_unit) >= 10 else 0
        sum_dec = (a_dec + b_dec + carry_units) % 10
        carry_dec = 1 if (a_dec + b_dec + carry_units) >= 10 else 0
        y_data.append([carry_dec, sum_dec, sum_units])  # Salida
    
    return jnp.array(x_data), jnp.array(y_data)


# Función para leer los datos desde un archivo .txt y generar el dataset de entrenamiento
def generate_train_dataset():
    x_data = []
    y_data = []
    
    for a, b in train_couples:
        a_dec = a // 10  # Decena del primer número
        a_unit = a % 10  # Unidad del primer número
        b_dec = b // 10  # Decena del segundo número
        b_unit = b % 10  # Unidad del segundo número

        x_data.append([a_dec, a_unit, b_dec, b_unit])  # Entrada

        sum_units = (a_unit + b_unit) % 10
        carry_units = 1 if (a_unit + b_unit) >= 10 else 0
        sum_dec = (a_dec + b_dec + carry_units) % 10
        carry_dec = 1 if (a_dec + b_dec + carry_units) >= 10 else 0
        y_data.append([carry_dec, sum_dec, sum_units])  # Salida
    
    return jnp.array(x_data), jnp.array(y_data)

# Función para generar los datos
def generate_final_data():
    x_data = []
    y_data = []
    for a_dec in range(10):
        for a_unit in range(10):
            for b_dec in range(10):
                for b_unit in range(10):
                    x_data.append([a_dec, a_unit, b_dec, b_unit])  # Entrada
                    sum_units = (a_unit + b_unit) % 10
                    carry_units = 1 if (a_unit + b_unit) >= 10 else 0
                    sum_dec = (a_dec + b_dec + carry_units) % 10
                    carry_dec = 1 if (a_dec + b_dec + carry_units) >= 10 else 0
                    y_data.append([carry_dec, sum_dec, sum_units])  # Salida
    return jnp.array(x_data), jnp.array(y_data)

In [14]:
# Función para crear parámetros entrenables (v_0, ..., v_11)
def init_params(epsilon = 0.1):
    v_values_init = jnp.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], dtype=jnp.float32)
    v_params = {
        f'v{i}': random.uniform(-10, 10) * epsilon + v_values_init[i] for i in range(12)
    }
    return v_params

# Función de pérdida
def loss_fn(params, x, y):
    y_pred_1, y_pred_2, y_pred_3 = model(params, x)
    return jnp.mean((y_pred_1 - y[:, 0]) ** 2) + jnp.mean((y_pred_2 - y[:, 1]) ** 2) + jnp.mean((y_pred_3 - y[:, 2]) ** 2)
    
# Función para actualizar los parámetros
def update_params(params, x, y, lr):
    # Asegúrate de usar JAX para los gradientes y operaciones
    gradients = grad(loss_fn)(params, x, y)
    new_params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, gradients)
    return new_params

# Función para entrenar el modelo
def train_model(params, x_train, y_train, lr=0.01, epochs=100, stop=1, batch_size=16):
    final_loss = 0
    total_examples = x_train.shape[0]
    
    # Convertir x_train y y_train a arrays de JAX (si aún no lo son)
    x_train = jnp.array(x_train)
    y_train = jnp.array(y_train)
    
    # Entrenar el modelo
    for epoch in range(epochs):  # Número de épocas
        # Barajar los datos al inicio de cada época
        #perm = jax.random.permutation(jax.random.PRNGKey(epoch), total_examples)
        #x_train = x_train[perm]
        #y_train = y_train[perm]
        
        # Iterar por batches
        #for i in range(0, total_examples, batch_size):
        if epoch == 0:
            pred_count, pred_count_test, step_loss = correct_predictions_and_loss(params)  # Contamos las predicciones correctas
            print(f"Epoch {epoch}, Loss: {step_loss}, Correct predictions: {pred_count}, Correct predictions test: {pred_count_test}")
            
        x_batch = x_train[(epoch*batch_size):((epoch+1)*batch_size)]
        y_batch = y_train[(epoch*batch_size):((epoch+1)*batch_size)]
        params = update_params(params, x_batch, y_batch, lr)
    
        # Mostrar estadísticas cada 10 épocas
        if epoch % 10 == 0 and epoch != 0:
            pred_count, pred_count_test, step_loss = correct_predictions_and_loss(params)  # Contamos las predicciones correctas
            print(f"Epoch {epoch}, Loss: {step_loss}, Correct predictions: {pred_count}, Correct predictions test: {pred_count_test}")
            if stop == 1 and pred_count_test == 192:
                break

            if stop == 0 and pred_count == 10000:
                break

            if step_loss >= 1000000:
                break
        
    return params, step_loss

def correct_predictions_and_loss(params):
    x_test, y_test = generate_test_dataset()
    pred_count = 0
    pred_count_test = 0
    total_examples = x_test.shape[0]
    pred_hundreds, pred_tens, pred_units = model(params, x_test)   
    loss = jnp.mean((pred_hundreds - y_test[:, 0]) ** 2) + jnp.mean((pred_tens - y_test[:, 1]) ** 2) + jnp.mean((pred_units - y_test[:, 2]) ** 2)
    for i in range(total_examples):
        normalized_pred = [int(jnp.round(pred_hundreds[i].item())),
                           int(jnp.round(pred_tens[i].item())),
                           int(jnp.round(pred_units[i].item()))]
        
        # Obtener los valores a y b de x_test
        a = int(str(x_test[i, 0]) + str(x_test[i, 1]))
        b = int(str(x_test[i, 2]) + str(x_test[i, 3]))
        # Comparar las predicciones con las etiquetas y contar los aciertos
        if normalized_pred[0] == y_test[i, 0] and normalized_pred[1] == y_test[i, 1] and normalized_pred[2] == y_test[i, 2]:
            pred_count += 1
            if (a, b) in test_couples:
                pred_count_test += 1

    return pred_count, pred_count_test, loss

def correct_predictions(params):
    x_test, y_test = generate_test_dataset()
    pred_count = 0
    pred_count_test = 0
    total_examples = x_test.shape[0]
    pred_hundreds, pred_tens, pred_units = model(params, x_test)        
    for i in range(total_examples):
        normalized_pred = [int(jnp.round(pred_hundreds[i].item())),
                           int(jnp.round(pred_tens[i].item())),
                           int(jnp.round(pred_units[i].item()))]
        
        # Obtener los valores a y b de x_test
        a = int(str(x_test[i, 0]) + str(x_test[i, 1]))
        b = int(str(x_test[i, 2]) + str(x_test[i, 3]))
        # Comparar las predicciones con las etiquetas y contar los aciertos
        if normalized_pred[0] == y_test[i, 0] and normalized_pred[1] == y_test[i, 1] and normalized_pred[2] == y_test[i, 2]:
            pred_count += 1
            if (a, b) in test_couples:
                pred_count_test += 1

    return pred_count, pred_count_test

In [15]:
# Modelo dinámico en JAX
def model(params, x):
    units_input = jnp.array(x[:, [1, 3]])  # Columnas 1 y 3 representando unidades y decenas
    units_input = units_input[:, None, :]  # Añade una dimensión extra para la secuencia (N, 1, 2)
                            
    unit_output = jnp.array(unit_addition_model(units_input))  # Asegúrate de que la entrada sea un batch
    unit_carry_output = jnp.array(unit_carry_model(units_input))  # Salida de acarreo de unidades

    # Tomar el valor máximo de las predicciones (argmax en JAX)
    unit_val = jnp.argmax(unit_output, axis=-1)
    carry_unit_val = jnp.argmax(unit_carry_output, axis=-1)

    decs_input = jnp.array(x[:, [0, 2]])
    decs_input = jnp.concatenate([decs_input, carry_unit_val[:, None]], axis=-1)
    decs_input = decs_input[:, None, :]  # Añadir dimensión para la secuencia (N, 1, 3)
    
    dec_output = jnp.array(dec_addition_model(decs_input))  # Salida para decenas
    dec_carry_output = jnp.array(dec_carry_model(decs_input))  # Salida de acarreo de decenas
    
    dec_val = jnp.argmax(dec_output, axis=-1)
    carry_dec_val = jnp.argmax(dec_carry_output, axis=-1)

    # Calcular las salidas combinadas con los parámetros v
    salida_1 = (params['v0'] * carry_dec_val) + (params['v1'] * dec_val) + (params['v2'] * carry_unit_val) + (params['v3'] * unit_val)
    salida_2 = (params['v4'] * carry_dec_val) + (params['v5'] * dec_val) + (params['v6'] * carry_unit_val) + (params['v7'] * unit_val)
    salida_3 = (params['v8'] * carry_dec_val) + (params['v9'] * dec_val) + (params['v10'] * carry_unit_val) + (params['v11'] * unit_val)

    return salida_1, salida_2, salida_3

In [16]:
# Función para leer parámetros de un archivo JSON
def load_params_from_file(filename):
    with open(filename, 'r') as f:
        return json.load(f)

# Guardar el modelo entrenado
def save_trained_model(params, filename, model_dir):
    os.makedirs(model_dir, exist_ok=True)
    file_path = os.path.join(model_dir, filename)
    serializable_params = {key: value.tolist() for key, value in params.items()}
    
    with open(file_path, 'w') as f:
        json.dump(serializable_params, f)

class Tee(object):
    def __init__(self, file, mode='w'):
        self.file = open(file, mode)
        self.console = sys.stdout  

    def write(self, data):
        self.console.write(data)   
        self.file.write(data)    

    def flush(self):
        self.console.flush()
        self.file.flush()

    def close(self):
        self.file.close()

In [17]:
x_train, y_train = generate_train_dataset()

save_dir = f"{folder}Results_models/AP_{epsilon}"
save_model_dir = f"{folder}Trained_models/AP_{epsilon}"
save_model_dir_2 = f"{folder}Super_trained_models/AP_{epsilon}"
folder_path = f'{folder}Parameters/AP_{epsilon}'
date_pattern = r'trainable_model_(\d{4}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}).json'
files = sorted(
    (f for f in os.listdir(folder_path) if not f.startswith('.')),  # Filtrar archivos ocultos
    key=lambda x: re.search(date_pattern, x).group(1) if re.search(date_pattern, x) else ''
)

for filename in files:
    match = re.search(date_pattern, filename)
    if match:
        current_time = match.group(1)
    else:
        print('Error')
        break
    
    file_path = f"{folder_path}/trainable_model_{current_time}.json"
    with open(file_path, 'rb') as file:
        trainable_model = json.load(file)

    trainable_model_jnp = {key: jnp.array(value) for key, value in trainable_model.items()}
    print(f'Loaded trainable_model_{current_time}.json')
    
    os.makedirs(save_dir, exist_ok=True) 
    results_file = os.path.join(save_dir, f"Results_{current_time}.txt") 
    tee = Tee(results_file, 'w') 
    sys.stdout = tee
    
    try: 
        new_params, average_loss = train_model(trainable_model_jnp, x_train, y_train, lr=0.01, epochs=100)
        pred_count, pred_count_test = correct_predictions(new_params)

        trained_model_filename = f"trained_model_{current_time}.json"
        save_trained_model(new_params, trained_model_filename, save_model_dir)
        print(f'Saved trained_model_{current_time}.json')

        if pred_count != 10001:
            new_params_2, average_loss_2 = train_model(new_params, x_train, y_train, lr=0.01, epochs=613, stop=0)
            trained_model_filename_2 = f"super_trained_model_{current_time}.json"
            save_trained_model(new_params_2, trained_model_filename_2, save_model_dir_2)
            print(f'Saved super_trained_model_{current_time}.json')

    finally:
        sys.stdout = tee.console
        tee.close()

Loaded trainable_model_2024_12_29_15_41_17.json
Epoch 0, Loss: 179.31431579589844, Correct predictions: 1, Correct predictions test: 0
Epoch 10, Loss: 21.18964195251465, Correct predictions: 257, Correct predictions test: 14
Epoch 20, Loss: 9.510498046875, Correct predictions: 656, Correct predictions test: 6
Epoch 30, Loss: 13.880317687988281, Correct predictions: 347, Correct predictions test: 14
Epoch 40, Loss: 6.867625713348389, Correct predictions: 669, Correct predictions test: 8
Epoch 50, Loss: 3.045828342437744, Correct predictions: 1226, Correct predictions test: 26
Epoch 60, Loss: 4.175603866577148, Correct predictions: 1617, Correct predictions test: 32
Epoch 70, Loss: 3.144599437713623, Correct predictions: 718, Correct predictions test: 4
Epoch 80, Loss: 8.346903800964355, Correct predictions: 487, Correct predictions test: 28
Epoch 90, Loss: 5.717255592346191, Correct predictions: 629, Correct predictions test: 14
Saved trained_model_2024_12_29_15_41_17.json
Epoch 0, Loss