# Imports

In [1]:
import tarfile
import numpy as np

from astropy.io import fits
import matplotlib.pyplot as plt

from scipy.io import readsav

import os

from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
from tensorflow.keras.models import Model


In [2]:
import sunpy
import sunpy.map

In [13]:
import matplotlib.pyplot as plt

In [3]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Useful General Functions

In [5]:
def get_file_paths(directory_path):
    """
    Given a directory path, this function returns a list of full file paths
    for all the files within the given directory.

    Parameters:
    - directory_path (str): The path to the directory from which to list files.

    Returns:
    - List[str]: A list of full file paths.
    """
    # List to hold file paths
    file_paths = []

    # Check if the directory exists
    if not os.path.isdir(directory_path):
        raise ValueError(f"The provided directory path does not exist: {directory_path}")

    # Walk the directory tree
    for root, _, files in os.walk(directory_path):
        for file in files:
            # Concatenate the root directory and file name to get the full path
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

    return file_paths

"""
# Test
im = get_file_paths("G:\\BMR_Identification")
img = image_data_generator(im)
image = next(img)

plt.figure(figsize=(10, 10))
plt.imshow(image, cmap='gray', origin='lower')
plt.colorbar()
plt.title('Image Data from IDL .sav file')
plt.show()
"""

'\n# Test\nim = get_file_paths("G:\\BMR_Identification")\nimg = image_data_generator(im)\nimage = next(img)\n\nplt.figure(figsize=(10, 10))\nplt.imshow(image, cmap=\'gray\', origin=\'lower\')\nplt.colorbar()\nplt.title(\'Image Data from IDL .sav file\')\nplt.show()\n'

In [6]:
def combined_generator(x_generator, y_generator, batch_size):
    while True:  # Loop indefinitely
        X_batch = []
        Y_batch = []
        
        for _ in range(batch_size):
            X_batch.append(next(x_generator))
            Y_batch.append(next(y_generator))
        
        # Convert to numpy arrays and yield the batches
        yield (np.array(X_batch), np.array(Y_batch))


# Sun Images

In [7]:
def load_fits_image_sunpy(fits_file_path, vmin=-1500, vmax=1500):
    # Use sunpy to open the FITS file
    magnetogram_map = sunpy.map.Map(fits_file_path)
    image_data = magnetogram_map.data.astype(np.float32)  # Ensure data is in float32
    
    # Clip the data to be within the range [vmin, vmax]
    image_data = np.clip(image_data, vmin, vmax)
    
    # Normalize the clipped image data to [0, 1]
    image_data = (image_data - vmin) / (vmax - vmin)
    
    image_data = np.nan_to_num(image_data, nan=0)
    return image_data

def sun_data_generator(file_paths):
    for file_path in file_paths:
        yield load_fits_image_sunpy(file_path)

# BMR Images

In [8]:
def process_sav_file(sav_file_path):
    # Read the .sav file
    idl_data = readsav(sav_file_path)
    
    # Extract image dimensions and indices
    hdr_los = idl_data['hdr_los']
    naxis1 = hdr_los.naxis1[0]
    naxis2 = hdr_los.naxis2[0]
    bmr_ind = idl_data['bmr_ind']
    
    # Convert flat indices to 2D indices
    rows, cols = np.divmod(bmr_ind, naxis1)

    # Create an empty image and set the magnetic regions to 1
    image_data = np.zeros((naxis2, naxis1), dtype=np.float32)
    image_data[rows, cols] = 1.0
    
    image_data = np.expand_dims(image_data, axis=-1)

    return image_data

def image_data_generator(file_paths):
    for file_path in file_paths:
        yield process_sav_file(file_path)

# Model

In [17]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Conv2DTranspose, Lambda, Cropping2D
from tensorflow.keras.applications import MobileNetV2
import tensorflow as tf

