In [1]:
import sys
%matplotlib inline
# !{sys.executable} -m pip install --upgrade pip
# !{sys.executable} -m pip install GPy
# !{sys.executable} -m pip install seaborn
# !{sys.executable} -m pip install ipywidgets

In [2]:
import numpy as np
import os
import matplotlib.pyplot
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mticker
import matplotlib.dates as mdates
import datetime
from tqdm import tqdm
import GPy
from collections import defaultdict
from pathlib import Path
import seaborn as sns
import scipy.stats as stats
from matplotlib.colors import ListedColormap
import warnings
import time
from itertools import product
from joblib import Parallel, delayed

mpl.rcParams['legend.frameon'] = False
mpl.rcParams['figure.autolayout'] = True
# mpl.rcParams['figure.dpi'] = 300
# mpl.rcParams['axes.spines.right'] = False
# mpl.rcParams['axes.spines.top'] = False


plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica"]})

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],
})


def utkarshGrid():
    plt.minorticks_on()
    plt.grid(color='grey',
             which='minor',
             linestyle=":",
             linewidth='0.1',
             )
    plt.grid(color='black',
             which='major',
             linestyle=":",
             linewidth='0.1',
             )

In [3]:
import random
random.seed(3)

def run_tests(bol):
    if bol:
        import doctest
        t_i = time.time()
        print(doctest.testmod())
        print(f"Time Taken: {round(time.time() - t_i)}s")
    else:
        print("[STATUS] Skipping Tests")

In [4]:
class AllData:
    """Load the data from the set path and prepare it in a useable format. 
    """
    def __init__(self):
        self.folder_name = "/bns_m3_3comp" # Change folder name of data as required. 
    
    def load_path(self, path_to_dir):
        """ User defined path
        """
        self.folder_path = path_to_dir
        self.path = path_to_dir + self.folder_name
        return None
    
    def load_raw_data(self):
        """ Loads raw data from given path. Implimentation may different for windows and mac/linux users.
        >>> data = AllData()
        >>> data.load_path("/Users/utkarsh/PycharmProjects/SURP2021")
        >>> data.load_raw_data()
        >>> print(data.raw_data.file_name.iloc[0])
        nph1.0e+06_mejdyn0.001_mejwind0.130_phi45.txt
        >>> data.raw_data.file_name.iloc[192] == "nph1.0e+06_mejdyn0.001_mejwind0.090_phi0.txt"
        True
        >>> data.raw_data.file_name.iloc[192] == "nph1.0e+06_mejdyn0.005_mejwind0.110_phi0.txt"
        False
        >>> data.raw_data.file_name.iloc[192] == "nph1.0e+06_mejdyn0.005_mejwind0.110_phi0.txt"
        False
        """
        resd = defaultdict(list)
        folder_path = Path(self.path)
        for file in folder_path.iterdir():
            with open(file, "r") as file_open:
                resd["file_name"].append(file.name)
        temp_df = pd.DataFrame(resd)
        self.raw_data = temp_df[temp_df.file_name != ".DS_Store"].reset_index(drop=True)
        return None
        
        
    def process(self):
        """ Processes the data to a readable reference dataframe.
        >>> data = AllData()
        >>> data.load_path("/Users/utkarsh/PycharmProjects/SURP2021")
        >>> data.load_raw_data()
        >>> data.process()
        >>> print(data.reference_data.mejwind.iloc[68])
        0.03
        >>> print(data.reference_data.mejdyn.iloc[173])
        0.02
        >>> data.reference_data.phi.iloc[55] == 75
        False
        >>> data.reference_data.phi.iloc[57] == 75
        False
        >>> data.reference_data.phi.iloc[56] == 75
        True
        """
        split_series = self.raw_data.file_name.apply(lambda x: x.split('_'))
        temp_df = split_series.apply(pd.Series)
        temp_df["file_name"] = self.raw_data.file_name
        temp_df.columns = ["nph", "mejdyn", "mejwind", "phi", "filename"]
        temp_df["mejdyn"] = temp_df["mejdyn"].str.extract("(\d*\.?\d+)", expand=True)
        temp_df["mejwind"] = temp_df["mejwind"].str.extract("(\d*\.?\d+)", expand=True)
        temp_df["phi"] = temp_df["phi"].str.extract("(\d*\.?\d+)", expand=True)
        temp_df["nph"] = temp_df["nph"].apply(lambda x: float(x[3:]))
        temp_df[["mejdyn", "mejwind", "phi"]] = temp_df[["mejdyn", "mejwind", "phi"]].apply(pd.to_numeric)
        self.reference_data = temp_df.reset_index(drop=True)
        return None        
    
    def save_reference(self):
        """ Saves the reference data into a file for future use. 
        """
        try:
            self.reference_data.to_csv("reference.csv", index = False)
            print("[STATUS] Reference Saved")
        except Exception:
            print("[ERROR] Reference Unsaved")
    
    def load_reference(self, name):
        """ Loads the saved dataframe to save on computing time.
        >>> data = AllData()
        >>> data.load_reference("reference.csv")
        >>> print(data.reference_data.mejwind.iloc[68])
        0.03
        >>> print(data.reference_data.mejdyn.iloc[173])
        0.02
        >>> data.reference_data.phi.iloc[55] == 75
        False
        >>> data.reference_data.phi.iloc[57] == 75
        False
        >>> data.reference_data.phi.iloc[56] == 75
        True
        """
        self.reference_data = pd.read_csv(name)
        
            

