In [1]:
if "descended" not in locals():
    descended = 1
    %cd ".."
    
import utils
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from lilith.lilith_utils import *

from wotan import flatten
import visualize as vis

from dataloading import loading as dl

from detection import bls_detection as bls_det
from detection import rnn_detection as rnn_det

/Users/Yke/Desktop/AI/Thesis/ESA/transit-detection-rnn


In [2]:
def read_dvdic(path):
    with open(path, "rb") as f:
        dvdic = pickle.load(f)
    return dvdic

In [3]:
dv_all = read_dvdic("data/lilith/dv/dv_dic.pkl")
dv_sect = {}
for i in [1,2,3,4]:
    dv_sect[i] = read_dvdic(f"data/lilith/sector{i}/dv/dv_dic.pkl")

In [4]:
len(dv_all), len(dv_sect[1]), len(dv_sect[2]), len(dv_sect[3]), len(dv_sect[4])

(1256, 498, 601, 592, 490)

In [6]:
BASE_PATH = "data/lilith"
SECTOR = {i:f"/sector{i}" for i in [1,2,3,4]}  # use as: path = BASE_PATH + SECTOR[i]
[BASE_PATH + "/" + dirname for dirname in os.listdir(BASE_PATH) if os.path.isdir(BASE_PATH + "/" + dirname)]

# load ground-truth and sector sampleids
sampleids_sector = {i:sampleids_from_curl(BASE_PATH+SECTOR[i]+'/fits') for i in SECTOR}
sampleids = set()
for i in SECTOR:
    sampleids.update(sampleids_sector[i])
sampleids = np.array(list(sampleids))

gt_sector = {i:{obj:{} for obj in ["pl", "eb", "beb"]} for i in SECTOR}
gt = {obj:{"params":{}, "sampleids":[]} for obj in ["pl", "eb", "beb"]}
for i in SECTOR:
    gt_path = BASE_PATH+SECTOR[i]+'/ground_truth'
    gt_sector[i]["pl"]["params"], gt_sector[i]["pl"]["sampleids"] = get_pl_data(gt_path, sampleids_sector[i])
    gt_sector[i]["eb"]["params"], gt_sector[i]["eb"]["sampleids"] = get_eb_data(gt_path, sampleids_sector[i], backeb=0)
    gt_sector[i]["beb"]["params"], gt_sector[i]["beb"]["sampleids"] = get_eb_data(gt_path, sampleids_sector[i], backeb=1)
    for obj in ["pl", "eb", "beb"]:
        for sampleid, obj_data in gt_sector[i][obj]["params"].items():  
            if sampleid in gt[obj]["sampleids"]:
                # ground truth already saved from prev sector
                if obj_data != gt[obj]["params"][sampleid]:
                    print("WARNING: inconsistent ground-truth data")
            elif obj_data != {}:
                gt[obj]["params"][sampleid] = obj_data
                gt[obj]["sampleids"].append(sampleid)
for obj in ["pl", "eb", "beb"]:
    gt[obj]["sampleids"] = np.array(gt[obj]["sampleids"])
    for sampleid in sampleids:
        if sampleid not in gt[obj]["sampleids"]:
            gt[obj]["params"][sampleid] = {}



In [7]:
# select test samples and display statistics
test_sampleids = {obj:set() for obj in ["all", "pl", "eb", "beb"]}  # not to be used for training

def inter(a, b):
    return np.intersect1d(a, b)

for i1 in SECTOR:
    for i2 in SECTOR:
        if i2 > i1:
            s_ids = inter(sampleids_sector[i1], sampleids_sector[i2])
            pl_ids = inter(gt_sector[i1]["pl"]["sampleids"], gt_sector[i2]["pl"]["sampleids"])
            eb_ids = inter(gt_sector[i1]["eb"]["sampleids"], gt_sector[i2]["eb"]["sampleids"])
            beb_ids = inter(gt_sector[i1]["beb"]["sampleids"], gt_sector[i2]["beb"]["sampleids"])
            
            test_sampleids["all"].update(s_ids), test_sampleids["pl"].update(pl_ids)
            test_sampleids["eb"].update(eb_ids), test_sampleids["beb"].update(beb_ids)

