In [1]:
import numpy as np
np.random.seed(42)
np.set_printoptions(suppress=True)
from scipy.optimize import least_squares, curve_fit
from matplotlib import pylab as plt
import pandas as pd
import glob
import time
from tqdm.notebook import tqdm
import os
import random

import warnings
from scipy.optimize import OptimizeWarning
warnings.simplefilter("error", OptimizeWarning)

# import seaborn as sns
# sns.set()

In [2]:
def plot_allbands_df(df,title=None,ax=plt):
    ugrizYcolors = ["violet","green","red","darkred","grey","black"]
    for i,band in enumerate("ugrizY"):
        subdf = df[df["band"]==band].reset_index(drop=True)
        x = subdf["mjd"]
        y = subdf["flux"]
        yerr = subdf["fluxerr"]
        # ax.errorbar(x,y,yerr=yerr,ls = "None",color=ugrizYcolors[i])
        ax.scatter(x,y,s=1.2,label=band,color=ugrizYcolors[i])
        if title is not None:
            if ax is not plt:
                ax.set_title(title)
            else:
                ax.title(title)
    return ax

In [3]:
filels = glob.glob("csv_data/*.csv")

In [4]:
def give_detected_info(file,mismatchdict={},plot=False,print_errratiolessthan5=False,print_mismatchcount=False):
    title=file.split("/")[-1]
    title_class = title.split("_")[-1].split(".")[0]
    if title_class not in mismatchdict.keys():
        mismatchdict[title_class] = 0
    
    df = pd.read_csv(file)
    df["mjd"] = df["mjd"] - df["mjd"][0]
    imax = abs(df["flux"]/df["fluxerr"]).argmax()
    
    if print_errratiolessthan5 and abs(df["flux"]/df["fluxerr"]).max() < 5:
        print(title)
    
    mjdmax = df["mjd"][imax]
    df["sid_bool"]=0
    df.loc[(df["mjd"] >= mjdmax-100) & (df["mjd"] <= mjdmax+100),"sid_bool"] = 1
    mismatch_count = len(df.loc[df["detected_bool"] - df["sid_bool"] == 1])

    if mismatch_count > 0:
        mismatchdict[title_class] = mismatchdict[title_class] + 1
        if print_mismatchcount:
            print(f"Mismatch Count = {mismatch_count} for {file.split('/')[-1]}")
    
    if plot:
        subdf = df[df["sid_bool"] == 1].reset_index(drop=True)
        f, (ax1, ax2) = plt.subplots(2,1, sharex=True)
        x = df["mjd"]
        y = df["flux"]
        ax1.scatter(x,y,s=1.2)
        plot_allbands_df(subdf,ax=ax2,title=title)

        plt.suptitle(file.split("/")[-1])
        plt.show()
    return mismatchdict

In [5]:
mismatchdict={}
for file in filels:
    mismatchdict = give_detected_info(file,mismatchdict)

In [6]:
mismatchdict

{'SNIa-91bg': 1,
 'TDE': 7,
 'SNII': 3,
 'SNIbc': 0,
 'SNIa': 3,
 'SLSN-I': 15,
 'SNIax': 1,
 'AGN': 59}