In [5]:
class LightCurve():
    """ The information regarding KNe light curves and data corresponding to KNe light curves.
    """
    
    def __init__(self, referenceName):
        """ Initializes class, reference is all the light curves, and selected represents ones of interest to be narrowed. 
        """
        self.reference = pd.read_csv(referenceName)
        self.selected = self.reference.copy()
        self.uBand = 365 
        self.bBand = 445
        self.gBand = 464
        self.vBand = 551
        self.rBand = 658
        self.iBand = 806
        self.zBand = 900
        self.yBand = 1020
        self.jBand = 1220
        self.hBand = 1630
        self.kBand = 2190
        self.lBand = 3450
        self.mBand = 4750
        self.nBand = 10500
        self.qBand = 21000
        self.temp_path = "/Users/utkarsh/PycharmProjects/SURP2021"
        
        warnings.filterwarnings( action='ignore', module='matplotlib.figure', category=UserWarning, 
                                message=('This figure includes Axes that are not compatible with tight_layout, '
             'so results might be incorrect.'))

    def _slice(self, typ, Min, Max):
        sliced = self.selected[self.selected[typ] >= Min]
        sliced2 = sliced[sliced[typ] <= Max]
        return sliced2 
    
    def select_curve(self, phiRange = [], mejdynRange = [], mejwindRange = [], nphRange = []):
        """ Select a measurment based on the physics limits required. 
        >>> data = LightCurve("reference.csv")
        >>> phi_range = [30]
        >>> mejdyn_range = [0.01]
        >>> mejwind_range = [0.11]
        >>> data.select_curve(phiRange = phi_range, mejdynRange = mejdyn_range, mejwindRange = mejwind_range)
        >>> print(data.selected.filename.iloc[0])
        nph1.0e+06_mejdyn0.010_mejwind0.110_phi30.txt
        """
        self.phi_range_single = phiRange
        self.mejdyn_range_single = mejdynRange
        self.mejwind_range_single = mejwindRange
        self.nph_range_single = nphRange
        if len(nphRange) > 0:
            self.selected = self._slice("nph", min(nphRange), max(nphRange))
        if len(phiRange) > 0:
            self.selected = self._slice("phi", min(phiRange), max(phiRange))
        if len(mejdynRange) > 0:
            self.selected = self._slice("mejdyn", min(mejdynRange), max(mejdynRange))   
        if len(mejwindRange) > 0:
            self.selected = self._slice("mejwind", min(mejwindRange), max(mejwindRange))   
        return None
    
    def _set_path(self, path):
        """ Sets the path to the file to be extracted. Chooses first file if there are many. 
        >>> data = LightCurve("reference.csv")
        >>> data._set_path("")
        [WARNING] Many curves in data: First curve has been selected. 
        [CURVE] nph1.0e+06_mejdyn0.001_mejwind0.130_phi45.txt
        >>> print(data.path)
        /bns_m3_3comp/nph1.0e+06_mejdyn0.001_mejwind0.130_phi45.txt
        """
        self.temp_path = path
        self.folder_path = path + "/bns_m3_3comp/"
        self.path = self.folder_path + self.selected.filename.iloc[0]
        if len(self.selected.filename) > 1:
            print(f"[WARNING] Many curves in data: First curve has been selected. \n[CURVE] {self.selected.filename.iloc[0]}")
        
        return None
    
    def extract_curve(self):
        """ Extracts curve based on selected data and converts it into a readable format. 
        >>> data = LightCurve("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.select_curve(phiRange = phi_range, mejdynRange = mejdyn_range, mejwindRange = mejwind_range)
        >>> data.extract_curve()
        >>> data.curve.shape
        (11, 500)
        >>> zBand = 910
        >>> plotDf = data.curve.loc[:, [zBand]]
        >>> print(plotDf.loc[1,zBand][3])
        0.0028678
        >>> print(data.selected.filename.iloc[0])
        nph1.0e+06_mejdyn0.020_mejwind0.110_phi60.txt
        >>> data.curve.shape
        (11, 500)
        >>> data.Nobs
        11
        >>> data.Nwave
        500.0
        >>> data.Ntime
        [100.0, 0.0, 20.0]
        >>> data.time_arr[13]
        2.6262626262626263
        >>> data = LightCurve("reference.csv")
        >>> phi_range = [61]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.select_curve(phiRange = phi_range, mejdynRange = mejdyn_range, mejwindRange = mejwind_range)
        >>> data.extract_curve()
        [ERROR] Selected dataframe is empty! Data Compromised.
        """
        
        if self.selected.empty:
            print('[ERROR] Selected dataframe is empty! Data Compromised.')
            return None
               
        # Obtain path to read curve from. 
        self._set_path(self.temp_path)
        
        # Read txt file containig light curve information
        temp0 = pd.read_csv(self.path, header = None, names = ["data"])
        
        # Set parameters for viewing angles, numbers of wavelengths, and time step. 
        self.Nobs = int(temp0.data.iloc[0])
        self.Nwave = float(temp0.data.iloc[1])
        self.Ntime = list(map(float, temp0.data.iloc[2].split()))
        self.time_arr = np.linspace(int(self.Ntime[1]), int(self.Ntime[2]), int(self.Ntime[0]), endpoint = True)
        
        # Drop information header and reset index. 
        temp1 = temp0.iloc[3:].reset_index(drop = True)
        
        # Convert data from string to float
        temp1["data"] = temp1["data"].apply(lambda x: list(map(float, x.split())))
        
        # Obtain wavelength from messy data list. Convert to nm
        temp1.loc[:, 'wavelength'] = temp1.data.map(lambda x: x[0]/10)
        
        # Remove wavelengths from data vector. 
        temp1["data"] = temp1["data"].apply(lambda x: x[1:])
        
        # Pivot to order the table by wavelengths
        temp1 = temp1.pivot(columns = "wavelength", values = "data")
        
        # Concatenate all rows to remove NA values to get a neat, readable dataframe. 
        final = pd.concat([temp1[col].dropna().reset_index(drop=True) for col in temp1], axis=1)
        
        # Rename axis titles. 
        final.index.name = "iobs"
        final.columns.name = "wavelength"
        self.curve = final
        
        return None
    
    def _odd(self,x):
        """Rounds to nearest odd numbers
        >>> data = LightCurve("reference.csv")
        >>> data._odd(3)
        3
        >>> data._odd(2.5)
        3
        >>> data._odd(2)
        3
        >>> data._odd(1.999)
        1
        """
        return 2 * int(x/2) + 1
    
    def simple_plot(self, wv):
        """Simple plotting function by wavelength for light curve data. 
        >>> data = LightCurve("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.select_curve(phiRange = phi_range, mejdynRange = mejdyn_range, mejwindRange = mejwind_range)
        >>> data.extract_curve()
        >>> data.simple_plot(900)
        [STATUS] Plotting...
        >>> data.time_arr[11]
        2.2222222222222223
        >>> data.wavelength
        910
        >>> plt.show() #doctest: +SKIP
        >>> plt.close()
        """
        print("[STATUS] Plotting...")
        self.time_arr = np.linspace(int(self.Ntime[1]), int(self.Ntime[2]), int(self.Ntime[0]), endpoint = True)
        self.wavelength = 10*self._odd(wv/10)

        viewing_angles = np.linspace(0, 1, self.Nobs, endpoint = True)
        plt.figure(dpi = 300)
        plt.gca().set_prop_cycle("color", sns.color_palette("coolwarm_r",self.Nobs))
        for i,j in self.curve.loc[:, [self.wavelength]].iterrows():
            ang = round(np.degrees(np.arccos(viewing_angles[i])), 2) 
            plt.plot(self.time_arr, j.values[0], label = f"{ang}"r"$^o$", linewidth = 1)
        plt.xlabel("Time (Days)")
        plt.ylabel(r"Flux $Erg s^{-1} cm^{-2}A^{-1}$")
        plt.title(f"Lights curves for {self.Nobs} viewing angles at {self.wavelength}nm")
        utkarshGrid()
        plt.legend(title = r"$\Phi$")
        return None
    
    def _compute_wavelength(self,wv):
        """Wavelength helper function. Rounds value to nearest odd 10. 
        >>> data = LightCurve("reference.csv")
        >>> data._compute_wavelength(900)
        910
        >>> data._compute_wavelength(899.99)
        890
        """
        return 10*self._odd(wv/10)
    
    def plot_viewingangle_simple(self):
        """Plots LightCurve according to viewing angle from the pole to the equator. Multiple bands are plotted. 
        >>> data = LightCurve("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.select_curve(phiRange = phi_range, mejdynRange = mejdyn_range, mejwindRange = mejwind_range)
        >>> data.extract_curve()
        >>> data.plot_viewingangle_simple()
        [STATUS] Plotting...
        >>> plt.show() #doctest: +SKIP
        >>> plt.close()
        """
        wvList = [self.uBand, self.gBand, self.rBand, self.iBand, 
                  self.zBand, self.yBand, self.jBand, self.hBand]

        plt.figure(dpi=300)
        print("[STATUS] Plotting...")
        for k in range(len(wvList)):
            wv = wvList[k]
            self.time_arr = np.linspace(int(self.Ntime[1]), int(self.Ntime[2]), int(self.Ntime[0]), endpoint = True)
            self.wavelength = self._compute_wavelength(wv)
            viewing_angles = np.linspace(0, 1, self.Nobs, endpoint = True)
            colors = sns.color_palette("coolwarm_r",len(wvList))[::-1]
            # plt.gca().set_prop_cycle("color", sns.color_palette("coolwarm_r",len(wvList)))

            for i,j in self.curve.loc[:, [self.wavelength]].iterrows():
                if i == 0:
                    labelStr = f"{self.wavelength}nm"
                else:
                    labelStr = f""
                ang = round(np.degrees(np.arccos(viewing_angles[i])), 2) 
                plt.plot(self.time_arr, j.values[0], label = labelStr, 
                         linewidth = 1, color = colors[k])

        plt.xlabel("Time (Days)")
        plt.ylabel(r"Log Flux $Erg s^{-1} cm^{-2}A^{-1}$")
        plt.legend(title = r"$\lambda$", ncol=2, loc = "upper right")
        plt.yscale("log")
        plt.title(f"Lights curves for {self.Nobs} viewing angles at varying wavelengths")
        return None

    
    def plot_viewingangle(self): 
        """ Plots LightCurve according to viewing angle as mutiple subplots.  
        >>> data = LightCurve("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.select_curve(phiRange = phi_range, mejdynRange = mejdyn_range, mejwindRange = mejwind_range)
        >>> data.extract_curve()
        >>> data.plot_viewingangle()
        [STATUS] Plotting for nph: [], mejdyn: [0.02], mejwind: [0.11], phi: [60], viewing_angle: [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]
        >>> data.uBand
        365
        >>> data.counter_final
        8
        >>> plt.show() #doctest: +SKIP
        >>> plt.close()
        """
        fig, axes = plt.subplots(nrows = 2, ncols = 4, dpi=300, figsize = (6,3.5))
        plt.tight_layout()
        viewing_angles = np.linspace(0, 1, self.Nobs, endpoint = True)
        colors = plt.cm.RdBu(np.linspace(-1,1,self.Nobs))
        counter = 0
        wvList = [self.uBand, self.gBand, self.rBand, self.iBand, 
                  self.zBand, self.yBand, self.jBand, self.hBand]
        namesList = ["uBand", "gBand", "rBand", "iBand", "zBand", "yBand", "jBand", "hBand"]
        ticks = np.arange(min(self.time_arr), max(self.time_arr)+1, 5)
        self.iobs_range = viewing_angles

        for row in range(0, 2):
            for col in range(0,4):
                wv = wvList[counter]
                self.wavelength = self._compute_wavelength(wv)

                for i,j in self.curve.loc[:, [self.wavelength]].iterrows():
                    if i == 0:

                        labelStr = f"{namesList[counter][0].upper()}"
                    else:
                        labelStr = f""
                    ang = round(np.degrees(np.arccos(viewing_angles[i])), 2) 
                    im = axes[row,col].plot(self.time_arr, j.values[0], label = labelStr, 
                             linewidth = 1, color = colors[i])

                axes[row,col].set_yscale('log')
                axes[row,col].legend(handletextpad=-2.0, handlelength=0)
                axes[row,col].axes.get_yaxis().set_visible(False)
                axes[row,col].set_xticks(ticks)

                counter += 1
        
        self.counter_final = counter
        bottom, top = 0.1, 0.9
        left, right = 0.2, 0.8

        fig.subplots_adjust(top=top, bottom=bottom, left=left, right=right, hspace=0.15, wspace=0.25)
        my_cmap = "RdBu"
        sm = plt.cm.ScalarMappable(cmap=my_cmap, norm=plt.Normalize(vmin=0, vmax=90))
        cbar_ax = fig.add_axes([1, bottom, 0.03, top-bottom])
        cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), extend='both', cax = cbar_ax, format=mpl.ticker.FuncFormatter(self._fmt_degree))
        cbar.set_label(r"Viewing Angle $\theta_0$", size=15, labelpad=10)
        fig.text(0.5, 0, "Time since merger (Days)", ha='center')
        fig.text(-0.01, 0.5, r"Log Flux $Erg s^{-1} cm^{-2}A^{-1}$", va='center', rotation='vertical')
        print(f"[STATUS] Plotting for nph: {self.nph_range_single}, mejdyn: {self.mejdyn_range_single}, mejwind: {self.mejwind_range_single}, phi: {self.phi_range_single}, viewing_angle: {self.iobs_range}")
        return None  
    
    def select_viewingangle(self, phi_range, mejdyn_range, mejwind_range, wv = 0):
        """ Trauncates the selected data (self.curve) by selected wavelength. 
        >>> data = LightCurve("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.select_curve(phiRange = phi_range, mejdynRange = mejdyn_range, mejwindRange = mejwind_range)
        >>> data.extract_curve()
        >>> data.select_viewingangle(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.viewingangle.iloc[0][3]
        2.6453e-06
        >>> data.viewingangle.iloc[3][0]
        0.0019884
        >>> data.viewingangle.iloc[3][3]
        0.0029003
        >>> data.viewingangle.shape
        (100, 12)
        >>> data.mejdyn_range
        [0.02]
        >>> data.wv_range
        900
        >>> data.phi_range
        [60]
        >>> data.mejwind_range
        [0.11]
        """
        self.mejdyn_range = mejdyn_range
        self.wv_range = wv
        self.phi_range = phi_range
        self.mejwind_range = mejwind_range
        
        wv = self._compute_wavelength(wv)
        self.select_curve(phiRange = phi_range, 
                          mejdynRange = mejdyn_range, 
                          mejwindRange = mejwind_range)
        self.extract_curve()
        
        if wv>0:
            z = self.curve.T[self.curve.T.index == wv]
            z = z.reset_index(drop = True)
            z = z.apply(pd.Series.explode).reset_index(drop = True)
            z["time"] = self.time_arr
            z.index.name = "time_step"
            self.viewingangle = z 
        return None
    
    def select_mejdyn(self, wv_range, iobs_range, phi_range, mejwind_range):
        """Slice reference data set by dynamical ejecta mass (mejdyn). 
        >>> data = LightCurve("reference.csv")
        >>> wv_range = [900]
        >>> iobs_range = [0]
        >>> phi_range = [45]
        >>> mejwind_range = [0.13]
        >>> data.select_mejdyn(wv_range, iobs_range, phi_range, mejwind_range)
        >>> data.mejdyn.iloc[0,3]
        2.8034e-06
        >>> data.mejdyn.iloc[44,2]
        0.00019771
        >>> data.mejdyn.iloc[65,3]
        4.5352e-05
        >>> data.mejdyn.shape
        (100, 4)
        >>> data.phi_range
        [45]
        >>> data.mejdyn_range
        [0.001, 0.005, 0.01, 0.02]
        >>> data.iobs_range
        [0]
        """
        mejdyn_range_list = self.reference.mejdyn.unique()
        
        self.mejdyn_range = sorted(mejdyn_range_list)
        self.wv_range = wv_range
        self.iobs_range = iobs_range
        self.phi_range = phi_range
        self.mejwind_range = mejwind_range
        
        
        mejdyn_range_list = [[x] for x in mejdyn_range_list]
        df = pd.DataFrame()
        self.select_curve(phiRange = phi_range, mejwindRange = mejwind_range)
        tempReference = self.selected

        for i in range(len(mejdyn_range_list)):
            self.select_curve(mejdynRange = [self.mejdyn_range[i]])
            self.extract_curve()
            arr = self.curve[self._compute_wavelength(wv_range[0])][iobs_range[0]] # at wv 900, iobs0, over TIME STEP, for mejdyn 0.01
            assert len(arr) == self.Ntime[0]
            df[self.mejdyn_range[i]] = arr
            self.selected = tempReference

        df.columns.name = "mejdyn"
        df.index.name = "time_step"
        self.mejdyn = df
        self.mejdyn = self.mejdyn.sort_index(axis=1)
        return None
    
    def plot_mejdyn(self, verbose = False):
        """ Plots the data acquired in the dynamical ejecta mass. 
        >>> data = LightCurve("reference.csv")
        >>> wv_range = [900]
        >>> iobs_range = [0]
        >>> phi_range = [45]
        >>> mejwind_range = [0.13]
        >>> data.select_mejdyn(wv_range, iobs_range, phi_range, mejwind_range)
        >>> data.plot_mejdyn() #doctest: +SKIP
        """
        numRows = 2
        numCols = 4
        fig, axes = plt.subplots(nrows = numRows, ncols = numCols, dpi=300, figsize = (6,3.5))
        plt.tight_layout()
        viewing_angles = np.linspace(0, 1, self.Nobs, endpoint = True)
        colors = plt.cm.PiYG(np.linspace(0,1,len(self.mejdyn_range)))
        counter = 0
        namesList = list(map(str, self.mejdyn_range))
        ticks = np.arange(min(self.time_arr), max(self.time_arr)+1, 5)
        wvList = [self.uBand, self.gBand, self.rBand, self.iBand, 
                          self.zBand, self.yBand, self.jBand, self.hBand]
        namesList = ["uBand", "gBand", "rBand", "iBand", "zBand", "yBand", "jBand", "hBand"]

        row = 0
        col = 0

        for row in tqdm(range(0, numRows), disable= not verbose):
            for col in range(0,numCols):
                wv = wvList[counter]
                self.select_mejdyn([wv], self.iobs_range, self.phi_range, self.mejwind_range)
                self.wavelength = self._compute_wavelength(wv)

                for i in range(len(self.mejdyn_range)):
                    if i == 0:

                        labelStr = f"{namesList[counter][0].upper()}"
                    else:
                        labelStr = f""
                    axes[row, col].plot(self.time_arr, self.mejdyn[self.mejdyn_range[i]], label = labelStr, 
                                 linewidth = 1, color = colors[i])
                axes[row,col].set_yscale('log')
                axes[row,col].legend(handletextpad=-2.0, handlelength=0)
                axes[row,col].axes.get_yaxis().set_visible(False)
                axes[row,col].set_xticks(ticks)

                counter += 1



        self.counter_final = counter
        bottom, top = 0.1, 0.9
        left, right = 0.2, 0.8

        fig.subplots_adjust(top=top, bottom=bottom, left=left, right=right, hspace=0.15, wspace=0.25)
        my_cmap = "PiYG"
        sm = plt.cm.ScalarMappable(cmap=my_cmap, norm=plt.Normalize(vmin=min(self.mejdyn_range), vmax=max(self.mejdyn_range)))
        cbar_ax = fig.add_axes([1, bottom, 0.03, top-bottom])
        cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), extend='both', cax = cbar_ax, format=mpl.ticker.FuncFormatter(self._fmt_solarmass))
        cbar.set_label(r"Dynamical Ejecta Mass ($M_{ej}$)", size=15, labelpad=10)
        fig.text(0.5, 0, "Time since merger (Days)", ha='center')
        fig.text(-0.01, 0.5, r"Log Flux $Erg s^{-1} cm^{-2}A^{-1}$", va='center', rotation='vertical')
        print(f"[STATUS] Plotting for mejdyn: {self.mejdyn_range}, mejwind: {self.mejwind_range}, phi: {self.phi_range}, viewing_angle: {self.iobs_range}")
        
    def select_mejwind(self, wv_range, iobs_range, phi_range, mejdyn_range):
        """Slice reference data set by wind ejecta mass (mejwind). 
        >>> data = LightCurve("reference.csv")
        >>> wv_range = [900]
        >>> iobs_range = [0]
        >>> mejdyn_range = [0.01]
        >>> phi_range = [45]
        >>> data.select_mejwind(wv_range, iobs_range, phi_range, mejdyn_range)
        >>> data.mejwind.iloc[0,3]
        2.465e-06
        >>> data.mejwind.iloc[44,2]
        6.077e-05
        >>> data.mejwind.iloc[65,6]
        7.3339e-05
        >>> data.mejwind.shape
        (100, 7)
        >>> data.phi_range
        [45]
        >>> data.mejwind_range
        [0.01, 0.03, 0.05, 0.07, 0.09, 0.11, 0.13]
        >>> data.iobs_range
        [0]
        >>> data.mejdyn_range
        [0.01]
        """
        mejwind_range_list = self.reference.mejwind.unique()
        self.mejwind_range = sorted(mejwind_range_list)
        self.wv_range = wv_range
        self.iobs_range = iobs_range
        self.phi_range = phi_range
        self.mejdyn_range = mejdyn_range
        mejwind_range_list = [[x] for x in mejwind_range_list]
        df = pd.DataFrame()
        self.select_curve(phiRange = phi_range, mejdynRange = mejdyn_range)
        tempReference = self.selected

        for i in range(len(mejwind_range_list)):
            self.select_curve(mejwindRange = [self.mejwind_range[i]])
            self.extract_curve()
            arr = self.curve[self._compute_wavelength(wv_range[0])][iobs_range[0]] # at wv 900, iobs0, over TIME STEP, for mejdyn 0.01
            assert len(arr) == self.Ntime[0]
            df[self.mejwind_range[i]] = arr
            self.selected = tempReference

        df.columns.name = "mejwind"
        df.index.name = "time_step"
        self.mejwind = df
        self.mejwind = self.mejwind.sort_index(axis=1)
    
    def plot_mejwind(self, verbose = False):
        """ Plots the data acquired in the wind ejecta mass. 
        >>> data = LightCurve("reference.csv")
        >>> wv_range = [900]
        >>> iobs_range = [0]
        >>> mejdyn_range = [0.01]
        >>> phi_range = [45]
        >>> data.select_mejwind(wv_range, iobs_range, phi_range, mejdyn_range)
        >>> data.plot_mejwind() #doctest: +SKIP
        """
        numRows = 2
        numCols = 4
        fig, axes = plt.subplots(nrows = numRows, ncols = numCols, dpi=300, figsize = (6,3.5))
        plt.tight_layout()
        viewing_angles = np.linspace(0, 1, self.Nobs, endpoint = True)
        colors = plt.cm.BrBG(np.linspace(0,1,len(self.mejwind_range)))
        counter = 0
        namesList = list(map(str, self.mejwind_range))
        ticks = np.arange(min(self.time_arr), max(self.time_arr)+1, 5)
        wvList = [self.uBand, self.gBand, self.rBand, self.iBand, 
                          self.zBand, self.yBand, self.jBand, self.hBand]
        namesList = ["uBand", "gBand", "rBand", "iBand", "zBand", "yBand", "jBand", "hBand"]

        row = 0
        col = 0

        for row in tqdm(range(0, numRows), disable= not verbose):
            for col in range(0,numCols):
                wv = wvList[counter]
                self.select_mejwind([wv], self.iobs_range, self.phi_range, self.mejdyn_range)
                self.wavelength = self._compute_wavelength(wv)

                for i in range(len(self.mejwind_range)):
                    if i == 0:

                        labelStr = f"{namesList[counter][0].upper()}"
                    else:
                        labelStr = f""
                    axes[row, col].plot(self.time_arr, self.mejwind[self.mejwind_range[i]], label = labelStr, 
                                 linewidth = 1, color = colors[i])
                axes[row,col].set_yscale('log')
                axes[row,col].legend(handletextpad=-2.0, handlelength=0)
                axes[row,col].axes.get_yaxis().set_visible(False)
                axes[row,col].set_xticks(ticks)

                counter += 1



        self.counter_final = counter
        bottom, top = 0.1, 0.9
        left, right = 0.2, 0.8

        fig.subplots_adjust(top=top, bottom=bottom, left=left, right=right, hspace=0.15, wspace=0.25)
        my_cmap = "BrBG"
        sm = plt.cm.ScalarMappable(cmap=my_cmap, norm=plt.Normalize(vmin=min(self.mejwind_range), vmax=max(self.mejwind_range)))
        cbar_ax = fig.add_axes([1, bottom, 0.03, top-bottom])
        cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), extend='both', cax = cbar_ax, format=mpl.ticker.FuncFormatter(self._fmt_solarmass))
        cbar.set_label(r"Wind Ejecta Mass ($M_{ej}$)", size=15, labelpad=10)
        fig.text(0.5, 0, "Time since merger (Days)", ha='center')
        fig.text(-0.01, 0.5, r"Log Flux $Erg s^{-1} cm^{-2}A^{-1}$", va='center', rotation='vertical')
        print(f"[STATUS] Plotting for mejdyn: {self.mejdyn_range}, mejwind: {self.mejwind_range}, phi: {self.phi_range}, viewing_angle: {self.iobs_range}")
    
    def select_phi(self, wv_range, iobs_range, mejwind_range, mejdyn_range):
        """Slice reference data set by half-seperation angle (phi). 
        >>> data = LightCurve("reference.csv")
        >>> wv_range = [900]
        >>> iobs_range = [0]
        >>> mejdyn_range = [0.01]
        >>> mejwind_range = [0.13]
        >>> data.select_phi(wv_range, iobs_range, mejwind_range, mejdyn_range)
        >>> data.phi.iloc[0,3]
        1.203e-06
        >>> data.phi.iloc[44,2]
        0.00023286
        >>> data.phi.iloc[65,4]
        6.9335e-05
        >>> data.phi.shape
        (100, 7)
        >>> data.phi_range
        [0, 15, 30, 45, 60, 75, 90]
        >>> data.mejwind_range
        [0.13]
        >>> data.iobs_range
        [0]
        >>> data.mejdyn_range
        [0.01]
        """
        phi_range_list = self.reference.phi.unique()
        self.phi_range = sorted(phi_range_list)
        self.wv_range = wv_range
        self.iobs_range = iobs_range
        self.mejwind_range = mejwind_range
        self.mejdyn_range = mejdyn_range
        
        phi_range_list = [[x] for x in phi_range_list]
        df = pd.DataFrame()
        self.select_curve(mejwindRange = mejwind_range, mejdynRange = mejdyn_range)
        tempReference = self.selected

        for i in range(len(phi_range_list)):
            self.select_curve(phiRange = [self.phi_range[i]])
            self.extract_curve()
            arr = self.curve[self._compute_wavelength(wv_range[0])][iobs_range[0]] # at wv 900, iobs0, over TIME STEP, for mejdyn 0.01
            assert len(arr) == self.Ntime[0]
            df[self.phi_range[i]] = arr
            self.selected = tempReference

        df.columns.name = "phi"
        df.index.name = "time_step"
        self.phi = df
        self.phi = self.phi.sort_index(axis=1)
    
    def plot_phi(self, verbose = False):
        """ Plots the data acquired in the half-opening angle. 
        >>> dat = LightCurve("reference.csv")
        >>> wv_range = [900]
        >>> iobs_range = [0]
        >>> mejdyn_range = [0.01]
        >>> mejwind_range = [0.13]
        >>> dat.select_phi(wv_range, iobs_range, mejwind_range, mejdyn_range)
        >>> dat.plot_phi() #doctest: +SKIP
        """
        numRows = 2
        numCols = 4
        fig, axes = plt.subplots(nrows = numRows, ncols = numCols, dpi=300, figsize = (6,3.5))
        plt.tight_layout()
        viewing_angles = np.linspace(0, 1, self.Nobs, endpoint = True)
        colors = plt.cm.PuOr(np.linspace(0,1,len(self.phi_range)))
        counter = 0
        namesList = list(map(str, self.phi_range))
        ticks = np.arange(min(self.time_arr), max(self.time_arr)+1, 5)
        wvList = [self.uBand, self.gBand, self.rBand, self.iBand, 
                          self.zBand, self.yBand, self.jBand, self.hBand]
        namesList = ["uBand", "gBand", "rBand", "iBand", "zBand", "yBand", "jBand", "hBand"]
        for row in tqdm(range(0, numRows), disable= not verbose):
            for col in range(0,numCols):
                wv = wvList[counter]
                self.select_phi([wv], self.iobs_range, self.mejwind_range, self.mejdyn_range)
                self.wavelength = self._compute_wavelength(wv)

                for i in range(len(self.phi_range)):
                    if i == 0:

                        labelStr = f"{namesList[counter][0].upper()}"
                    else:
                        labelStr = f""
                    axes[row, col].plot(self.time_arr, self.phi[self.phi_range[i]], label = labelStr, 
                                 linewidth = 1, color = colors[i])
                axes[row,col].set_yscale('log')
                axes[row,col].legend(handletextpad=-2.0, handlelength=0)
                axes[row,col].axes.get_yaxis().set_visible(False)
                axes[row,col].set_xticks(ticks)

                counter += 1
        self.counter_final = counter
        bottom, top = 0.1, 0.9
        left, right = 0.2, 0.8

        fig.subplots_adjust(top=top, bottom=bottom, left=left, right=right, hspace=0.15, wspace=0.25)
        my_cmap = "PuOr"
        sm = plt.cm.ScalarMappable(cmap=my_cmap, norm=plt.Normalize(vmin=min(self.phi_range), vmax=max(self.phi_range)))
        cbar_ax = fig.add_axes([1, bottom, 0.03, top-bottom])
        cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), extend='both', cax = cbar_ax, format=mpl.ticker.FuncFormatter(self._fmt_degree))
        cbar.set_label(r"Phi ($\Phi$)", size=15, labelpad=10)
        fig.text(0.5, 0, "Time since merger (Days)", ha='center')
        fig.text(-0.01, 0.5, r"Log Flux $Erg s^{-1} cm^{-2}A^{-1}$", va='center', rotation='vertical')
        print(f"[STATUS] Plotting for mejdyn: {self.mejdyn_range}, mejwind: {self.mejwind_range}, phi: {self.phi_range}, viewing_angle: {self.iobs_range}")
        
    def _fmt_degree(self, x, pos):
        """Formats the text given into scientific notation. Used as a helper function in plotting. 
        """

        return r'${}^o$'.format(round(x))
    
    def _fmt_solarmass(self, x, pos):
        """Formats the text given into scientific notation. Used as a helper function in plotting. 
        """
        a, b = '{:.2e}'.format(x).split('e')
        b = int(b)
        return r'${} \times 10^{{{}}} M_\odot$'.format(a, b)
    

