
# 1. Import Required Libraries

In [None]:

import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from sunpy.map import Map
from spacepy import pycdf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, LSTM, Dense, Concatenate, Flatten
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
from sklearn.model_selection import train_test_split
from concurrent.futures import ProcessPoolExecutor

# Reproducibility
tf.random.set_seed(42)
np.random.seed(42)


# 2. Global Configurations

In [None]:

VELC_DIR = 'velc_data/'     # Path to .fits images
ASPEX_DIR = 'aspex_data/'   # Path to .cdf files
IMG_SIZE = (128, 128)       # Resize for CNN
BATCH_SIZE = 32
EPOCHS = 10
NUM_SAMPLES = 1000          # Limit for efficiency


# 3. Data Preprocessing Functions

In [None]:

def preprocess_velc(file_path):
    velc_map = Map(file_path)
    data = velc_map.data.astype(np.float32)
    resized = tf.image.resize(data[None, ..., None], IMG_SIZE) / 255.0
    return resized.numpy().squeeze()

def preprocess_aspex(file_path):
    cdf_data = pycdf.CDF(file_path)
    df = pd.DataFrame(cdf_data['data'])  # Adjust variable name as needed
    flux = df['proton_flux'].values      # Replace with real key
    return tf.keras.preprocessing.sequence.pad_sequences([flux], maxlen=100, dtype='float32')[0]


# 4. Batch Loader

In [None]:

def load_data(batch_start, batch_size):
    velc_files = sorted(os.listdir(VELC_DIR))[:NUM_SAMPLES]
    aspex_files = sorted(os.listdir(ASPEX_DIR))[:NUM_SAMPLES]

    velc_batch = [os.path.join(VELC_DIR, f) for f in velc_files[batch_start:batch_start + batch_size]]
    aspex_batch = [os.path.join(ASPEX_DIR, f) for f in aspex_files[batch_start:batch_start + batch_size]]

    with ProcessPoolExecutor() as executor:
        velc_data = list(executor.map(preprocess_velc, velc_batch))
        aspex_data = list(executor.map(preprocess_aspex, aspex_batch))

    labels = np.random.randint(0, 2, size=len(velc_batch))  # Synthetic binary labels
    return (np.array(velc_data), np.array(aspex_data)), labels


# 5. Aggregate Data

In [None]:

all_velc_data, all_aspex_data, all_labels = [], [], []
for i in range(0, NUM_SAMPLES, BATCH_SIZE):
    (velc_batch, aspex_batch), labels = load_data(i, BATCH_SIZE)
    all_velc_data.append(velc_batch)
    all_aspex_data.append(aspex_batch)
    all_labels.append(labels)

velc_data = np.concatenate(all_velc_data)
aspex_data = np.concatenate(all_aspex_data)
labels = np.concatenate(all_labels)

# Reshape ASPEX input for LSTM
aspex_data = aspex_data.reshape((-1, 100, 1))


# 6. Train-Test Split

In [None]:

velc_train, velc_test, aspex_train, aspex_test, y_train, y_test = train_test_split(
    velc_data, aspex_data, labels, test_size=0.2, random_state=42
)


# 7. ATSFusion Model

In [None]:

def build_atsfusion():
    # VELC CNN branch
    velc_input = Input(shape=(128, 128, 1))
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(velc_input)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Flatten()(x)
    x = Dense(64, activation='relu')(x)

    # ASPEX LSTM branch
    aspex_input = Input(shape=(100, 1))
    y = LSTM(64)(aspex_input)
    y = Dense(64, activation='relu')(y)

    # Fusion
    fused = Concatenate()([x, y])
    fused = Dense(32, activation='relu')(fused)
    output = Dense(1, activation='sigmoid')(fused)

    model = Model(inputs=[velc_input, aspex_input], outputs=output)
    model.compile(optimizer=Adam(0.001), loss='binary_crossentropy', metrics=['accuracy'])
    return model

atsfusion = build_atsfusion()
atsfusion.summary()


# 8. Training

In [None]:

history = atsfusion.fit(
    [velc_train, aspex_train], y_train,
    validation_data=([velc_test, aspex_test], y_test),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE
)


# 9. Visualization

In [None]:

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Val')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Val')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.show()


# 10. Evaluation

In [None]:

test_loss, test_accuracy = atsfusion.evaluate([velc_test, aspex_test], y_test, verbose=0)
print(f"Test Accuracy: {test_accuracy:.4f}")

# Sample prediction
sample_idx = 0
pred = atsfusion.predict([velc_test[sample_idx:sample_idx+1], aspex_test[sample_idx:sample_idx+1]])
print(f"Predicted CME Probability: {pred[0][0]:.4f} | Actual: {y_test[sample_idx]}")
