In [2]:
import csv
import numpy as np
import pandas as pd
import random
from astropy.table import Table
from collections import Counter
from sklearn.model_selection import train_test_split

In [2]:
#读取EBdata
EBData = Table.read('./rawData/newclass_samples.dat', format='ascii.fixed_width_two_line')
EBData

TIC,sector,class
int32,int32,str2
91961,11,EA
120016,11,EW
120016,38,EW
627436,5,EW
627436,32,EW
671564,5,EW
737546,5,EA
862763,34,EA
927554,34,EA
1045298,4,EB


# 获取训练样本-非EB

In [None]:
filekey_lst = {
    1: (2018206045859, 120), 2: (2018234235059, 121),
    3: (2018263035959, 123), 4: (2018292075959, 124),
    5: (2018319095959, 125), 6: (2018349182500, 126),
    7: (2019006130736, 131), 8: (2019032160000, 136),
    9: (2019058134432, 139), 10: (2019085135100, 140),
    11: (2019112060037, 143), 12: (2019140104343, 144),
    13: (2019169103026, 146), 14: (2019198215352, 150),
    15: (2019226182529, 151), 16: (2019253231442, 152),
    17: (2019279210107, 161), 18: (2019306063752, 162),
    19: (2019331140908, 164), 20: (2019357164649, 165),
    21: (2020020091053, 167), 22: (2020049080258, 174),
    23: (2020078014623, 177), 24: (2020106103520, 180),
    25: (2020133194932, 182), 26: (2020160202036, 188),
    27: (2020186164531, 189), 28: (2020212050318, 190),
    29: (2020238165205, 193), 30: (2020266004630, 195),
    31: (2020294194027, 198), 32: (2020324010417, 200),
    33: (2020351194500, 203), 34: (2021014023720, 204),
    35: (2021039152502, 205), 36: (2021065132309, 207),
    37: (2021091135823, 208), 38: (2021118034608, 209),
    39: (2021146024351, 210), 40: (2021175071901, 211),
    41: (2021204101404, 212), 42: (2021232031932, 213),
    43: (2021258175143, 214), 44: (2021284114741, 215),
    45: (2021310001228, 216), 46: (2021336043614, 217),
    47: (2021364111932, 218), 48: (2022027120115, 219),
    49: (2022057073128, 221), 50: (2022085151738, 222),
    51: (2022112184951, 223), 52: (2022138205153, 224),
    53: (2022164095748, 226), 54: (2022190063128, 227),
    55: (2022217014003, 242), 56: (2022244194134, 243),
    57: (2022273165103, 245), 58: (2022302161335, 247),
    59: (2022330142927, 248), 60: (2022357055054, 249),
    61: (2023018032328, 250), 62: (2023043185947, 254),
    63: (2023069172124, 255), 64: (2023096110322, 257)
}
#以下函数返回指定的 TIC 编号的目标星在指定的 sector 中的光变曲线的文件名
def get_lc_file(sector, tic):
    timestamp, scid = filekey_lst[sector]
    return 'tess{:13d}-s{:04d}-{:016d}-{:04d}-s_lc.fits'.format(timestamp, sector, tic, scid)

In [None]:
#用字典保存每个扇区下的EB的tic
ebNameDict = {}
for data in EBData:
    sector = data['sector']
    tic = data['TIC']
    name = get_lc_file(sector, tic)
    if sector not in ebNameDict:
        ebNameDict[sector] = [name]
    else:
        ebNameDict[sector].append(name)

In [None]:
#删除某扇区数据中的EB数据
def dropEB(secData,ebName):
    return secData.set_index(0).drop(ebName).reset_index().drop([0],axis = 1)

In [None]:
sectorLst = ['s010',
              's011','s012','s013','s014','s015','s016','s017','s018','s019','s020',
              's021','s022','s023','s024','s025','s026','s027','s028','s029','s030',
              's031','s032','s033','s034','s035','s036','s037','s038','s039','s040',
              's041','s042','s043','s044','s045','s046','s047','s048','s049','s050',
              's051','s052','s053','s054','s055','s056','s057','s058','s059','s060',
              's061','s062','s063','s064']
secNum = 1000  #每个扇区选择非eb数据的数量

for sec in sectorLst:
    lc = pd.read_csv('F:/tess/processedData/lc/'+sec+'.csv',header=None)
    GLS = pd.read_csv('F:/tess/processedData/GLS/'+sec+'.csv',header=None)

    #获得该扇区下的EBname
    ebName = ebNameDict[int(sec.strip('s'))]
    #去除EB后的lc和GLS数据
    newlc = dropEB(lc,ebName)
    newlc = newlc.sample(n=secNum, random_state=1)
    newGLS = dropEB(GLS,ebName)
    newGLS = newGLS.sample(n=secNum, random_state=1)

    #将去除eb后的新数据保存
    newlc.to_csv('F:/tess/processedData/非EB/lc/' + sec + '.csv',header=False,index = False)
    newGLS.to_csv('F:/tess/processedData/非EB/GLS/' + sec + '.csv',header=False,index = False)