In [6]:
class GP(LightCurve):
    """The Gaussian Process for KNe Light Curves using GPy. 
    """
    
    def __init__(self, referenceName):
        """ Instantiates class of both Gaussian Process and KNe Light Curve
        >>> gp = GP("reference.csv")
        >>> isinstance(gp, GP)
        True
        >>> isinstance(gp, LightCurve)
        True
        >>> isinstance(gp, float)
        False
        """
        LightCurve.__init__(self, referenceName)
        self.empirical = False
        self.isLog = False
        self.normalizeX = False
        self.extraction_time = 1
        self.failed = 0
        return None
    
    def range_select_wavelength(self, phi_range, mejdyn_range, mejwind_range, wv):
        """ Trauncate selection of light curve by wavelength. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.viewingangle.iloc[0][3]
        2.6453e-06
        >>> data.viewingangle.iloc[3][0]
        0.0019884
        >>> data.viewingangle.iloc[3][3]
        0.0029003
        >>> data.viewingangle.shape
        (100, 12)
        """
        self.select_viewingangle(phi_range, mejdyn_range, mejwind_range, wv)
        return None

    def single_time_step(self, time_of_interest, delta = 0):
        """ Select single time step and save dataframe of viewing angles. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.time_sliced.shape
        (1, 11)
        >>> data.time_sliced.iloc[0][1]
        0.004539
        >>> data.single_time_step(2, delta = 2)
        >>> data.time_sliced.shape
        (5, 11)
        >>> data.time_sliced.iloc[1][1]
        0.0029853
        >>> data.time_sliced.iloc[0][1]
        0.0036462
        >>> data.time_sliced.iloc[1][0]
        0.0036834
        """
        time_of_int = time_of_interest
        time_ind = np.argmin(np.abs(self.time_arr-time_of_int)) 
        delt = delta
        day = self.viewingangle.iloc[time_ind - delt: time_ind + delt + 1] 
        del day["time"] # dont need time after choosing our time frame
        self.time_sliced = day
        return None
    
    def normedDF(self):
        """ Save the median normal of the dataframe. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.time_sliced_normed.shape
        (1, 11)
        >>> data.time_sliced_normed.iloc[0][1]
        -0.015038083458108309
        >>> data.time_sliced_normed.iloc[0][3]
        -0.07545081700410139
        >>> data.time_sliced_normed.iloc[0][6]
        0.12820345897619512
        """
        med = np.median(self.time_sliced)
        self.time_sliced_normed = self.time_sliced.divide(med) - 1
        return None
    
    def _normedArr(self, arr):
        """ Returns the median normal of the array. 
        >>> data = GP("reference.csv")
        >>> newArr = np.array([0,1,2,3,4,5,6,7,8,9])
        >>> newArr2 = data._normedArr(newArr)
        >>> newArr3 = np.array([-1. , -0.77777778, -0.55555556, -0.33333333, -0.11111111,\
        0.11111111,  0.33333333,  0.55555556,  0.77777778,  1. ])
        >>> np.allclose(newArr2, newArr3)
        True
        >>> data.median
        4.5
        """
        med = np.median(arr)
        self.median = med
        return arr/med - 1
    
    def _undoNormedArr(self, arr):
        """ Undos the median normal of the array. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> np.isclose(data.medianCov, 2.123642889e-05)
        True
        >>> np.isclose(data.median, 0.0046083)
        True
        >>> data.median = data._undoNormedArr(data.time_sliced)
        >>> np.isclose(data.medianCov, 2.123642889e-05)
        True
        >>> np.isclose(float(data.median[0]), 0.004620670059690001)
        True
        >>> np.isclose(float(data.median[3]), 0.0046279341229800005)
        True
        """
        return (arr + 1)*self.median  
    
    def _undoCovNorm(self,arr):
        """Undoes the normalization of the covariance matrix. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> np.isclose(data.medianCov, 2.123642889e-05)
        True
        >>> np.isclose(data.median, 0.0046083)
        True
        >>> preTest = np.random.rand(10,10)
        >>> testMat = data._undoCovNorm(preTest)
        >>> np.isclose(testMat[1,1], (preTest[1,1]+1)*data.medianCov)
        True
        >>> np.isclose(testMat[8,9], (preTest[8,9]+1)*data.medianCov)
        True
        >>> np.isclose(data.median, 0.0046083)
        True
        """
        return (arr+1)*self.medianCov
    
    def setXY_viewingangle(self):
        """ Choose the X and Y training parameters by viewing angle. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.predX.shape
        (40, 1)
        >>> data.predX[3][0]
        0.8461538461538461
        >>> data.predX[-1][0]
        11.0
        """
        N = 40
        self.X = np.arange(0, self.Nobs, 1)
        self.Y = np.array(self.time_sliced_normed.iloc[0])
        self.X = self.X.reshape(len(self.X), 1)
        self.Y = self.Y.reshape(len(self.Y), 1)
        self.predX = np.linspace(0,self.Nobs,N).reshape(N, 1)
        return None
    
    def set_kernel(self, kernel):
        """Sets the kernel of the function
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> type(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        <class 'GPy.kern.src.rbf.RBF'>
        >>> type(data.kernel)
        <class 'GPy.kern.src.rbf.RBF'>
        """
        self.kernel = kernel
        return None
    
    def set_model(self, model): 
        """Set the model for the gaussian process before training. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> type(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        <class 'GPy.models.gp_regression.GPRegression'>
        >>> type(data.model)
        <class 'GPy.models.gp_regression.GPRegression'>
        """
        self.model = model
        return None
    
    def set_predX(self, predX, include_like = False):
        """Sets the trianing data set. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> data.predX[0]
        array([0.])
        >>> data.predX[10]
        array([1.11111111])
        >>> data.predX[90]
        array([10.])
        >>> data.medianCov
        2.123642889e-05
        >>> data.median
        0.0046083
        """
        self.predX = predX
