In [1]:
import csv
import os
import numpy as np
import pandas as pd
from astropy.io import fits
import tensorflow as tf
import matplotlib.pyplot as plt
from collections import Counter
from astropy.table import Table 
import tensorflow as tf
from scipy.interpolate import InterpolatedUnivariateSpline

In [2]:
filekey_lst = {
    1: (2018206045859, 120), 2: (2018234235059, 121),
    3: (2018263035959, 123), 4: (2018292075959, 124),
    5: (2018319095959, 125), 6: (2018349182500, 126),
    7: (2019006130736, 131), 8: (2019032160000, 136),
    9: (2019058134432, 139), 10: (2019085135100, 140),
    11: (2019112060037, 143), 12: (2019140104343, 144),
    13: (2019169103026, 146), 14: (2019198215352, 150),
    15: (2019226182529, 151), 16: (2019253231442, 152),
    17: (2019279210107, 161), 18: (2019306063752, 162),
    19: (2019331140908, 164), 20: (2019357164649, 165),
    21: (2020020091053, 167), 22: (2020049080258, 174),
    23: (2020078014623, 177), 24: (2020106103520, 180),
    25: (2020133194932, 182), 26: (2020160202036, 188),
    27: (2020186164531, 189), 28: (2020212050318, 190),
    29: (2020238165205, 193), 30: (2020266004630, 195),
    31: (2020294194027, 198), 32: (2020324010417, 200),
    33: (2020351194500, 203), 34: (2021014023720, 204),
    35: (2021039152502, 205), 36: (2021065132309, 207),
    37: (2021091135823, 208), 38: (2021118034608, 209),
    39: (2021146024351, 210), 40: (2021175071901, 211),
    41: (2021204101404, 212), 42: (2021232031932, 213),
    43: (2021258175143, 214), 44: (2021284114741, 215),
    45: (2021310001228, 216), 46: (2021336043614, 217),
    47: (2021364111932, 218), 48: (2022027120115, 219),
    49: (2022057073128, 221), 50: (2022085151738, 222),
    51: (2022112184951, 223), 52: (2022138205153, 224),
    53: (2022164095748, 226), 54: (2022190063128, 227),
    55: (2022217014003, 242), 56: (2022244194134, 243),
    57: (2022273165103, 245), 58: (2022302161335, 247),
    59: (2022330142927, 248), 60: (2022357055054, 249),
    61: (2023018032328, 250), 62: (2023043185947, 254),
    63: (2023069172124, 255), 64: (2023096110322, 257),
    65: (2023124020739, 259), 66: (2023153011303, 260),
    67: (2023181235917, 261), 68: (2023209231226, 262),
    69: (2023237165326, 264), 70: (2023263165758, 265),
    71: (2023289093419, 266), 72: (2023315124025, 267),
    73: (2023341045131, 268), 74: (2024003055635, 269),
    75: (2024030031500, 270), 76: (2024058030222, 271),
    77: (2024085201119, 272), 78: (2024114025118, 273),
    79: (2024142205832, 274), 80: (2024170053053, 275),
    81: (2024196212429, 276), 82: (2024223182411, 278),
    83: (2024249191853, 280), 84: (2024274222008, 281),
    85: (2024300212641, 282)
}
#以下函数返回指定的 TIC 编号的目标星在指定的 sector 中的光变曲线的文件名
def get_lc_file(sector, tic):
    timestamp, scid = filekey_lst[sector]
    return 'tess{:13d}-s{:04d}-{:016d}-{:04d}-s_lc.fits'.format(timestamp, sector, tic, scid)

In [3]:
'''
功能：读取已预处理的数据放入模型中预测

输入：扇区名，字符串形式
输出：预测的结果,向量形式,二维
      该扇区数据的tic名称列表

'''
def predict(sector):
    #读取数据
    lc = pd.read_csv('F:/tess/processedData/lc/' + sector + '.csv',header=None)
    GLS = pd.read_csv('F:/tess/processedData/GLS/' + sector + '.csv',header=None)
    #数据第一列存储lc的名称
    ticLst = lc.loc[:,0]
    #除第一列以外的数据点用来预测
    lcData = lc.loc[:,1:]
    GLSData = GLS.loc[:,1:]
    
    #转换格式，方便放入模型中
    lcData = list(np.array(lcData))
    lcData = tf.convert_to_tensor(lcData)

    lcData = tf.expand_dims(lcData, 2)
    GLSData = list(np.array(GLSData))
    GLSData = tf.convert_to_tensor(GLSData)
    GLSData = tf.expand_dims(GLSData, 2)
    #预测
    predArray = model.predict([lcData,GLSData])
    #predArray = model.predict(lcData)
    #保存预测出的各类分数
    predData = pd.DataFrame(predArray)

    return predData,ticLst

In [4]:
'''
功能：将模型预测出来的结果对应到各个类别

输入：predData,经过模型预测后得到的结果数据
      threshold,判定类别的阈值
输出：predLabel,预测的类别
      scoreLst,得到的分数

'''
def getLabel(predData,threshold):
    predLabel = []
    for i in range(len(predData)):
        maxone = np.argmax(predData[i])
        score = predData[i][maxone]
        if score >= threshold:
            label = label_lst[maxone]
        if score < threshold:
            label = 'NOTSURE'
        predLabel.append(label)
    return predLabel

In [5]:
#获取已知的EB数据
EBData = Table.read('./rawData/newclass_samples.dat', format='ascii.fixed_width_two_line')
allTic = set(EBData['TIC'])
#读取模型
model = tf.keras.saving.load_model("./model/model1.keras")