# 构造训练集和测试集

## lc

In [5]:
allLabel = []  #所有标签列表
allData = pd.DataFrame()

#EB数据
for label in ['EA','EB','EW','BY','DSCT','GDOR','M','RRAB','RRC','ELL']:
    data = pd.read_csv('./processedData_4000/'+label+'.csv',header = None)
    if label == 'EA':
        allData = pd.concat([allData,data])
        allLabel = allLabel + [0]*len(data)
    elif label == 'EB':
        allData = pd.concat([allData,data])
        allLabel = allLabel + [0]*len(data)
    elif label == 'EW':
        allData = pd.concat([allData,data])
        allLabel = allLabel + [0]*len(data)
    else:
        allData = pd.concat([allData,data])
        allLabel = allLabel + [1]*len(data)

#OTHERS数据
sectorLst = ['s001','s002','s003','s004','s005','s006','s007','s008','s009','s010',
            's011','s012','s013','s014','s015','s016','s017','s018','s019','s020',
            's021','s022','s023','s024','s025','s026','s027','s028','s029','s030',
            's031','s032','s033','s034','s035','s036','s037','s038','s039','s040',
            's041','s042','s043','s044','s045','s046','s047','s048','s049','s050',
            's051','s052','s053','s054','s055','s056','s057','s058','s059','s060',
            's061','s062','s063','s064']
for sec in sectorLst:
    data = pd.read_csv('F:/tess/processedData/非EB/lc/'+sec+'.csv',header = None).sample(n = 430,random_state = 22)
    allData = pd.concat([allData,data])
    allLabel = allLabel + [1]*len(data)

fluxData = np.array(allData)
fluxLabel = np.array(allLabel)


## gls

In [6]:
allLabel = []  #所有标签列表
allData = pd.DataFrame()

#EB数据
for label in ['EA','EB','EW','BY','DSCT','GDOR','M','RRAB','RRC','ELL']:
    data = pd.read_csv('./GLSdata_1000/'+label+'.csv',header = None)
    if label == 'EA':
        allData = pd.concat([allData,data])
        allLabel = allLabel + [0]*len(data)
    elif label == 'EB':
        allData = pd.concat([allData,data])
        allLabel = allLabel + [0]*len(data)
    elif label == 'EW':
        allData = pd.concat([allData,data])
        allLabel = allLabel + [0]*len(data)
    else:
        allData = pd.concat([allData,data])
        allLabel = allLabel + [1]*len(data)


#OTHERS数据
sectorLst = ['s001','s002','s003','s004','s005','s006','s007','s008','s009','s010',
            's011','s012','s013','s014','s015','s016','s017','s018','s019','s020',
            's021','s022','s023','s024','s025','s026','s027','s028','s029','s030',
            's031','s032','s033','s034','s035','s036','s037','s038','s039','s040',
            's041','s042','s043','s044','s045','s046','s047','s048','s049','s050',
            's051','s052','s053','s054','s055','s056','s057','s058','s059','s060',
            's061','s062','s063','s064']
for sec in sectorLst:
    data = pd.read_csv('F:/tess/processedData/非EB/GLS/'+sec+'.csv',header = None).sample(n = 430,random_state = 22)
    allData = pd.concat([allData,data])
    allLabel = allLabel + [1]*len(data)

GLSData = np.array(allData)
GLSLabel = np.array(allLabel)

## split

In [12]:
print('所有数据中的类别数量情况：{}'.format(Counter(fluxLabel)))
f_trainX, f_testX, f_trainY, f_testY = train_test_split(fluxData, fluxLabel, test_size=0.2, random_state=10)
g_trainX, g_testX, g_trainY, g_testY = train_test_split(GLSData, GLSLabel, test_size=0.2, random_state=10)

print('训练集中的类别数量情况：{}'.format(Counter(f_trainY)))
print('测试集中的类别数量情况：{}'.format(Counter(f_testY)))

所有数据中的类别数量情况：Counter({1: 36810, 0: 36651})
训练集中的类别数量情况：Counter({1: 29441, 0: 29327})
测试集中的类别数量情况：Counter({1: 7369, 0: 7324})


## save

In [7]:
def save(data,fileName):
    df = pd.DataFrame(data)
    df.to_csv('./train_test_data/' + fileName +'.csv', index=False, header=None)

In [8]:
save(fluxData,'lc_X')
save(fluxLabel,'lc_Y')
save(GLSData,'gls_X')
save(GLSLabel,'gls_Y')

In [38]:
save(f_trainX,'lc_trainX')
save(f_testX,'lc_testX')
save(f_trainY,'lc_trainY')
save(f_testY,'lc_testY')
save(g_trainX,'gls_trainX')
save(g_testX,'gls_testX')
save(g_trainY,'gls_trainY')
save(g_testY,'gls_testY')
