In [None]:
import mne
import neurokit2 as nk
import pyedflib
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

plt.rcParams['figure.figsize']=(18, 3)

class PSGDataProcessor:
    def __init__(self):
        """
        Initialize the PSGDataProcessor class without loading data immediately.
        """
        self.data = None
        self.raw_data = None
        self.sampling_rate = None
        self.ch_names = None
        self.start_datetime = None
    
    def load_data(self, psg_file):
        """
        Load PSG data from an EDF file.

        Args:
        psg_file (str): Path to the PSG file.
        """
        self.data = mne.io.read_raw_edf(psg_file, preload=True)
        self.raw_data = self.data.get_data()
        self.sampling_rate = self.data.info['sfreq']
        self.ch_names = self.data.ch_names
        self.start_datetime = self.get_datetime_from_info(self.data.info['meas_date'])
        
    def psg_plot(self):
        self.data.plot()
        
    def print_file_info(self, psg_file, info_name = 'file'):
        """
        Load PSG data from an EDF file using pyEDFlib and print file and signal headers.

        Args:
        psg_file (str): Path to the PSG file.
        """
        try:
            edf_file = pyedflib.EdfReader(psg_file)
            self.retrieve_info(edf_file, info_name)
            edf_file.close()
        except Exception as e:
            print(f"Failed to read EDF file: {e}")

    @staticmethod
    def retrieve_info(edf_file, info_name = 'file'):
        """
        Print detailed information about the EDF file.

        Args:
        edf_file (pyedflib.EdfReader): An instance of EdfReader.
        """
        if info_name == 'label':
            PSGDataProcessor.print_label_and_freq(list(zip(edf_file.getSignalLabels(), edf_file.getSampleFrequencies())))
        elif info_name == 'signal':
            PSGDataProcessor.print_sig_headers(edf_file.getSignalHeaders())
        else:
            PSGDataProcessor.print_file_header(edf_file.getHeader())


    @staticmethod
    def print_label_and_freq(sig_freq):
        """
        Print signal labels and their corresponding sampling frequencies.

        Args:
        sig_freq (list of tuples): List containing signal labels and frequencies.
        """
        print("Signal Labels | Sampling Frequencies")
        print("------------------------------------")
        for label, freq in sig_freq:
            print(f"{label} | {freq}")

    @staticmethod
    def print_sig_headers(signal_headers):
        """
        Print all headers for each signal.

        Args:
        signal_headers (list): List containing headers of each signal.
        """
        for i, header in enumerate(signal_headers):
            print(f"Signal {i+1}:")
            print("Field Name | Value")
            print("------------------")
            for field_name, value in header.items():
                print(f"{field_name} | {value}")
            print("\n")

    @staticmethod
    def print_file_header(file_header):
        """
        Print the main header of the EDF file.

        Args:
        file_header (dict): Header information of the EDF file.
        """
        print("Field Name | Value")
        print("------------------")
        for field_name, value in file_header.items():
            print(f"{field_name} | {value}")
    
    def get_datetime_from_info(self, meas_date):
        """
        Convert measurement date to datetime object, handling various formats.

        Args:
        meas_date (tuple or datetime): The measurement date from MNE info.

        Returns:
        datetime: A datetime object.
        """
        if isinstance(meas_date, tuple):
            return datetime.fromtimestamp(meas_date[0]).replace(tzinfo=None)
        return meas_date.replace(tzinfo=None)
    
    def extract_segment_by_timestamp(self, start_datetime, end_datetime, data_types):
        """
        Extract specific types of data within a specified time range defined by timestamps.
        """
        start_idx = int((start_datetime - self.start_datetime).total_seconds() * self.sampling_rate)
        end_idx = int((end_datetime - self.start_datetime).total_seconds() * self.sampling_rate)
        return self.extract_data_indices(start_idx, end_idx, data_types)

    def extract_data_indices(self, start_idx, end_idx, data_types):
        """
        Extract specific types of data within a specified index range.
        """
        extracted_data = {}
        for data_type in data_types:
            if data_type in self.ch_names:
                data_array = np.array(self.data[data_type][0][0])
                extracted_data[data_type] = data_array[start_idx:end_idx]
            else:
                raise ValueError(f"Data type {data_type} not found in the dataset.")
        return extracted_data

    def plot_data(self, data, data_type, sampling_rate):
        """
        Plot data of a specified type using Matplotlib.
        """
        time_axis = np.linspace(0, len(data) / sampling_rate, len(data))
        plt.figure(figsize=(10, 4))
        plt.plot(time_axis, data, label=data_type)
        plt.xlabel('Time (seconds)')
        plt.ylabel('Amplitude')
        plt.title(f'{data_type} Data Plot')
        plt.legend()
        plt.grid(True)
        plt.show()
        
    def compare_plot(self, data_dict, channel_names, sampling_rate):
        """
        Plot multiple channels data for comparison using subplots.

        Args:
        data_dict (dict): Dictionary containing data arrays for channels.
        channel_names (list of str): List of channel names to plot.
        sampling_rate (int): The sampling rate of the data.
        """
        num_channels = len(channel_names)
        plt.figure(figsize=(10, 4 * num_channels))
        
        for i, channel in enumerate(channel_names):
            if channel in data_dict:
                ax = plt.subplot(num_channels, 1, i + 1)
                time_axis = np.linspace(0, len(data_dict[channel]) / sampling_rate, len(data_dict[channel]))
                ax.plot(time_axis, data_dict[channel], label=channel)
                ax.set_xlabel('Time (seconds)')
                ax.set_ylabel('Amplitude')
                ax.set_title(f'{channel} Data Plot')
                ax.legend()
                ax.grid(True)
            else:
                print(f"Data for {channel} not found in the provided data dictionary.")

        plt.tight_layout()
        plt.show()

    def ecg_diagram(self, ecg_slice):
        """
        Process and visualize an ECG signal slice with R-peaks.

        Args:
        ecg_slice (np.array): The slice of ECG data to process.
        """
        # Automatically process the (raw) ECG signal
        signals, info = nk.ecg_process(ecg_slice, sampling_rate=self.sampling_rate)

        # Plot the processed ECG signal
        nk.ecg_plot(signals, info)

        # Extract clean ECG and R-peaks location
        rpeaks = info["ECG_R_Peaks"]
        cleaned_ecg = signals["ECG_Clean"]

        # Visualize R-peaks in ECG signal
        plot = nk.events_plot(rpeaks, cleaned_ecg)
        plt.show()
        
    def rsp_diagram(self, rsp_slice):
        """
        Process and visualize a respiratory signal slice with peaks.

        Args:
        rsp_slice (np.array): The slice of RSP data to process.
        """
        # Process the respiratory signal
        rsp_signals, info = nk.rsp_process(rsp_slice, sampling_rate=self.sampling_rate, report="text")

        # Plot the processed RSP signal
        nk.rsp_plot(rsp_signals, info)

        # Extract clean RSP and R-peaks location
        cleaned_rsp = rsp_signals["RSP_Clean"]
        peaks = info["RSP_Peaks"]

        # Visualize R-peaks in RSP signal
        plot = nk.events_plot(peaks, cleaned_rsp)
        plt.show()
    