<a href="https://colab.research.google.com/github/wilmi94/MasterThesis-AE/blob/main/notebooks/sdo_e2e_ConvLSTM_171.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SDO/AIA-171A End-to-End ConvLSTM Model

model is based on:

*Predicting Solar Flares Using a Long Short-term Memory Network. Liu, H., Liu, C., Wang, J. T. L., Wang, H., ApJ., 877:121, 2019.*


In [None]:
import pandas as pd
from keras.utils import np_utils
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import class_weight
from keras.models import *
from keras.layers import *
import csv
import sys
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
try :
    import tensorflow as tf
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
except Exception as e:
    print('turn off loggins is not supported')

In [None]:
def load_data(datafile, flare_label, series_len, start_feature, n_features, mask_value):
    df = pd.read_csv(datafile)
    df_values = df.values
    X = []
    y = []
    tmp = []
    for k in range(start_feature, start_feature + n_features):
        tmp.append(mask_value)
    for idx in range(0, len(df_values)):
        each_series_data = []
        row = df_values[idx]
        label = row[1][0]
        if flare_label == 'C' and (label == 'X' or label == 'M'):
            label = 'C'
        if flare_label == 'C' and label == 'B':
            label = 'N'
        has_zero_record = False
        # if at least one of the 25 physical feature values is missing, then discard it.
        if flare_label == 'C':
            if float(row[5]) == 0.0:
                has_zero_record = True
            if float(row[7]) == 0.0:
                has_zero_record = True
            for k in range(9, 13):
                if float(row[k]) == 0.0:
                    has_zero_record = True
                    break
            for k in range(14, 16):
                if float(row[k]) == 0.0:
                    has_zero_record = True
                    break
            if float(row[18]) == 0.0:
                has_zero_record = True

        if has_zero_record is False:
            cur_noaa_num = int(row[3])
            each_series_data.append(row[start_feature:start_feature + n_features].tolist())
            itr_idx = idx - 1
            while itr_idx >= 0 and len(each_series_data) < series_len:
                prev_row = df_values[itr_idx]
                prev_noaa_num = int(prev_row[3])
                if prev_noaa_num != cur_noaa_num:
                    break
                has_zero_record_tmp = False
                if flare_label == 'C':
                    if float(row[5]) == 0.0:
                        has_zero_record_tmp = True
                    if float(row[7]) == 0.0:
                        has_zero_record_tmp = True
                    for k in range(9, 13):
                        if float(row[k]) == 0.0:
                            has_zero_record_tmp = True
                            break
                    for k in range(14, 16):
                        if float(row[k]) == 0.0:
                            has_zero_record_tmp = True
                            break
                    if float(row[18]) == 0.0:
                        has_zero_record_tmp = True

                if len(each_series_data) < series_len and has_zero_record_tmp is True:
                    each_series_data.insert(0, tmp)

                if len(each_series_data) < series_len and has_zero_record_tmp is False:
                    each_series_data.insert(0, prev_row[start_feature:start_feature + n_features].tolist())
                itr_idx -= 1

            while len(each_series_data) > 0 and len(each_series_data) < series_len:
                each_series_data.insert(0, tmp)

            if len(each_series_data) > 0:
                X.append(np.array(each_series_data).reshape(series_len, n_features).tolist())
                y.append(label)
    X_arr = np.array(X)
    y_arr = np.array(y)
    print(X_arr.shape)
    return X_arr, y_arr


def data_transform(data):
    encoder = LabelEncoder()
    encoder.fit(data)
    encoded_Y = encoder.transform(data)
    converteddata = np_utils.to_categorical(encoded_Y)
    return converteddata


def attention_3d_block(hidden_states, series_len):
    hidden_size = int(hidden_states.shape[2])
    hidden_states_t = Permute((2, 1), name='attention_input_t')(hidden_states)
    hidden_states_t = Reshape((hidden_size, series_len), name='attention_input_reshape')(hidden_states_t)
    score_first_part = Dense(series_len, use_bias=False, name='attention_score_vec')(hidden_states_t)
    score_first_part_t = Permute((2, 1), name='attention_score_vec_t')(score_first_part)
    h_t = Lambda(lambda x: x[:, :, -1], output_shape=(hidden_size, 1), name='last_hidden_state')(hidden_states_t)
    score = dot([score_first_part_t, h_t], [2, 1], name='attention_score')
    attention_weights = Activation('softmax', name='attention_weight')(score)
    context_vector = dot([hidden_states_t, attention_weights], [2, 1], name='context_vector')
    context_vector = Reshape((hidden_size,))(context_vector)
    h_t = Reshape((hidden_size,))(h_t)
    pre_activation = concatenate([context_vector, h_t], name='attention_output')
    attention_vector = Dense(hidden_size, use_bias=False, activation='tanh', name='attention_vector')(pre_activation)
    return attention_vector


def lstm(nclass, n_features, series_len):
    inputs = Input(shape=(series_len, n_features,))
    lstm_out = LSTM(10, return_sequences=True, dropout=0.5)(inputs)
    attention_mul = attention_3d_block(lstm_out, series_len)
    layer1_out = Dense(200, activation='relu')(attention_mul)
    layer2_out = Dense(500, activation='relu')(layer1_out)
    output = Dense(nclass, activation='softmax', activity_regularizer=regularizers.l2(0.0001))(layer2_out)
    model = Model(input=[inputs], output=output)
    return model


if __name__ == '__main__':
    flare_label = sys.argv[1]
    train_again = int(sys.argv[2])
    filepath = './'
    n_features = 0
    if flare_label == 'C':
        n_features = 14
    start_feature = 5
    mask_value = 0
    series_len = 10
    epochs = 7
    batch_size = 256
    nclass = 2
    result_file = './output.csv'

    if train_again == 1:
        # Train
        X_train_data, y_train_data = load_data(datafile=filepath + 'normalized_training.csv',
                                               flare_label=flare_label, series_len=series_len,
                                               start_feature=start_feature, n_features=n_features,
                                               mask_value=mask_value)

        X_train = np.array(X_train_data)
        y_train = np.array(y_train_data)
        y_train_tr = data_transform(y_train)

        class_weights = class_weight.compute_class_weight('balanced',
                                                          np.unique(y_train), y_train)
        class_weight_ = {0: class_weights[0], 1: class_weights[1]}
        # print(class_weight_)

        model = lstm(nclass, n_features, series_len)
        model.compile(loss='categorical_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])

        history = model.fit(X_train, y_train_tr,
                            epochs=epochs, batch_size=batch_size,
                            verbose=False, shuffle=True, class_weight=class_weight_)
        model.save('./model.h5')
    else:
        model = load_model('./model.h5')

    # Test
    X_test_data, y_test_data = load_data(datafile=filepath + 'normalized_testing.csv',
                                         flare_label=flare_label, series_len=series_len,
                                         start_feature=start_feature, n_features=n_features,
                                         mask_value=mask_value)
    X_test = np.array(X_test_data)
    y_test = np.array(y_test_data)
    y_test_tr = data_transform(y_test)

    classes = model.predict(X_test, batch_size=batch_size, verbose=0, steps=None)

    with open(result_file, 'w', encoding='UTF-8') as result_csv:
        w = csv.writer(result_csv)
        with open(filepath + 'normalized_testing.csv', encoding='UTF-8') as data_csv:
            reader = csv.reader(data_csv)
            i = -1
            for line in reader:
                if i == -1:
                    line.insert(0, 'Predicted Label')
                else:
                    if classes[i][0] >= 0.5:
                        line.insert(0, 'Positive')
                    else:
                        line.insert(0, 'Negative')
                i += 1
                w.writerow(line)