In [8]:
pl_ids_corrected = []
for s_id in test_sampleids["pl"]:
    if (s_id not in test_sampleids["eb"]) and (s_id not in test_sampleids["beb"]):
        pl_ids_corrected.append(s_id)
pl_ids_corrected= np.array(pl_ids_corrected)

In [9]:
def get_trange(sector):
    with open(f"data/lilith/sector{sector}/raw_batches/test/00000-00249", "rb") as f:
        b = pickle.load(f)
    return (b["time"][0][0], b["time"][0][-1])
sector_range = {i:get_trange(i) for i in [1,2,3,4]}

In [10]:
def consec_sectors(sec_str, only_two=False):
    consec = 0
    for s in sec_str:
        if s=="1":
            consec += 1
        elif consec and s=="0":
            return False
    if only_two:
        if consec != 2:
            return False
    return True

def consec_from_data(times, only_two=False):
    if len(times)==1:
        print("WARNING: test sample with only one sector")
    t_cat = np.concatenate(times)
    if np.max(np.diff(t_cat))>10:
        return False
    if only_two and len(times) > 2: 
        return False
    return True

def retrieve_sector(time):
    midt = np.nanmean(time)
    for s in sector_range:
        if midt > sector_range[s][0] and midt < sector_range[s][1]:
            return s

In [11]:
consec = {}
excluded = []
pbar = tqdm(os.listdir("data/eval/lilith/processed_nn_basic"))
for fname in pbar:
    try:
        with open("data/eval/lilith/processed_nn_basic/"+fname, "rb") as f:
            b = pickle.load(f)
        for s_id, data in b.items():
            if s_id in pl_ids_corrected:
                if len(data["time"])==0:
                    excluded.append(s_id)
                    continue
                consec[s_id] = consec2_from_data(data["time"])
    except:
        pbar.close()
        raise

  5%|▌         | 1/20 [00:00<00:18,  1.02it/s]



 70%|███████   | 14/20 [00:13<00:05,  1.16it/s]



100%|██████████| 20/20 [00:18<00:00,  1.06it/s]






In [12]:
for s_id in excluded:
    keep = pl_ids_corrected!=s_id
    pl_ids_corrected = pl_ids_corrected[keep]

In [13]:
def fix_t0(t0, period, sectors, warnings=False):
    sectors = sectors if isinstance(sectors, list) else [sectors]
    tt = t0
    while tt < sector_range[sectors[0]][0]:
        tt += period
    while tt > sector_range[sectors[0]][0]:
        tt -= period
    
    if tt + period > sector_range[sectors[-1]][1] and warnings:
        print("WARNING: period too large to fix t0")
    return tt + period

def accept_planet(pl_data, sectors):
    sectors = sectors if isinstance(sectors, list) else [sectors]
    tt = fix_t0(pl_data["t0"], pl_data["orb_period"], sectors)

    n_transits = 0
    while tt < sector_range[sectors[-1]][1]:
        n_transits += 1
        tt += pl_data["orb_period"]
    if n_transits < 2:
        return False
    return True  

In [47]:
def correct_detection(detection, gt_params, sectors, per_frac=0.01, dur_factor=1):
    
    true_per = gt_params["orb_period"]
    true_t0 = fix_t0(gt_params["t0"], true_per, sectors, warnings=1)
    dur = gt_params["duration"]
    
    pred_per = detection["period"]
    pred_t0 = detection["t0"]
        
    per_correct = ((1-per_frac)*true_per <= pred_per) and (pred_per <= (1+per_frac)*true_per)
    t0_correct = ((true_t0-0.5*dur_factor*dur) <= pred_t0) and (pred_t0 <= (true_t0+0.5*dur_factor*dur))
    
    return (per_correct and t0_correct), per_correct, t0_correct