#         currMean, currCov = self.model.predict(self.predX,  full_cov=True, include_likelihood = include_like)
#         self._currMean, self._currCov = currMean, currCov
        self.median = np.median(self.time_sliced)
        self.medianCov = self.median ** 2
        return None

    
    def model_train(self, verbose = False, optimize_method = "lbfgs"):
        """Model training
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.model.rbf.lengthscale[0]
        2.0
        >>> data.model_train()
        >>> data.model.rbf.lengthscale[0]
        5.104570976151195
        """
        if verbose:
            display(self.model)
        else:
            pass
           
        self.model.optimize(optimizer = optimize_method)
        
#         self.model.optimize_restarts(verbose=False)
        
        if verbose:
            print(self.model)
        return None

        
    def plot_prior(self, manual = False, sig = 1, randomDraws = True, title = None):
        """Plot the prior distribution with random draws. (Automatic plots the untrained posterior)
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> data.plot_prior(manual = True, sig = 2)
        >>> plt.close()
        >>> data.kernel.lengthscale[0]
        2.0
        >>> data.model.Gaussian_noise.variance[0]
        1.0
        >>> data.cov[0,0]
        2.0
        >>> data.cov[19,21]
        1.9876923466528822
        >>> data.cov[21,19]
        1.9876923466528822
        >>> data.cov[1,89]
        1.2910493813265383e-05
        """
        
        plotX = self.predX.reshape(1, len(self.predX))[0]
        
        if manual:
            predY_mean, prdY_cov = self.model.predict(self.predX,  full_cov=True, include_likelihood = False)
            cov = self.kernel.K(self.predX)

        else:
            predY_mean, cov = self.model.predict(self.predX,  full_cov=True, include_likelihood = False)
            
        
        var = np.diag(cov)
        mean_arr = predY_mean.reshape(1, len(predY_mean))[0]
        plotY = mean_arr

        F = np.random.multivariate_normal(mean_arr, cov, size = 9) 

        plt.figure(dpi = 300)
        numVar = sig * np.sqrt(var)

        plt.fill_between(plotX, plotY + numVar, plotY - numVar, alpha = 0.1, label = f"Confidence: {sig}"r"$\sigma$")
        plt.plot(plotX, plotY, label = "Mean", color = "Blue")

        if randomDraws:
            plt.plot(self.predX, F[0], color = "red", linewidth = 0.3, label = "Random Draws")
            plt.plot(self.predX, F[1], color = "red", linewidth = 0.3)
            plt.plot(self.predX, F[2], color = "red", linewidth = 0.3) 
            plt.plot(self.predX, F[3], color = "red", linewidth = 0.3)
            plt.plot(self.predX, F[4], color = "red", linewidth = 0.3) 
            plt.plot(self.predX, F[5], color = "red", linewidth = 0.3)
            plt.plot(self.predX, F[6], color = "red", linewidth = 0.3) 
            plt.plot(self.predX, F[7], color = "red", linewidth = 0.3)
            plt.plot(self.predX, F[8], color = "red", linewidth = 0.3) 
            
        plt.xlabel(r"Viewing Angle ($\Phi$)")
        plt.ylabel("Deviation from Median")
        plt.scatter(self.X, self.Y, color = "green", facecolors='none', label = "Training Data")
        if title:
            plt.title(title)
        else:
            plt.title("GP Prior Distribution with random draws")
        utkarshGrid()
        plt.legend()
        self.cov = cov
    
    def plot_posterior(self, manual = False, sig = 1, randomDraws = True, include_like = False):
        """ Plot trained data after optimizing gaussian process. 
        >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> data.model_train(verbose = False)
        >>> data.plot_posterior(manual = True)
        >>> plt.close()
        >>> data.kernel.lengthscale[0]
        5.104570976151195
        >>> data.model.Gaussian_noise.variance[0]
        0.0343451657295366
        >>> data.cov[0,0]
        2.1414049353615357e-05
        >>> data.cov[19,21]
        2.1346097498941584e-05
        >>> data.cov[21,19]
        2.1346097498941584e-05
        >>> data.cov[1,89]
        2.122508391774732e-05
        >>> data.posterior_mean[0]
        0.003982024942205805
        >>> data.posterior_mean[50]
        0.004937711001749544
        >>> np.isclose((data.posterior_mean[-1])/data.median - 1, 0.1718569467329627)
        True
        """
        
        plotX = self.predX.reshape(1, len(self.predX))[0]
        
        if manual:
            trainX = self.X
            kernel = self.kernel
            noise = self.model.Gaussian_noise.variance
            bracket = kernel.K(trainX, trainX) + noise*np.identity(len(trainX))
            bracket_inv = np.linalg.inv(bracket)
            A = kernel.K(self.predX, trainX)
            B = bracket_inv
            C = kernel.K(trainX, self.predX)
            posterior_cov = kernel.K(self.predX, self.predX) - np.linalg.multi_dot([A,B,C])
            posterior_mean = np.linalg.multi_dot([kernel.K(self.predX, trainX), bracket_inv, self.Y])
            posterior_mean = posterior_mean.reshape(1,len(posterior_mean))[0]
            posterior_mean = np.array(posterior_mean, dtype=float)
            predY_mean = posterior_mean
            cov = posterior_cov

        else:
            self.kernel = self.kernel 
            if include_like:
                predY_mean, cov = self.model.predict(self.predX,  full_cov=True, include_likelihood = True)
            else:
                predY_mean, cov = self.model.predict(self.predX,  full_cov=True, include_likelihood = False)
        
        self.unnormedY = self.Y
        cov = self._undoCovNorm(cov)
        predY_mean = self._undoNormedArr(predY_mean)
        self.unnormedY = self._undoNormedArr(self.Y)
        var = np.diag(cov)
        mean_arr = predY_mean.reshape(1, len(predY_mean))[0]
        plotY = mean_arr

        F = np.random.multivariate_normal(mean_arr, cov, size = 3) 

        plt.figure(dpi = 300)
        numVar = sig * np.sqrt(var)

        plt.fill_between(plotX, plotY + numVar, plotY - numVar, alpha = 0.1, label = f"Confidence: {sig}"r"$\sigma$")
        plt.plot(plotX, plotY, label = "Mean", color = "Blue")

        if randomDraws:
            plt.plot(self.predX, F[0], color = "red", linewidth = 0.3, label = "Random Draws")
            plt.plot(self.predX, F[1], color = "red", linewidth = 0.3)
            plt.plot(self.predX, F[2], color = "red", linewidth = 0.3) 
        
        
        plt.xlabel(r"Viewing Angle ($\Phi$)")
        plt.ylabel(r"Flux $Erg s^{-1} cm^{-2}A^{-1}$")
        plt.scatter(self.X, self.unnormedY, color = "green", facecolors='none', label = "Training Data")
        plt.title("GP Posterior Distribution with random draws")
        utkarshGrid()
        plt.legend()
        self.cov = cov
        self.posterior_mean = predY_mean
    
    def plot_covariance(self):
        """ Plot the current saved covariance matrix. 
                >>> data = GP("reference.csv")
        >>> phi_range = [60]
        >>> mejdyn_range = [0.02]
        >>> mejwind_range = [0.11]
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, 900)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> data.model_train(verbose = False)
        >>> data.plot_posterior(manual = False)
        >>> plt.close()
        >>> data.plot_covariance()
        >>> plt.close()
        >>> np.isclose(data.cov[0,0], 2.1414049384491493e-05)
        True
        >>> np.isclose(data.cov[19,21], 2.1346097523916208e-05)
        True
        >>> np.isclose(data.cov[21,19], 2.1346097523916208e-05)
        True
        >>> np.isclose(data.cov[1,89], 2.122508391318888e-05)
        True
        >>> np.isclose(data.posterior_mean[0], 0.00398203)[0]
        True
        >>> np.isclose(data.posterior_mean[50], 0.00493771)[0]
        True
        >>> np.isclose(data.posterior_mean[-1], 0.00540027)[0]
        True
        """
        plt.figure(figsize = (4,4), dpi = 150)
        plt.imshow(self.cov, cmap = "inferno", interpolation = "none")
        plt.colorbar()
        plt.title(f"Covarance Matrix between {len(self.cov)} sampled points", fontsize = 10)
        
        
    def LOOCV(self, manual = True, include_like = True):
        """Leave-One-Out Cross Validation of selected dataset. Cross validating by viewing angle. 
        >>> data = GP("reference.csv")
        >>> phi_range = [45]
        >>> mejdyn_range = [0.01]
        >>> mejwind_range = [0.11]
        >>> wv = 900
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, wv)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> compareVA = data.viewingangle
        >>> kerLen = data.kernel.lengthscale[0]
        >>> compareM = data.model.Gaussian_noise.variance[0]
        >>> data.LOOCV()
        >>> np.isclose(data.looMean[0], 0.0025592363242248643)
        True
        >>> np.isclose(data.looMean[6], 0.00478646958190752)
        True
        >>> np.isclose(data.sigmaList[1], 0.00420134)[0]
        True
        >>> np.isclose(data.sigmaList[5], 0.00419738)[0]
        True
        >>> np.allclose(np.array(compareVA, dtype=float), np.array(data.viewingangle, dtype=float))
        True
        >>> np.isclose(data.looList[2], -0.072787432912073)
        True
        >>> np.isclose(data.looList[7], -0.04193157845918918)
        True
        >>> print(kerLen)
        2.0
        >>> np.isclose(data.kernel.lengthscale[0], 10.237731281822857)
        True
        >>> print(compareM)
        1.0
        >>> np.isclose(data.model.Gaussian_noise.variance[0], 0.019581516431695353)
        True
        
        """
        
        self.looMean = []
        self.sigmaList = []
        self.range_select_wavelength(self.phi_range, self.mejdyn_range, self.mejwind_range, self.wv_range)
        tempViewingAngle = self.viewingangle
        tempKernel = self.kernel
        tempModel = self.model
        tempX = self.X.copy()
        tempY = self.Y.copy()
        compareY = self.Y.copy()
        tempPredX = self.predX
        
        if manual:
            for i in range(self.Nobs):
                self.viewingangle = tempViewingAngle
                self.single_time_step(self.extraction_time)
                self.normedDF()
                self.setXY_viewingangle()
                self.set_kernel(tempKernel)
                self.set_model(tempModel)
                test_pointX = np.array([tempX[i]])
                test_pointY = np.array([tempY[i]])
                self.X = np.delete(tempX, i)
                self.X = self.X.reshape(len(self.X), 1)
                self.Y = np.delete(tempY, i)
                self.set_predX(tempPredX)
                self.model_train(verbose = False)
                self.X = tempX
                
                mean, var = self.model.predict(test_pointX, include_likelihood = include_like)
                var = self._undoCovNorm(var)
                sigma = np.sqrt(var)
                mean, sigma = mean[0], sigma[0]
                mean = self._undoNormedArr(mean)[0]
                self.looMean.append(mean)
                self.sigmaList.append(sigma)

        
            tempY.T[0] = self._undoNormedArr(tempY.T[0])
            tempY.T[0] = np.array(tempY.T[0], dtype=float)
            
            compareY = np.array(compareY, dtype = float)
            tempY = np.array(tempY, dtype = float)
            
            self.Y = compareY
            
            arr1 = np.array(tempY.T[0], dtype=float)
            arr2 = np.array(self.time_sliced.to_numpy()[0], dtype=float)
            assert(np.allclose(arr1, arr2))
            self.looList = (self.looMean - tempY.T[0])/np.array(self.sigmaList).T[0]
        
        else:
            gp = self
            gp.normedDF()
            gp.setXY_viewingangle()
            gp.set_predX(self.set_predX(self.predX))
            var,mean = gp.model.predict(gp.X)
            gp.model_train(verbose = False)
            var,mean = gp.model.predict(gp.X)
            var = gp._undoCovNorm(var)


            #Calculate LOO
            loos  = gp.model.inference_method.LOO(gp.kernel, gp.X, gp.Y, gp.model.likelihood, gp.model.posterior)
            loo_error = np.sum(loos)
            print(f"Leave one out density: {loo_error}")
            plt.figure(figsize = (6,3), dpi = 300)
            plt.scatter(gp.X, gp._undoNormedArr((gp.Y - loos)/np.sqrt(var)), facecolor = "none", edgecolor = "dodgerblue")
            plt.xlabel("Viewing Angle Left Out")
            plt.ylabel(r"Leave-One-Out Accuracy $\sigma$")
            plt.title("GPy LOO (Possibly not in Sigma Units)")
            utkarshGrid()
        
        self.viewingangle = tempViewingAngle 
        self.set_kernel(tempKernel)
        self.set_model(tempModel)
        return None
    
    def plot_loocv(self, plot_type = "single"):
        """ Plot of the errors in the LOO optimization. 
        >>> data = GP("reference.csv")
        >>> phi_range = [45]
        >>> mejdyn_range = [0.01]
        >>> mejwind_range = [0.11]
        >>> wv = 900
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, wv)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> data.LOOCV()
        >>> data.plot_loocv()
        >>> plt.close()
        >>> np.isclose(data.plot_loocv_limitY, 0.2381935005242566)
        True
        """
        if plot_type == "multiple":
            self.plotLooList = np.mean(self.loo_list_multiple, axis = 0)
        elif plot_type == "single":
            self.plotLooList = self.looList
        else:
            print("[ERROR] Plot type not selected")
        
        plt.figure(figsize = (6,3), dpi = 300)
        plt.scatter(self.X.T[0], self.plotLooList, facecolor = "none", edgecolor = "dodgerblue")
        plt.xlabel("Viewing Angle Left Out")
        plt.ylabel(r"Accuracy (Units $\sigma$)")
        plt.title("Leave-One-Out Cross Validation")
        utkarshGrid()
        limitY = max(abs(min(self.plotLooList))*1.1, abs(max(self.plotLooList))*1.1)
        self.plot_loocv_limitY = limitY
        plt.ylim(-limitY, limitY)
        return None
        
    def plot_loocv_simple(self, include_like = True):
        """ Simple Plot of the LOOCV on the posterior distribution. 
        >>> data = GP("reference.csv")
        >>> phi_range = [45]
        >>> mejdyn_range = [0.01]
        >>> mejwind_range = [0.11]
        >>> wv = 900
        >>> data.range_select_wavelength(phi_range, mejdyn_range, mejwind_range, wv)
        >>> data.single_time_step(1)
        >>> data.normedDF()
        >>> data.setXY_viewingangle()
        >>> data.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
        >>> data.set_model(GPy.models.GPRegression(data.X, data.Y, data.kernel))
        >>> data.set_predX(np.linspace(0,data.Nobs,100).reshape(100, 1))
        >>> data.LOOCV()
        >>> data.plot_loocv_simple()
        >>> plt.close()
        >>> data.model.Gaussian_noise.variance[0]
        0.019581516431695353
        >>> data.cov[0,0]
        1.7704562471056736e-05
        >>> data.cov[19,21]
        1.7286183692651927e-05
        >>> data.cov[21,19]
        1.7286183692651927e-05
        >>> data.cov[1,89]
        1.722470910238539e-05
        >>> np.isclose(data.posterior_mean[0], 0.00255924)[0]
        True
        >>> np.isclose(data.posterior_mean[50], 0.00459252)[0]
        True
        """
        self.plot_posterior(include_like = True)
        plt.scatter(self.X, self.looMean, label = "Loo Prediction", color = "black", marker='X', zorder = 1)
        plt.legend()
        return None
    
    def multiple_LOOCV(self, N = 3, verbose = True):
        """ Multiple Gaussian Processes on different datasets to apply Leave-One-Out CV. 
        >>> data = GP("reference.csv")
        >>> data.multiple_LOOCV(N = 5, verbose = False)
        >>> data.loo_list_multiple.shape
        (5, 11)
        >>> data.phi_range_list
        [45, 15, 75, 75, 60]
        >>> data.mejdyn_range_list
        [0.001, 0.01, 0.001, 0.005, 0.02]
        >>> data.wv_range_list
        [900, 900, 900, 900, 900]
        >>> np.isclose(data.loo_list_multiple[0,1], 4.767686375520658e-10)
        True
        >>> np.isclose(data.loo_list_multiple[1,0], 0.06800734333543058)
        True
        >>> np.isclose(data.loo_list_multiple[4,3], 0.03651750786814857)
        True
        >>> np.isclose(data.kernel.lengthscale[0], 5.104570976151195)
        True
        >>> np.isclose(data.model.Gaussian_noise.variance[0], 0.0343451657295366)
        True
        """
        self.phi_range_list = list(self.reference.phi)[0:N]
        self.mejdyn_range_list = list(self.reference.mejdyn)[0:N]
        self.mejwind_range_list = list(self.reference.mejwind)[0:N]
        self.wv_range_list = [900] * N

        self.loo_list_multiple = []
        for i in tqdm(range(N), disable= not verbose):
            tempSelected = self.selected

            self.range_select_wavelength([self.phi_range_list[i]], 
                                       [self.mejdyn_range_list[i]], 
                                       [self.mejwind_range_list[i]], 
                                       self.wv_range_list[i])
            if self.Nobs != 11:
                if verbose:
                    print(f"[STATUS] File selected at {self.selected.filename.iloc[0]} \ndoes not have the correct number of viewing angles. Skipping...")
                self.selected = tempSelected
                continue
            self.single_time_step(self.extraction_time)
            self.normedDF()
            self.setXY_viewingangle()
            self.set_kernel(GPy.kern.RBF(input_dim=1, variance = 2, lengthscale=2))
            self.set_model(GPy.models.GPRegression(self.X, self.Y, self.kernel))
            self.LOOCV()
            self.loo_list_multiple.append(self.looList)
            self.selected = tempSelected

        self.loo_list_multiple = np.array(self.loo_list_multiple)
        return None
    
    def gaussian(self, x,x0,sigma):
        return np.exp(-np.power((x - x0)/sigma, 2.)/2.)
    
    def plot_loocv_histogram(self, edge = 2.5, mu = 0, sigma = 1, binning = 30):
        fig, ax = plt.subplots(dpi = 300)
        utkarshGrid()
        
        hist_arr = self.loo_list_multiple.flatten()
        hist_arr = hist_arr[np.isfinite(hist_arr)]
        hist_arr = hist_arr[hist_arr < 5]
        hist_arr = hist_arr[hist_arr > -5]
        print(f"Inside 5x: {len(hist_arr)}, Total: {len(self.loo_list_multiple.flatten())}")
            
        df = pd.DataFrame(hist_arr, columns = ["hist"])
        
        if not self.empirical:
            x_gauss = np.linspace(-edge, edge, 100, endpoint = True)
            y_gauss = self.gaussian(x_gauss, mu, sigma)
            plt.plot(x_gauss, y_gauss, label = "Unit Gaussian", color ="purple", zorder = 3)
        
        df.plot.hist(density = True, bins = binning, ax = ax, label = "Count", 
                     facecolor = '#2ab0ff', edgecolor='#169acf', zorder = 1)
        df.plot.kde(ax = ax, label = "LOO Distribution", alpha = 1, zorder = 2)
        plt.ylabel("Count Intensity")
           
        if self.empirical:
            ax.legend(["LOO Distribution ", "Count"])
            ax.set_title(r"Ratio = $\frac{Truth - Predictive}{Truth}$")
            
            if self.isLog:
                plt.xlabel(r"Deviation Error (Units Log Flux)")
            else:
                plt.xlabel(r"Deviation Error (Units Flux)")
        else:
            plt.xlabel(r"Deviation Error (Units $\sigma$)")
            ax.legend(["LOO Distribution ", "Unit Gaussian", "Count"])
            ax.set_title(r"Ratio = $\frac{{Truth - Predictive}}{\sigma}$")
            
        ax.set_ylim(bottom=-0.1)
    
    def plot_hist_lengthscale(self, arr_hist, typ0 = None, tol = 5):
        """ Plots the histogram distribution of the lengthscales. 
        """
        temp_hist = arr_hist
        arr_hist = arr_hist[arr_hist < tol]
        
        if not typ0:
            typ0 = "UNSPECIFIED"
            

        print(f"Dimension {typ0}: {round(100*len(arr_hist)/len(temp_hist))}% within lengthscale {tol}.")

        fig, ax = plt.subplots(dpi = 100)
        hist_arr = arr_hist.flatten()
        utkarshGrid()
        df = pd.DataFrame(hist_arr, columns = ["hist"])
        df.plot.hist(density = True, bins = 20, ax = ax, label = "Count", 
                     facecolor = '#23de6b', edgecolor='#18b855', zorder = 1)

        try:
            df.plot.kde(ax = ax, label = "Lengthscale Distribution", alpha = 1, zorder = 2)

        except:
            if np.max(arr_hist) == np.min(arr_hist):
                print(f"[WARNING] Are all values of this array are the same!")
            else:
                print("[ERROR] Something went wrong")
        plt.ylabel("Count Intensity")
        ax.legend(["Lengthscale Distribution", "Count"])
        plt.xlabel(r"Lengthscale")
        ax.set_ylim(bottom=-1e-9)
        return None

