In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.sql import Row
from elephas.ml_model import ElephasEstimator
from elephas.utils.rdd_utils import to_simple_rdd
from sklearn.model_selection import train_test_split
import numpy as np

# Crear una sesión de Spark
spark = SparkSession.builder \
    .appName("Galaxy Classification") \
    .getOrCreate()

2023-04-20 17:41:16.473586: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-20 17:41:16.585790: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-04-20 17:41:17.067669: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hadoop/lib/native:
2023-04-20 17:41:17.067725: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot

In [None]:
# Función para leer y procesar los archivos FITS y CAT
def read_fits_and_cat(file_path):
    # Aquí va el código para leer y procesar los archivos FITS y CAT
    # utilizando Astropy y el resto del código que has proporcionado.
    # Finalmente, devuelve una lista de tuplas con los datos de entrada y salida.
    pass

# Cambia la ruta para que apunte a un directorio con los archivos
fits_and_cat_files = "../data/"

In [None]:
# Lee y procesa los archivos FITS y CAT
data = spark.sparkContext.wholeTextFiles(fits_and_cat_files).flatMap(lambda x: read_fits_and_cat(x[0]))

# Convierte los datos en un DataFrame de Spark
schema = ["features", "label"]
data = data.map(lambda x: Row(**dict(zip(schema, (Vectors.dense(x[0]), x[1])))))
data = spark.createDataFrame(data, schema)

# Divide los datos en conjuntos de entrenamiento y validación
train, val = data.randomSplit([0.8, 0.2], seed=42)

# Convierte los DataFrames de Spark a RDDs para usar con Elephas
train_rdd = train.rdd.map(lambda x: (x.features.toArray(), x.label))
val_rdd = val.rdd.map(lambda x: (x.features.toArray(), x.label))

# Convierte los RDDs a arrays NumPy para usar con Keras
x_train, y_train = zip(*train_rdd.collect())
x_val, y_val = zip(*val_rdd.collect())

x_train = np.array(x_train)
y_train = np.array(y_train)
x_val = np.array(x_val)
y_val = np.array(y_val)

In [None]:
# Crear el modelo de Keras
def create_keras_model():
    model = Sequential()
    model.add(Conv2D(32, (3, 3), activation="relu", padding="same", input_shape=(2048, 2048, 1)))
    model.add(tfa.layers.InstanceNormalization())
    model.add(MaxPooling2D((2, 2), padding="same"))
    model.add(Conv2D(32, (3, 3), activation="relu", padding="same"))
    model.add(tfa.layers.InstanceNormalization())
    model.add(MaxPooling2D((2, 2), padding="same"))
    model.add(Conv2D(32, (3, 3), activation="relu", padding="same"))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(32, (3, 3), activation="relu", padding="same"))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(1, (3, 3), activation="sigmoid", padding="same"))
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    return model

# Entrena el modelo con Elephas
estimator = ElephasEstimator()
estimator.set_keras_model_config(create_keras_model().to_yaml())
estimator.set_optimizer_config('adam')
estimator.set_loss('binary_crossentropy')
estimator.set_mode('synchronous')
estimator.set_metrics(['accuracy'])

# Convierte los datos de entrenamiento y validación en RDD
train_rdd = to_simple_rdd(spark.sparkContext, x_train, y_train)
val_rdd = to_simple_rdd(spark.sparkContext, x_val, y_val)

# Entrena el modelo
fitted_model = estimator.fit(train_rdd)

# Evalúa el modelo
score = fitted_model.evaluate(val_rdd)
print("Accuracy: ", score[1])

# Guarda el modelo
fitted_model.save("cnn_spark.h5")
