<a href="https://colab.research.google.com/github/osherlock1/exoplanet-detection-CNN/blob/main/lightcurve_preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
!pip install lightkurve

**Imports**

In [None]:
#Standard Libaries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os


#Handling Fits files
from astropy.timeseries import TimeSeries
from astropy.time import Time
from astropy.io import fits

#Google Colab
from google.colab import drive

#Lightkurve
import lightkurve as lk

#Pytorch
import torch
from torch.utils.data import Dataset

from concurrent.futures import ProcessPoolExecutor

**Google Drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#File path for directory
fits_directory = '/content/drive/MyDrive/ELE391_Final_Project/data/Kepler/Kepler_conf'

## K2 Processing

In [None]:
def process_single_file(fits_file, fits_directory, output_directory, num_bins=200):
    """
    Processes a single .fits file and saves the resulting tensor to the output directory.
    """
    file_path = os.path.join(fits_directory, fits_file)

    # Open the FITS file and extract data
    with fits.open(file_path) as hdul:
        data = hdul[1].data  # Light curve is usually in the second HDU
        time = data['T']
        flux = data['FCOR']
        #quality = data['SAP_QUALITY']



    # Remove NaN values
    valid = ~np.isnan(time) & ~np.isnan(flux) & np.isfinite(time) & np.isfinite(flux)
    time = time[valid]
    flux = flux[valid]

    # Create lightkurve object
    lc = lk.LightCurve(time=time[np.isfinite(time)], flux=flux[np.isfinite(time)])

    # Remove outliers
    lc_cleaned = lc.remove_outliers(sigma=5)

    # Flatten and normalize the light curve
    lc_flat = lc_cleaned.flatten().normalize()

    # Find best period
    periodogram = lc_flat.to_periodogram(method="bls")
    best_period = periodogram.period_at_max_power

    # Fold the light curve
    folded_lc = lc_flat.fold(period=best_period, normalize_phase=True)

    phase = folded_lc.phase.value + 0.5
    flux = folded_lc.flux.value

    # Bin the folded light curve
    bin_edges = np.linspace(0, 1, num_bins + 1)
    bin_indices = np.digitize(phase, bins=bin_edges) - 1

    binned_flux = np.array([
      np.mean(flux[bin_indices == i]) if np.any(bin_indices == i) else np.nan
      for i in range(num_bins)
    ])


    # Interpolate NaN values
    nans = np.isnan(binned_flux)
    x = np.arange(len(binned_flux))
    binned_flux[nans] = np.interp(x[nans], x[~nans], binned_flux[~nans])


    if np.std(binned_flux) > 0:
      binned_flux = (binned_flux - np.mean(binned_flux)) / np.std(binned_flux)

    # Convert to PyTorch tensor and save
    binned_tensor = torch.tensor(binned_flux, dtype=torch.float32)
    output_path = os.path.join(output_directory, f"{os.path.splitext(fits_file)[0]}.pt")
    torch.save(binned_tensor, output_path)

    print(f"Processed and saved: {fits_file} -> {output_path}")

def process_lightcurves_parallel(fits_directory, output_directory, num_bins=200):
    """
    Processes all .fits files in a directory in parallel.
    """
    # List all .fits files in the directory
    fits_files = [f for f in os.listdir(fits_directory) if f.endswith('.fits')]

    # Process files in parallel
    with ProcessPoolExecutor() as executor:
        executor.map(
            lambda fits_file: process_single_file(
                fits_file, fits_directory, output_directory, num_bins
            ),
            fits_files
        )

    print("All files processed.")


In [None]:
fits_directory = '/content/drive/MyDrive/ELE391_Final_Project/data/K2/K2_confirmed_names'
output_directory = '/content/drive/MyDrive/ELE391_Final_Project/data_v2/K2_v2/K2_conf_200'

fits_files = [f for f in os.listdir(fits_directory) if f.lower().endswith('.fits')]
print(f"Found {len(fits_files)} files to process.")

def debug_wrapper(fits_file):
    print(f"Passing to process_single_file: {fits_file}")
    process_single_file(fits_file, fits_directory, output_directory, num_bins=200)

with ProcessPoolExecutor() as executor:
    executor.map(debug_wrapper, fits_files)

## Plotting Binned Light Curves

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

def plot_binned_lightcurve(pt_file, num_bins=200):
    """
    Read a binned light curve from a .pt file and plot it.
    Args:
        pt_file (str): Path to the .pt file containing the binned light curve.
        num_bins (int): Number of bins used in the binned light curve.
    """
    # Load the binned flux tensor
    binned_flux = torch.load(pt_file)

    # Check if the file is a PyTorch tensor
    if not isinstance(binned_flux, torch.Tensor):
        raise ValueError("The loaded file does not contain a PyTorch tensor.")

    # Generate bin midpoints for the phase
    bin_edges = np.linspace(0, 1, num_bins + 1)
    binned_phase = (bin_edges[:-1] + bin_edges[1:]) / 2  # Midpoints of the bins

    # Convert tensor to NumPy for plotting
    binned_flux = binned_flux.numpy()

    # Plot the binned light curve
    plt.figure(figsize=(8, 5))
    plt.plot(binned_phase, binned_flux, '-o', label=f'Binned Light Curve: {os.path.basename(pt_file)}', color='blue')
    plt.xlabel("Phase")
    plt.ylabel("Flux")
    plt.title(f"Binned Light Curve: {os.path.basename(pt_file)}")
    plt.grid(alpha=0.5)
    plt.legend()
    plt.show()


def plot_all_lightcurves_in_dir(directory, num_bins=200):
    """
    Find all .pt files in the directory and plot their light curves.
    Args:
        directory (str): Path to the directory containing .pt files.
        num_bins (int): Number of bins used in the binned light curves.
    """
    # List all .pt files in the directory
    pt_files = [os.path.join(directory, file) for file in os.listdir(directory) if file.endswith('.pt')]

    # Check if there are any .pt files
    if not pt_files:
        print("No .pt files found in the directory.")
        return

    # Plot each .pt file
    for pt_file in pt_files:
        try:
            print(f"Plotting: {pt_file}")
            plot_binned_lightcurve(pt_file, num_bins)
        except Exception as e:
            print(f"Error plotting {pt_file}: {e}")




In [None]:
# Directory containing .pt files
directory = '/content/drive/MyDrive/ELE391_Final_Project/data_v4/K2_fp_pt'

# Plot all light curves in the directory
plot_all_lightcurves_in_dir(directory, num_bins=400)