In [7]:
class GP2D(GP):
    """The Gaussian Process for KNe Light Curves using GPy. This class handles the two dimensional case.  
    """
    
    def __init__(self, referenceName):
        """ Instantiates class of both Gaussian Process and KNe Light Curve
        >>> gp = GP2D("reference.csv")
        >>> isinstance(gp, GP)
        True
        >>> isinstance(gp, LightCurve)
        True
        >>> isinstance(gp, float)
        False
        >>> isinstance(gp, GP2D)
        True
        """
        GP.__init__(self, referenceName)
        return None
    
    def set_selection_range(self, typ, phi_range = [], mejwind_range = [], mejdyn_range = [], wv_range = [900], verbose = False):
        """ 
        Selection method that converts the range desired into a two dimensional training vector. 
        First dimension is the viewing angle, second dimension is selected type. 
        The data vector returned is not normalized. 
        
        >>> data = GP2D("reference.csv")
        >>> phi_range = [60]
        >>> mejwind_range = [0.05]
        >>> mejdyn_range = [0.01]
        >>> wv_range = [1500]
        >>> data.set_selection_range(typ = "mejdn", phi_range = phi_range, mejwind_range = mejwind_range, mejdyn_range = mejdyn_range,wv_range = wv_range, verbose = False)
        [ERROR] Incorrect selection of 2D parameters. Try Again...
        >>> data = GP2D("reference.csv")
        >>> phi_range = [60]
        >>> mejwind_range = [0.05]
        >>> mejdyn_range = [0.01]
        >>> wv_range = [1500]
        >>> data.set_selection_range(typ = "mejwind", phi_range = phi_range, mejwind_range = mejwind_range, mejdyn_range = mejdyn_range,wv_range = wv_range, verbose = False)
        """
        self.typ = typ
        self.wv_range = wv_range
        self.phi_range = phi_range
        self.mejwind_range = mejwind_range
        self.mejdyn_range = mejdyn_range
        self.wv = self.wv_range[0]
        
        if typ == "mejdyn":
            self.mejdyn_range = []

        elif typ == "mejwind":
            self.mejwind_range = []

        elif typ == "phi":
            self.phi_range = []

        else:
            print("[ERROR] Incorrect selection of 2D parameters. Try Again...")
            return None


        self.select_curve(phiRange = self.phi_range, 
                          mejwindRange = self.mejwind_range,
                          mejdynRange = self.mejdyn_range)

        if typ == "mejdyn":
            curr_range_list = self.selected.mejdyn.unique()

        elif typ == "mejwind":
            curr_range_list = self.selected.mejwind.unique()

        elif typ == "phi":
            curr_range_list = self.selected.phi.unique()

        else:
            curr_range_list = None
            print("[ERROR] Incorrect selection of 2D parameters. Try Again...")

        self.curr_range_list = curr_range_list
        self.curr_range_list = np.sort(self.curr_range_list)

        df = None
        temp_curr_range_list = curr_range_list
        for i in tqdm(range(len(curr_range_list)), disable = not verbose):
            tempSelected = self.selected

            if typ == "mejdyn":
                self.mejdyn_range = [curr_range_list[i]]

            elif typ == "mejwind":
                self.mejwind_range = [curr_range_list[i]]

            elif typ == "phi":
                self.phi_range = [curr_range_list[i]]

            else:
                print("[ERROR] Incorrect selection of 2D parameters. Try Again...")
                return None

            self.range_select_wavelength(self.phi_range, self.mejdyn_range, self.mejwind_range, self.wv)
            
            