In [52]:
targets_sect = {i:{} for i in [1,2,3,4]}
rejected_sect = {i:{} for i in [1,2,3,4]}

for sector in [1,2,3,4]:
    for s_id in pl_ids_corrected:
        if s_id in gt_sector[sector]["pl"]["params"]:
            targets_sect[sector][s_id] = {}
            rejected_sect[sector][s_id] = {}
            for pl_i, pl_data in gt_sector[sector]["pl"]["params"][s_id].items():
                if accept_planet(pl_data, sector):
                    targets_sect[sector][s_id][pl_i] = pl_data
                else:
                    rejected_sect[sector][s_id][pl_i] = pl_data

In [57]:
pipeline_detections_sect = {i:{} for i in [1,2,3,4]}
for sector in [1,2,3,4]:
    for s_id in pl_ids_corrected:
        if s_id in gt_sector[sector]["pl"]["params"]:
            pipeline_detections_sect[sector][s_id] = dv_sect[sector][s_id] if s_id in dv_sect[sector] else {}

In [None]:
def correct_detection(params, detection, dur_factor=1, per_frac=0.01):
    # assumes sample (params) contains single planet 
    t0_true = params["t0"]
    dur_true = params["duration"]
    per_true = params["period"]
    
    t0_correct = (detection["t0"] > (t0_true - 0.5*dur_factor*dur_true)) and \
                 (detection["t0"] < (t0_true + 0.5*dur_factor*dur_true))
    per_correct = (1-per_frac)*per_true < detection["period"] and detection["period"] < (1+per_frac)*per_true
    correct = (t0_correct and per_correct)
    return correct

def evaluate_thresholds(detections, thresholds):
    snames = ["tp", "fp", "fn", "tn"]
    results = {sname:[] for sname in snames}
    pbar = tqdm(thresholds)
    for thr in pbar:  
        try:
            tp = fp = tn = fn = 0
            for i in meta:
                num_planets = len(meta[i])
                found_planets = 0
                for score, det in detections[i].items():
                    if score >= thr:
                        found_one = False
                        if num_planets > 0:
                            for pl_i in meta[i]:
                                if correct_detection(meta[i][pl_i], det):
                                    found_one = True
                                    found_planets += 1
                                    tp += 1
                                    break
                        if not found_one:
                            fp += 1
                if found_planets > num_planets:
                    print("WARNING: found more planets than possible")
                fn += (num_planets-found_planets)
                # true negative is always inf
            results["tp"].append(tp), results["fp"].append(fp) 
            results["tn"].append(tn), results["fn"].append(fn) 
        except:
            pbar.close()
            raise
    for sname in snames:
        results[sname] = np.array(results[sname])
    return result


def concatenate_sectors(detections):
    detections_ = {}
    for sector in detections:
        for s_id in detections[sector]:
            detections_[f"{sector}-{s_id}"] = detections[sector][s_id]
    return detections_

In [51]:
# # individual sector detections

# pipeline_results = {}
# for sector in [1,2,3,4]:
#     sector_results = {"sampleids":[], "planets":[], "detected":[], "tp":[], "fn":[], "fp":[]}
#     pipeline_results[sector] = sector_results
#     for s_id in pl_ids_corrected:
#         if s_id in gt_sector[sector]["pl"]["params"]:
#             sector_results["sampleids"].append(s_id)
#             accepted_planets = []
#             detected_planets = []
#             tp = fp = fn = 0
#             for pl_i, pl_data in gt_sector[sector]["pl"]["params"][s_id].items():
#                 if accept_planet(pl_data, sector):
#                     detected = False
#                     accepted_planets.append(pl_i)
                    
#                     if s_id in dv_sect[sector]:
#                         if correct_detection(dv_sect[sector][s_id], pl_data, sector)[0]:
#                             detected_planets.append(pl_i)
#                             detected = True
#                             tp += 1
#                     if not detected:
#                         fn += 1
#             fp = 1 if ((s_id in dv_sect[sector]) and len(detected_planets)==0\
#                        and dv_sect[sector][s_id]["n_transit"]>=3) else 0
                        
