### Import libraries

In [None]:
import astropy.io.fits as pyfits
import pandas as pd
import numpy as np
import glob
import matplotlib.pyplot as plt
import os, urllib
import torch
from torch.utils.data import Dataset, DataLoader
from joblib import Parallel, delayed
from scipy.interpolate import interp1d

### Choose band

In [None]:
ccd_dict = {
# 1: "B",
# 2: "G",
  3: "R",
# 4: "I"
}

ccd_list_keys = list(ccd_dict.keys())

### Read fits files
This code is mainly based on GALAH existing code found at https://github.com/svenbuder/GALAH_DR3/tree/master/tutorials

In [None]:
def read_spectra(sobject_id, ccd_list_keys):
    """
    Read in all available CCDs and give back a dictionary
    Download them if not already in working directory
    """
    spectra_directory = "fits-path"
    # Check if FITS files already available in working directory
    fits_files = [[], [], [], []]
    for each_ccd in ccd_list_keys:
        fits_files[each_ccd-1] = glob.glob(spectra_directory+str(sobject_id)+str(each_ccd)+'.fits')  
    spectrum = dict()

    for each_ccd in ccd_list_keys:
        if fits_files[each_ccd-1]!=[]:

            fits = pyfits.open(fits_files[each_ccd-1][0])

            # Extension 0: Reduced spectrum
            # Extension 1: Relative error spectrum
            # Extension 4: Normalised spectrum, NB: cut for CCD4

            # Extract wavelength grid for the normalised spectrum

            start_wavelength = fits[4].header["CRVAL1"]
            dispersion       = fits[4].header["CDELT1"]
            nr_pixels        = fits[4].header["NAXIS1"]
            reference_pixel  = fits[4].header["CRPIX1"]
            if reference_pixel == 0:
                reference_pixel=1
            spectrum['wave_norm_'+str(each_ccd)] = ((np.arange(0,nr_pixels)--reference_pixel+1)*dispersion+start_wavelength)

            # Extract flux and flux error of reduced spectrum
            # Added byteswap for Pandas use ----> https://stackoverflow.com/questions/30283836/creating-pandas-dataframe-from-numpy-array-leads-to-strange-errors
            spectrum['sob_red_'+str(each_ccd)]  = np.array(fits[0].data).byteswap().newbyteorder()
            # Extract flux and flux error of normalised spectrum
            spectrum['sob_norm_'+str(each_ccd)] = np.array(fits[4].data)
            spectrum['uob_norm_'+str(each_ccd)] = np.array(fits[4].data * fits[1].data)
            fits.close()
        else:
            spectrum['wave_red_'+str(each_ccd)] = []
            spectrum['wave_norm_'+str(each_ccd)] = []
            spectrum['sob_red_'+str(each_ccd)] = []
            spectrum['sob_norm_'+str(each_ccd)] = []
            spectrum['uob_red_'+str(each_ccd)] = []
            spectrum['uob_norm_'+str(each_ccd)] = []
    
    spectrum['sob_red'] = np.concatenate(([spectrum['sob_red_'+str(each_ccd)] for each_ccd in ccd_list_keys]))
    spectrum['sob_norm'] = np.concatenate(([spectrum['sob_norm_'+str(each_ccd)] for each_ccd in ccd_list_keys]))
    spectrum['wave_norm'] = np.concatenate(([spectrum['wave_norm_'+str(each_ccd)] for each_ccd in ccd_list_keys]))
    spectrum['uob_norm'] = np.concatenate(([spectrum['uob_norm_'+str(each_ccd)] for each_ccd in ccd_list_keys]))
   
    return spectrum

### Using code above convert .fits files to .csv files.

In [None]:
input_dir = r"input-fits-files-path"
output_dir = r"ouput-csv-files-path"

os.makedirs(output_dir, exist_ok=True)
fits_files = glob.glob(os.path.join(input_dir, "*.fits"))
total_files = len(fits_files)