#             if self.Nobs != 11:
#                 if verbose: 
#                     print(f"[ERROR] File selected at {self.selected.filename.iloc[0]} \ndoes not have the correct number of viewing angles. Skipping...")
#                 self.selected = tempSelected
#                 temp_curr_range_list = np.delete(temp_curr_range_list, i)
#                 continue

            self.single_time_step(self.extraction_time) # Want distribution at 1 day
            temp_df = self.time_sliced
            df = pd.concat([df, temp_df])
            self.selected = tempSelected
        
        
        if df.shape[1] == 1:
            repeats = 11
            df = pd.concat([df]*repeats, axis = 1)
            self.Nobs = 11
        
        curr_range_list = temp_curr_range_list
        df.index = curr_range_list
        df.index.name = self.typ
        df.reset_index(drop = True)
        df.fillna(axis=1, method='ffill', inplace = True)
        self.Nobs = 11
        epsilon = 1e-15
        self.iobs_range = np.linspace(0, self.Nobs-1, self.Nobs, endpoint = True)
        df = df.add(epsilon)
        self.training2D = df.sort_index()
        self.curr_range_list = curr_range_list
        self.curr_range_list = np.sort(self.curr_range_list)
        return None
    
    def normalize_training2D(self):
        """ Normalize the two dimensional training vector using median normalization. 
        >>> data = GP2D("reference.csv")
        >>> phi_range = [60]
        >>> mejwind_range = [0.05]
        >>> mejdyn_range = [0.01]
        >>> wv_range = [1500]
        >>> data.set_selection_range(typ = "mejdyn", phi_range = phi_range, mejwind_range = mejwind_range, mejdyn_range = mejdyn_range,wv_range = wv_range, verbose = False)
        >>> data.normalize_training2D()
        >>> checkDF = data.training2D
        >>> mat1 = checkDF.to_numpy(dtype = float)
        >>> mat2 = data._undoNormedDF(data.training2D_normalized).to_numpy(dtype = float)
        >>> np.allclose(mat1, mat2)
        True
        """
        self.training2D_normalized = self._normedDF(self.training2D)
        return None
        
    def unnormalize_training2D(self, verbose = True):
        """ Normalize the two dimensional training vector using median normalization. 
        >>> gp = GP2D("reference.csv")
        >>> phi_range = [30]
        >>> mejwind_range = [0.05]
        >>> mejdyn_range = [0.01]
        >>> wv_range = [900]
        >>> gp.set_selection_range(typ = "mejwind", phi_range = phi_range,mejwind_range = mejwind_range, mejdyn_range = mejdyn_range, wv_range = wv_range, verbose = False)
        >>> gp.setXY()
        >>> gp.set_kernel(GPy.kern.RBF(input_dim=2, variance = 1, lengthscale=10))
        >>> gp.set_model(GPy.models.GPRegression(gp.X,gp.Y,gp.kernel))
        >>> gp.model_train(verbose = False)
        >>> gp.model_predict(N = 5)
        >>> gp.unnormalize_training2D()
        [ERROR] Data has not been normalized, so it cannot be unnormalized
        """
        
        if verbose:
            try:
                self.median
            except:
                print("[ERROR] Data has not been normalized, so it cannot be unnormalized")
                return None
            
        self.posterior_mean = self._undoNormedArr(self.posterior_mean)
        self.posterior_cov = self._undoCovNorm(self.posterior_cov)
        return None
    
    def _normedDF(self, df):
        med = np.median(df)
        self.median = med
        self.medianCov = self.median**2
        return df.divide(med) - 1
    
    def _undoNormedDF(self, df):
        return (df + 1).multiply(self.median)
    
    def _check_normalization_fixY(self):
        try:
            self.training2D_normalized
        except:
            self.training2D_normalized = self.training2D
            
        return None
    
    def setXY(self):
        """ Set X and Y training vectors in the format of GPy
        
        >>> data = GP2D("reference.csv")
        >>> phi_range = [60]
        >>> mejwind_range = [0.05]
        >>> mejdyn_range = [0.01]
        >>> wv_range = [1500]
        >>> data.set_selection_range(typ = "mejdyn", phi_range = phi_range, mejwind_range = mejwind_range, mejdyn_range = mejdyn_range,wv_range = wv_range, verbose = False)
        >>> data.normalize_training2D()
        >>> X1 = data.iobs_range
        >>> X2 = np.array(data.curr_range_list)
        >>> Y = data.training2D_normalized.to_numpy(dtype = float).flatten()
        >>> Y = Y.reshape(len(Y), 1)
        >>> _X1, _X2 = np.meshgrid(X1, X2)
        >>> _X1 = _X1.flatten()
        >>> _X1 = _X1.reshape(len(_X1), 1)
        >>> _X2 = _X2.flatten()
        >>> _X2 = _X2.reshape(len(_X2), 1)
        >>> X = np.hstack([_X1, _X2])
        >>> data.setXY()
        >>> np.allclose(X, data.X)
        True
        >>> np.allclose(Y, data.Y)
        True
        >>> np.allclose(X1, data.X1)
        True
        >>> np.allclose(X2, data.X2)
        True
        """
        X1 = self.iobs_range
        X2 = np.array(self.curr_range_list)
        
        if self.normalizeX:
            normalizedX1 = (X1-min(X1))/(max(X1)-min(X1))
            normalizedX2 = (X2-min(X2))/(max(X2)-min(X2))
            
            X1, X2 = normalizedX1, normalizedX2
        
        
        self._check_normalization_fixY()
        
        Y = self.training2D_normalized.to_numpy(dtype = float)
            
        Y = Y.flatten()
        Y = Y.reshape(len(Y), 1)
        _X1, _X2 = np.meshgrid(X1, X2)
        _X1 = _X1.flatten()
        _X1 = _X1.reshape(len(_X1), 1)
        _X2 = _X2.flatten()
        _X2 = _X2.reshape(len(_X2), 1)
        X = np.hstack([_X1, _X2])
        
        self.Y = Y
        self.X1 = X1
        self.X2 = X2
        self.X = X
        return None
    
    def model_predict(self, N = 50, include_like = True, make_cov = True, same_dimension = False):
        curr_range_list = self.curr_range_list
        
        
        PredX1 = np.linspace(min(self.iobs_range), max(self.iobs_range), N, endpoint = True)
        PredX2 = np.linspace(min(curr_range_list), max(curr_range_list), N, endpoint = True)
        
        if self.normalizeX:
            normalizedX1 = (PredX1-min(PredX1))/(max(PredX1)-min(PredX1))
            normalizedX2 = (PredX2-min(PredX2))/(max(PredX2)-min(PredX2))
            
            PredX1, PredX2 = normalizedX1, normalizedX2
            
        _X1Pred, _X2Pred = np.meshgrid(PredX1, PredX2)
        _X1Pred = _X1Pred.flatten()
        _X1Pred = _X1Pred.reshape(len(_X1Pred), 1)
        _X2Pred = _X2Pred.flatten()
        _X2Pred = _X2Pred.reshape(len(_X2Pred), 1)
        predX = np.hstack([_X1Pred, _X2Pred])
        
        
        mean, cov = self.model.predict(predX, full_cov = make_cov, include_likelihood = include_like)
        self.posterior_mean = mean
        self.posterior_cov = cov
        self.predX1 = PredX1
        self.predX2 = PredX2
        self.predX = predX
        self.N = N
        return None
    
    def plot_covariance2D(self):
        plt.figure(figsize = (4,4), dpi = 150)
        plt.imshow(self.posterior_cov, cmap = "inferno", interpolation = "none")
        plt.colorbar()
        plt.title(f"Covarance Matrix between {len(self.posterior_cov)} sampled points", fontsize = 10)
        return None
    
    def plot_posterior2D(self, verbose = False, lev= 20):
        plt.figure(dpi = 200)
        plt.tight_layout()
        Z = self.posterior_mean.reshape(self.N, self.N)
        contours = plt.contourf(self.predX1, self.predX2, Z, cmap = "plasma", levels = lev)
        # plt.clabel(contours, inline=True, fontsize=9, colors = "white")
        plt.colorbar()
        plt.xlabel("Viewing Angle")
        
        if self.typ == "mejdyn":
            plt.ylabel("Dynamical Ejecta Mass")

        elif self.typ == "mejwind":
            plt.ylabel("Wind Ejecta Mass")

        elif self.typ == "phi":
            plt.ylabel(r"Half Opening Angle $\Phi$")

        else:
            print("[ERROR] Incorrect selection of 2D parameters. Try Again...")
            
        if verbose:
            if self.typ == "mejdyn":
                print(f"[STATUS] Plotting for: \n[STATUS] mejdyn: {self.curr_range_list} \n[STATUS] mejwind: {self.mejwind_range} \n[STATUS] phi: {self.phi_range} \n[STATUS] viewing_angle: {self.iobs_range} \n[STATUS] wavelength: {self.wv_range}")

            elif self.typ == "mejwind":
                print(f"[STATUS] Plotting for: \n[STATUS] mejdyn: {self.mejdyn_range} \n[STATUS] mejwind: {self.curr_range_list} \n[STATUS] phi: {self.phi_range} \n[STATUS] viewing_angle: {self.iobs_range} \n[STATUS] wavelength: {self.wv_range}")

            elif self.typ == "phi":
                print(f"[STATUS] Plotting for: \n[STATUS] mejdyn: {self.mejdyn_range} \n[STATUS] mejwind: {self.mejwind_range} \n[STATUS] phi: {self.curr_range_list} \n[STATUS] viewing_angle: {self.iobs_range} \n[STATUS] wavelength: {self.wv_range}")
        return None
    
    def log_trainingND(self):
        """
        >>> data = GP2D("reference.csv")
        >>> phi_range = [45]
        >>> mejwind_range = []
        >>> mejdyn_range = [0.01]
        >>> wv_range = [900]
        >>> data.set_selection_range(typ = "mejwind", phi_range = phi_range, mejwind_range = mejwind_range, mejdyn_range = mejdyn_range, wv_range = wv_range, verbose = False)
        >>> arr1 = np.log10(data.training2D.to_numpy(dtype = float))
        >>> data.log_trainingND()
        >>> arr2 = data.training2D.to_numpy(dtype = float)
        >>> np.allclose(arr1, arr2)
        True
        """
        self.training2D = self.training2D.applymap(np.log10)
        self.isLog = True
        return None
    
    def set_normalizeX(self):
        self.normalizeX = True
        return None
    
    def LOOCV_2D(self, include_like = True, make_cov = False, verbose = True):
        """ Leave one out cross-validation for two-dimensional case. 
        >>> import warnings
        >>> warnings.filterwarnings("ignore")
        >>> data = GP2D("reference.csv")
        >>> phi_range = [45]
        >>> mejwind_range = [0.03]
        >>> mejdyn_range = [0.01]
        >>> wv_range = [900]
        >>> data.set_selection_range(typ = "mejwind", phi_range = phi_range, mejwind_range = mejwind_range, mejdyn_range = mejdyn_range, wv_range = wv_range, verbose = False)
        >>> data.kernel = GPy.kern.RBF(input_dim=2, variance = 1, lengthscale=10, ARD = True)
        >>> data.LOOCV_2D(verbose = False)
        (True, 0)
        >>> np.isclose(data.Y.T[0][20], 0.02809886951660512)
        True
        >>> np.isclose(data.Y.T[0][60],  -0.1008036214554312)
        True
        >>> np.isclose(data.looList[30], -0.00494856)[0][0]
        True
        >>> np.isclose(data.looList[50], 0.00248408)[0][0]
        True
        >>> old_med = data.median
        >>> np.isclose(old_med, -2.447331783887685)
        True
        >>> data.tempY = (data.tempY + 1)*old_med 
        >>> arr1 = np.array(data.tempY.T[0], dtype=float)
        >>> arr2 = np.array(data.training2D.to_numpy().flatten(), dtype=float)
        >>> np.allclose(arr1, arr2)
        True
        >>> np.isclose(data.model.rbf.lengthscale[0], 10.0)
        True
        >>> np.isclose(data.model.rbf.lengthscale[1], 10.0)     
        True
        """
        failed = 0
        self.log_trainingND()
        tempTraining2D = self.training2D
        
        # Begin of example step
        self.normalize_training2D()
        self.set_normalizeX()
        self.setXY()
        self.set_model(GPy.models.GPRegression(self.X,self.Y,self.kernel))
        
        
        # Begin of LOO
        self.looMean = []
        self.sigmaList = []
        tempKernel = self.kernel.copy()
        tempModel = self.model.copy()
        originalX = self.X.copy()
        originalY = self.Y.copy()
        tempX = self.X.copy()
        tempY = self.Y.copy()
        self.looList = []
        self.looList_empirical = []
        self.lengthscaleList = []
        for i in tqdm(range(len(tempX)), disable = not verbose):
            test_pointX = np.array([tempX[i]])
            test_pointY = np.array([tempY[i]])
            self.X = np.delete(tempX, i, 0)
            self.Y = np.delete(tempY, i, 0)
            self.set_kernel(tempKernel.copy())
            self.set_model(GPy.models.GPRegression(self.X,self.Y,self.kernel))
            # self.model['.*lengthscale'].constrain_bounded(0,5)
            try:
                self.model_train(verbose = False, optimize_method = "lbfgs")
            except:
                failed += 1
                continue