#             sector_results["planets"].append(accepted_planets)
#             sector_results["detected"].append(detected_planets)
#             sector_results["tp"].append(tp), sector_results["fp"].append(fp), 
#             sector_results["fn"].append(fn)

In [None]:
print("s  tp  fp  fn")
for i in [1,2,3,4]:
    print(i,"", sum(pipeline_results[i]["tp"]), sum(pipeline_results[i]["fp"]), sum(pipeline_results[i]["fn"]))

In [17]:
# load_path = "data/eval/lilith/processed_bls_12h_outlier"
# write_path = "results/bls_multi_3it_outlier_detections"

# fnames = [fnm for fnm in os.listdir(load_path) if not fnm.startswith(".")]
# utils.make_dir(write_path)
# pbar = tqdm(fnames)
# for fname in pbar:
#     try:
#         bls_detections = {i:{} for i in [1,2,3,4,0]}
#         batch = dl.load_data(load_path + "/"+ fname)
#         if batch is None:
#             continue
        
#         for sampleid in batch:
#             if sampleid in pl_ids_corrected:
#                 for i in range(len(batch[sampleid]["time"])):
#                     time = batch[sampleid]["time"][i]
#                     sector = retrieve_sector(time)
#                     flat = batch[sampleid]["flux"][i]
#                     detections = bls_det.algorithm(time, flat, num_iters=3)
#                     bls_detections[sector][sampleid] = detections

# #                 if consec[sampleid]:
# #                     time = np.concatenate(batch[sampleid]["time"])
# #                     flat = np.concatenate(batch[sampleid]["flux"])
# #                     detections = bls_det.algorithm(time, flat, min_transits=5, num_iters=3, freq_fac=6)
# #                     bls_detections[0][sampleid] = detections
#         with open(write_path + "/" +fname, "wb") as f:
#             pickle.dump(bls_detections, f) 
#     except:
#         pbar.close()
#         raise

In [69]:
# # combine bls_detections
# bls_detections_sect = {i:{s_id:{} for s_id in targets_sect[i]} for i in [1,2,3,4]}
# for fnm in os.listdir("results/bls_multi_3it_detections"):
#     with open("results/bls_multi_3it_detections/"+fnm, "rb") as f:
#         b = pickle.load(f)
#     for i in [1,2,3,4]:
#         bls_detections_sect[i] = {**bls_detections_sect[i], **b[i]}
# with open("results/bls_multi_3it_detections.pkl", "wb") as f:
#     pickle.dump(bls_detections_sect, f)

In [None]:
rnn_detections_sect = {i:{s_id:{} for s_id in targets_sect[i]} for i in [1,2,3,4]}
pbar = tqdm(os.listdir("data/eval/lilith/nn_basic_pts"))
for fnm in pbar:
    try:
        if not fnm.startswith("."):
            with open("data/eval/lilith/nn_basic_pts/"+fnm, "rb") as f:
                b_pts = pickle.load(f)
            with open("data/eval/lilith/processed_nn_basic/"+fnm, "rb") as f:
                b = pickle.load(f)

            for s_id in b:
                for i in range(len(b[s_id]["time"])):
                    sector = retrieve_sector(b[s_id]["time"][i])
                    detected = rnn_det.algorithm1(b_pts[s_id][i].copy(), num_iters=3, 
                                                  min_transits=3, p_min=2, p_max=None, step_mult=2, 
                                                  smooth=True, peak_frac=2, show_steps=False)
                    rnn_detections_sect[sector][s_id] = detected
    except:
        pbar.close()
        raise

 80%|████████  | 16/20 [1:16:16<18:25, 276.42s/it]

In [None]:
# TODO
with open("results/rnn_multi_3it_detections.pkl", "wb") as f:
    pickle.dump(rnn_detections_sect, f)