def create_small_unet(input_size=(1024, 1024, 1), pretrained_weights=None):
    inputs = Input(input_size)
    x = Lambda(lambda x: tf.tile(x, multiples=[1, 1, 1, 3]))(inputs)  # Convert to RGB

    base_model = MobileNetV2(weights='imagenet', include_top=False, input_tensor=x, input_shape=(1024, 1024, 3))

    # Freeze the layers of MobileNetV2
    for layer in base_model.layers:
        layer.trainable = False

    # Define the layer names for skip connections
    layer_names = [
        "block_1_expand_relu",   # downsampled to 512x512
        "block_3_expand_relu",   # downsampled to 256x256
        "block_6_expand_relu",   # downsampled to 128x128
        # Add the additional layers needed for skip connections if you want more upsampling steps
    ]
    layers = [base_model.get_layer(name).output for name in layer_names]

    # Create the decoder
    x = base_model.output
    # Start the upsampling process to match the downsampled layers
    for filters, skip_layer in zip([512, 256, 128, 64], layers[::-1]):  # Adjust the number of filters as needed
        x = Conv2DTranspose(filters, (3, 3), strides=(2, 2), padding='same')(x)
        # Calculate cropping for skip layer
        skip_size = skip_layer.shape[1:3]
        x_size = x.shape[1:3]
        crop_size = [(s - x_s) // 2 for s, x_s in zip(skip_size, x_size)]
        cropped_skip_layer = Cropping2D(cropping=(crop_size, crop_size))(skip_layer)
        x = concatenate([x, cropped_skip_layer])
        x = Conv2D(filters, (3, 3), activation='relu', padding='same')(x)

    # Additional upsampling steps to get back to the original size
    x = Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same')(x)  # Upsample to 512x512
    x = Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same')(x)  # Upsample to 1024x1024

    # Output layer for binary segmentation
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(x)

    model = Model(inputs=inputs, outputs=outputs)

    if pretrained_weights:
        model.load_weights(pretrained_weights)

    return model




# Training Loop

In [19]:
def training_loop(X_train_file_path, Y_train_file_path, X_validation_file_path, Y_validation_file_path, Yash_path):
    
    # Create the U-Net model
    model = create_small_unet()

    # Compile the model
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    list_of_x_train_files = get_file_paths(X_train_file_path)
    list_of_y_train_files = get_file_paths(Y_train_file_path)
    x_train_generator = sun_data_generator(list_of_x_train_files)
    y_train_generator = image_data_generator(list_of_y_train_files)

    # Combine them into a single generator for training
    train_generator = combined_generator(x_train_generator, y_train_generator, batch_size=8)

    # Calculate the steps per epoch (total_samples / batch_size)
    steps_per_epoch = len(list_of_x_train_files) // 8

    # Prepare your generators for validation
    list_of_x_test_files = get_file_paths(X_validation_file_path)
    list_of_y_test_files = get_file_paths(Y_validation_file_path)
    x_test_generator = sun_data_generator(list_of_x_test_files)
    y_test_generator = image_data_generator(list_of_y_test_files)

    # Combine them into a single generator for validation
    validation_generator = combined_generator(x_test_generator, y_test_generator, batch_size=8)

    # Define validation steps if you haven't
    validation_steps = len(list_of_x_test_files) // 8
    
    early_stopping = EarlyStopping(
    monitor='val_loss', 
    patience=5, 
    verbose=1, 
    mode='min'
    )
    
    
    history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=10,  # Replace with the number of epochs you desire
    validation_data=validation_generator,
    validation_steps=validation_steps,
    callbacks = [early_stopping]
    )



    # Plot training & validation accuracy values
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    plt.show()

    # Plot training & validation loss values
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    plt.show()
    
    file_name = 'my_plot.png'
    
    full_path = os.path.join(Yash_path, file_name)
    
    plt.savefig(full_path)

    # Save the model after training
    model.save(os.path.join(Yash_path, my_model.h))
    