#             print(self.model.rbf.lengthscale[0], self.model.rbf.lengthscale[1])
            self.lengthscaleList.append([self.model.rbf.lengthscale[0], self.model.rbf.lengthscale[1]])
            mean, var = self.model.predict(test_pointX, full_cov = make_cov, include_likelihood = include_like)
            mean = self._undoNormedArr(mean)[0]
            var = self._undoCovNorm(var)
            sigma = np.sqrt(var)
            mean, sigma = mean[0], sigma[0][0]
            
            test_pointY = (test_pointY + 1)*self.median # Unnormalize
            difference = mean - test_pointY
            
            self.looList.append(difference/sigma)
            self.looList_empirical.append(difference/test_pointY)
            
#             self.looMean.append(mean)
#             self.sigmaList.append(sigma)
            self.X, tempX = originalX.copy(), originalX.copy()
            self.Y, tempY = originalY.copy(), originalY.copy()
         

        self.looList = np.array(self.looList, dtype = float)
        self.looList_empirical = np.array(self.looList_empirical, dtype = float)
#         self.set_kernel(tempKernel)
#         self.set_model(tempModel)
        self.tempY = tempY
        self.originalY = originalY
        return self.isLog, failed 
    
    def multiple_LOOCV_2D(self, typ, verbose = True, trauncate = None, include_like = True, empirical = False):
        """
        >>> data = GP2D("reference.csv")
        >>> ref = data.reference
        >>> typ = "phi"
        >>> data.phi_range = [45]
        >>> data.mejdyn_range = [0.01]
        >>> data.mejwind_range = [0.05]
        >>> data.wv_range = [900]
        >>> data.kernel = GPy.kern.RBF(input_dim=2, variance = 1, lengthscale=10, ARD = True)
        >>> data.multiple_LOOCV_2D(typ, verbose = 0, trauncate = 2)
        >>> data.empirical
        False
        >>> arr1 = data.loo_list_multiple.flatten()
        >>> arr1 = arr1[~np.isnan(arr1)]
        >>> np.isclose(arr1[0], -0.03725916875580017, atol = 1e-04)
        True
        >>> np.isclose(arr1[54], 2.954963105890655)
        True
        >>> data.isLog
        True
        >>> data = GP2D("reference.csv")
        >>> ref = data.reference
        >>> typ = "phi"
        >>> data.phi_range = [45]
        >>> data.mejdyn_range = [0.01]
        >>> data.mejwind_range = [0.05]
        >>> data.wv_range = [900]
        >>> data.kernel = GPy.kern.RBF(input_dim=2, variance = 1, lengthscale=10, ARD = True)
        >>> data.multiple_LOOCV_2D(typ, verbose = 0, trauncate = 2, empirical = True)
        >>> arr2 = data.loo_list_multiple.flatten()
        >>> arr2 = arr2[~np.isnan(arr2)]
        >>> np.isclose(arr2[0], 0.04774528166115224, atol = 1e-04)
        True
        >>> np.isclose(arr2[54],  -0.7653357822915535)
        True
        >>> print(arr2.shape)
        (154,)
        >>> ref.equals(data.reference)
        True
        >>> ref.equals(data.selected)
        True
        >>> data.phi_range
        []
        >>> data.wv_range
        [900]
        >>> data.mejwind_range
        [0.05]
        >>> data.mejdyn_range
        [0.01]
        >>> data.empirical
        True
        >>> data.isLog
        True
        >>> data.lengthscaleList_multiple1.shape
        (154,)
        >>> data.lengthscaleList_multiple2.shape
        (154,)
        """
        self.loo_list_multiple = np.array([], dtype = float)
        self.lengthscaleList_multiple1 = np.array([], dtype = float)
        self.lengthscaleList_multiple2 = np.array([], dtype = float)

        if typ == "mejwind":
            a = self.reference.phi.unique()
            b = self.reference.mejdyn.unique()
            self.mejwind_range = []

        if typ == "phi":
            a = self.reference.mejwind.unique()
            b = self.reference.mejdyn.unique()
            self.phi_range = []

        if typ == "mejdyn":
            a = self.reference.mejwind.unique()
            b = self.reference.phi.unique()
            self.mejdyn_range = []


        curr_pair = list(product(a, b))
        self.counter = 0
        wv_range = self.wv_range
        
        if trauncate is None:
            loop_length = len(curr_pair)
            
        else:
            loop_length = trauncate
        