In [6]:
#二分类
def crossValidate(sec,th):
    #如果没有做过预测，则预测，得到预测分数表
    if(not os.path.exists('F:/tess/predict/predScore/' + sec + '.csv')):
        predData, nameLst = predict(sec)
        ticLst = [int(name.split("-")[2]) for name in nameLst]
        predScore = pd.concat([pd.DataFrame(ticLst),predData],axis= 1)
        predScore.columns=["tic","EBscore","OHTERSscore"]
        predScore = predScore.drop_duplicates(subset='tic',keep='first').reset_index(drop=True)
        predScore.to_csv('F:/tess/predict/predScore/' + sec + '.csv',sep=',',index=False)
    #如果预测过，直接读取已保存的预测分数表
    else:
        predScore = pd.read_csv('F:/tess/predict/predScore/' + sec + '.csv')
    #根据阈值做类别预测
    predEBTic = set(predScore[predScore['EBscore'] >= th]['tic'])
    trueEBTic = set(EBData[EBData["sector"] == int(sec.strip('s'))]['TIC'])
    crossEBTic = predEBTic & allTic
    
    print("--------------" + sec + "----------------")
    print("预测数量：{}".format(len(predEBTic)))
    print("已有数量：{}".format(len(trueEBTic)))
    print("召回数量：{}".format(len(crossEBTic)))
    # print("召回率：{:.4}".format(len(crossEBTic)/len(trueEBTic)))

    return predEBTic,trueEBTic,crossEBTic

In [7]:
sectorLst = ['s071','s072','s073','s074']
            # 's001','s002','s003','s004','s005','s006','s007','s008','s009','s010'
            # 's011','s012','s013','s014','s015','s016','s017','s018','s019','s020'
            # 's021','s022','s023','s024','s025','s026','s027','s028','s029','s030'
            # 's031','s032','s033','s034','s035','s036','s037','s038','s039','s040'
            # 's041','s042','s043','s044','s045','s046','s047','s048','s049','s050'
            # 's051','s052','s053','s054','s055','s056','s057','s058','s059','s060'
            # 's061','s062','s063','s064','s065','s066','s067','s068','s069','s070'
for sec in sectorLst:
    pred,true,cross = crossValidate(sec,0.90)

--------------s071----------------
预测数量：482
已有数量：0
召回数量：93
--------------s072----------------
预测数量：501
已有数量：0
召回数量：75
--------------s073----------------
预测数量：1029
已有数量：0
召回数量：317
--------------s074----------------
预测数量：1035
已有数量：0
召回数量：348


In [11]:
import matplotlib
matplotlib.use('Agg')

sectorLst = ['s074']
            # 's001','s002','s003','s004','s005','s006','s007','s008','s009','s010'
            # 's011','s012','s013','s014','s015','s016','s017','s018','s019','s020',
            # 's021','s022','s023','s024','s025','s026','s027','s028','s029','s030',
            # 's031','s032','s033','s034','s035','s036','s037','s038','s039','s040']
            # 's041','s042','s043','s044','s045','s046','s047','s048','s049','s050',
            # 's051','s052','s053','s054','s055','s056','s057','s058','s059','s060',
            # 's061','s062','s063','s064','s065','s066','s067','s068','s069','s070']
            # 's072','s073','s074'
for sec in sectorLst:
    pred,true,cross = crossValidate(sec,0.90)
    new = list(pred - true - cross)

    #画图
    for tic in new:
        name = get_lc_file(int(sec.strip('s')),tic)
        filename = 'F:/tess/lc/' + sec + '/' + name
        table = fits.getdata(filename)
        time = table['TIME']
        flux = table['PDCSAP_FLUX']
        q_lst = table['QUALITY']
        m = q_lst == 0
        time = time[m]
        flux = flux[m]
        m2 = ~np.isnan(flux)
        t = time[m2]
        f = flux[m2]

        #文件保存的文件夹位置
        filepath = 'F:/tess/predict/images/'+ sec
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        #画图保存 
        fig = plt.figure(figsize = (40,10),dpi = 100)
        plt.ion()
        plt.plot(t,f,'o',color = 'C0',ms = 3,alpha = 1)
        plt.ioff()
        plt.savefig(filepath + '/' + str(tic) + '.png')

--------------s074----------------
预测数量：1035
已有数量：0
召回数量：348


  fig = plt.figure(figsize = (40,10),dpi = 100)


In [11]:
## 统计全部的召回率
allPredEBTic = set()
th = 0.90
sectorLst = ['s001','s002','s003','s004','s005','s006','s007','s008','s009','s010',
            's011','s012','s013','s014','s015','s016','s017','s018','s019','s020',
            's021','s022','s023','s024','s025','s026','s027','s028','s029','s030',
            's031','s032','s033','s034','s035','s036','s037','s038','s039','s040',
            's041','s042','s043','s044','s045','s046','s047','s048','s049','s050',
            's051','s052','s053','s054','s055','s056','s057','s058','s059','s060',
            's061','s062','s063','s064','s065','s066','s067','s068','s069','s070']

for sec in sectorLst:
    predScore = pd.read_csv('F:/tess/predict/predScore/' + sec + '.csv')
    predEBTic = set(predScore[predScore['EBscore'] >= th]['tic'])
    allPredEBTic = allPredEBTic | predEBTic
    
allCrossEBTic = allPredEBTic & allTic

print("预测数量：{}".format(len(allPredEBTic)))
print("已有数量：{}".format(len(allTic)))
print("召回数量：{}".format(len(allCrossEBTic)))
print("召回率：{}".format(len(allCrossEBTic) / len(allTic)))

预测数量：14445
已有数量：4225
召回数量：4187
召回率：0.9910059171597633