# Loop through the FITS files, convert them to CSV, and save them in the output directory
for i, fits_file in enumerate(fits_files):

    file_name_without_ext = os.path.splitext(os.path.basename(fits_file))[0][:-1]
    spectrum = read_spectra(file_name_without_ext, ccd_list_keys)
    my_array = spectrum['sob_norm_3']
    flattened_array = np.sort(my_array.flatten())
    differences = np.diff(flattened_array)
    min_difference = np.min(differences[differences > 0])
    decimal_points = int(np.ceil(-np.log10(min_difference)))

    try: 
        csv_file = os.path.join(output_dir, file_name_without_ext + ".csv")
        fmt_str = f'%.{decimal_points}f'
        np.savetxt(csv_file, my_array, delimiter=',', fmt=fmt_str)
    except Exception as e:
        print(f"\nError processing file: {fits_file}\nError: {e}\n")

    print(f"Processed file {i + 1} of {total_files}", end='\r')
    
print("All files have been converted.")

### Calculate logarithm of wavelengths and interpolate intensity
Done because of the Doppler shift, otherwise the bottoms of the double lines would become more and more spaced as the wavelength increases.

In [None]:
source_dir = "csv-files-path"
target_dir = "logged-csv-files-path"

for root, dirs, files in os.walk(source_dir):
    target_subdir = root.replace(source_dir, target_dir)

    if os.path.exists(target_subdir):
        continue

    for file in files:
        if file.endswith('.csv'):
            try:
                df = pd.read_csv(os.path.join(root, file))
                wavelength = df['wave_norm_3']
                intensity = df['sob_norm_3']
                log_wavelength = np.log(wavelength)
                new_log_wavelength = np.linspace(log_wavelength.min(), log_wavelength.max(), num=len(log_wavelength))

                interpolation_function = interp1d(log_wavelength, intensity, kind='linear')
                new_intensity = interpolation_function(new_log_wavelength)

                interpolated_df = pd.DataFrame({'wave_norm_3': new_log_wavelength, 'sob_norm_3': new_intensity})

                os.makedirs(target_subdir, exist_ok=True)

                interpolated_df.to_csv(os.path.join(target_subdir, file), index=False)

            except Exception as e:
                print(f"An error occurred with file {file}: {e}")

### Define function for loading logged .csv files into dataframes

In [None]:
def load_data(sobject_id):
    folder_name = str(sobject_id)[:6]

    file_path = f"logged-csv-files-path"

    if os.path.exists(file_path):
        df = pd.read_csv(file_path)

        sob_df = pd.DataFrame({str(sobject_id): df['sob_norm_3']})
        wave_df = pd.DataFrame({str(sobject_id): df['wave_norm_3']})
        return sob_df, wave_df

### Use of parallel for faster loading of .csv files in the function above

In [None]:
labels_df = pd.read_csv(r"path-to-labels.csv")
selected_sobject_ids = labels_df['sobject_id'].values

results = Parallel(n_jobs=-1)(delayed(load_data)(sobject_id) for sobject_id in selected_sobject_ids)

spektri_df = pd.concat([res[0] for res in results if res is not None], axis=1)
wave_df = pd.concat([res[1] for res in results if res is not None], axis=1)

### Clean the loaded .csv files

In [None]:
# Create a dictionary to map sobject_id to label
id_to_label = dict(zip(labels_df['sobject_id'], labels_df['bin_tf']))

spektri_df = spektri_df.dropna()  # Remove rows containing NaN values
X = spektri_df.T
X.index = X.index.map(lambda x: int(float(x)))  # Remove decimal point and convert to int
X.index = X.index.astype(np.int64)
X = X.sort_index()

id_to_label = {np.int64(k): v for k, v in id_to_label.items()}
X = X[X.index.isin(id_to_label.keys())]
y = [id_to_label[sobject_id] for sobject_id in X.index]


### Split into train, validation and test datasets

In [None]:

# First, split the data into train (80%) and test (20%) sets
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)
# Then, split the train set into train (75%) and validation (25%) sets. This results in a 60%-20%-20% split.
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.25, random_state=seed)

# Create PyTorch datasets
class SpectraDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X.values, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = SpectraDataset(X_train, y_train)
val_dataset = SpectraDataset(X_val, y_val)
test_dataset = SpectraDataset(X_test, y_test)

# Create PyTorch dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Save to .h5 file type
This enables much faster load times when running or training the model.

In [None]:
store = pd.HDFStore(r'path-to-file.h5')
store['X_train'] = X_train
store['X_val'] = X_val
store['X_test'] = X_test
store['y_train'] = pd.Series(y_train)
store['y_val'] = pd.Series(y_val)
store['y_test'] = pd.Series(y_test)
store.close()