#         for i in tqdm(range(loop_length), disable = disableTQDM):
#             pass 
        
        def parallel_helper(self, i):   
            if typ == "mejwind":
                phi_range = [curr_pair[i][0]]
                mejwind_range = self.selected.mejwind.unique()
                mejdyn_range = [curr_pair[i][1]]

            if typ == "phi":
                phi_range = self.selected.phi.unique()
                mejwind_range = [curr_pair[i][0]]
                mejdyn_range = [curr_pair[i][1]]

            if typ == "mejdyn":
                phi_range = [curr_pair[i][1]]
                mejwind_range = [curr_pair[i][0]]
                mejdyn_range = self.selected.mejdyn.unique()
  
            self.set_selection_range(typ = typ, phi_range = phi_range, 
                                     mejwind_range = mejwind_range, 
                                     mejdyn_range = mejdyn_range,
                                     wv_range = wv_range, verbose = False)
            self.counter += 1
            
            if verbose > 10:
                looCV_verbose = True
            else:
                looCV_verbose = False
            
            self.isLog, self.failed = self.LOOCV_2D(verbose = looCV_verbose, include_like = include_like)
            
            if empirical:
                self.loo_list_multiple = np.append(self.loo_list_multiple, self.looList_empirical)
                
            else: 
                self.loo_list_multiple = np.append(self.loo_list_multiple, self.looList)
            
            self.selected = self.reference
            
            arr1 = np.array([x[0] for x in self.lengthscaleList], dtype = float)
            arr2 = np.array([x[1] for x in self.lengthscaleList], dtype = float)
            arr1 = arr1.flatten()
            arr2 = arr2.flatten()
            self.lengthscaleList_multiple1 = np.append(self.lengthscaleList_multiple1, arr1)
            self.lengthscaleList_multiple2 = np.append(self.lengthscaleList_multiple2, arr2)
            
            self.loo_list_multiple = np.array(self.loo_list_multiple, dtype = float)
            return self.loo_list_multiple, self.counter, self.lengthscaleList_multiple1, self.lengthscaleList_multiple2, self.isLog, self.failed
        
        self.results  = Parallel(n_jobs=8, verbose = verbose)(delayed(parallel_helper)(self, i) for i in range(loop_length))

        for x in self.results:
            self.loo_list_multiple = np.append(self.loo_list_multiple, x[0])
            self.lengthscaleList_multiple1 = np.append(self.lengthscaleList_multiple1, x[2])
            self.lengthscaleList_multiple2 = np.append(self.lengthscaleList_multiple2, x[3])
    
        self.counter = sum(np.array([x[1] for x in self.results], dtype = int))
        temp = np.array([x[4] for x in self.results], dtype = float)
        self.isLog = np.all(temp)
        self.failed = np.array([x[5] for x in self.results], dtype = int)
        
        self.empirical = empirical
        
        time.sleep(0.5)
        if verbose > 0:
            print(f"[STATUS] Used {self.counter}/{len(curr_pair)} items")
        
        return None

In [8]:
class GP5D(GP):
    """The Gaussian Process for KNe Light Curves using GPy. This class handles the five dimensional case.  
    """
    
    def __init__(self, referenceName):
        """ Instantiates class of both Gaussian Process and KNe Light Curve
        >>> gp = GP2D("reference.csv")
        >>> isinstance(gp, GP)
        True
        >>> isinstance(gp, LightCurve)
        True
        >>> isinstance(gp, float)
        False
        >>> isinstance(gp, GP2D)
        True
        >>> isinstance(gp, GP5D)
        True
        """
        GP2D.__init__(self, referenceName)
        return None
    
    

In [9]:
if __name__ == "__main__":
    initial = AllData()
    initial.load_path("/Users/utkarsh/PycharmProjects/SURP2021")
    initial.load_raw_data()
    initial.process()
    initial.save_reference()
    run_tests(True)

[STATUS] Reference Saved
**********************************************************************
File "__main__", line 360, in __main__.GP2D.LOOCV_2D
Failed example:
    np.isclose(data.Y.T[0][20], 0.02809886951660512)
Expected:
    True
Got:
    False
**********************************************************************
File "__main__", line 362, in __main__.GP2D.LOOCV_2D
Failed example:
    np.isclose(data.Y.T[0][60],  -0.1008036214554312)
Expected:
    True
Got:
    False
**********************************************************************
File "__main__", line 364, in __main__.GP2D.LOOCV_2D
Failed example:
    np.isclose(data.looList[30], -0.00494856)[0][0]
Expected:
    True
Got:
    False
**********************************************************************
File "__main__", line 366, in __main__.GP2D.LOOCV_2D
Failed example:
    np.isclose(data.looList[50], 0.00248408)[0][0]
Expected:
    True
Got:
    False
*********************************************************************