## __Pick up Gaia EDR3 Data__ ##

__You can obtain the required files from the following website：__
https://warwick.ac.uk/fac/sci/physics/research/astro/research/catalogues/gaiaedr3_wd_main.fits.gz  
 (from article,  $\textit{A catalogue of white dwarfs in Gaia EDR3}$  

* __the following codes are to create a substitutable, simplified *.csv__  
( the simplified file is in a 'temp' folder in the same folder as your file, gaiaedr3_wd_main.fits.gz

In [15]:
# the needed module import
# used to read the fits
from astropy.io import fits 

# used to desgin functions to simplify the code
from tqdm import tqdm, trange # not necessary, just a visible progress bar

# to create a new csv
import csv
import os
from pathlib import Path

In [17]:
path = 'C:/Users/Administrator/Desktop/大创/Data' #the path you saved the gaiaedr3_wd_main.fits.gz
name_list = ['WDJ_name','ra','dec','Pwd','phot_g_mean_mag_corrected']
Pwd_limitation = 0.75 #the minimum of Pwd
mag_limitation = [None, None]#the range of mag

In [18]:
class InputError(Exception):
    pass
def Gaia_EDR3_Data_Split(path,
                        name_list,
                        Pwd_limitation,
                        mag_limitation,
                        next_folder = 'Gaia',
                        file_create = True,
                        return_data = False):
    
    #through Pwd and phot_g_mean_mag_corrected
    if 0 > Pwd_limitation or Pwd_limitation >= 1:
        raise InputError ('Pwd_limitation should be in [0, 1)')
    if len(mag_limitation) != 2:
        raise InputError ('mag_limitation should be a list with 2 element or NoneType')
    
    if 'WDJ_name' not in name_list:
        name_list.insert(0 , 'WDJ_name')
    elif name_list[0] != 'WDJ_name':
        temp = name_list[0]
        name_list[0] = 'WDJ_name'
        for i in range(1, len(name_list)):
            if name_list[i] == 'WDJ_name':
                name_list[i] = temp
                temp = None
                break
    if path[len(path) - 1] == '/':
        path = path + 'gaiaedr3_wd_main.fits.gz'
    elif Path(path).parts[len(Path(path).parts)-1] != 'gaiaedr3_wd_main.fits.gz':
        path = path + '/' + 'gaiaedr3_wd_main.fits.gz'
        
    try:
        fits_file = fits.open(path)
    except FileNotFoundError:
        print('Wrong path of gaiaedr3_wd_main.fits.gz')
    else:
        fits_file_num = fits_file[1].header['NAXIS2']   # the total number of WDs in GaiaEDR3_WD_main.fits
        fits_file_data = fits_file[1].data   # pre-load the data
        fits_file.close()
        
        Pwd = fits_file_data.field('Pwd')
        Mag = fits_file_data.field('phot_g_mean_mag_corrected')
        
        data = []
        for j in trange(0 , fits_file_num):
            if Pwd[j] > Pwd_limitation:
                if mag_limitation[0] is not None and mag_limitation[1] is not None:
                    if mag_limitation[0] < Mag[j] <= mag_limitation[1]:
                        temp = []
                        for i in range(len(name_list)):
                            temp.append(fits_file_data.field(name_list[i])[j])
                        data.append(temp)
                
                elif mag_limitation[0] is not None or mag_limitation[1] is not None:
                    if mag_limitation[0] is not None:
                        if mag_limitation[0] < Mag[j]:
                            temp = []
                            for i in range(len(name_list)):
                                temp.append(fits_file_data.field(name_list[i])[j])
                            data.append(temp)
                    
                    else:
                        if Mag[j] <= mag_limitation[1]:
                            temp = []
                            for i in range(len(name_list)):
                                temp.append(fits_file_data.field(name_list[i])[j])
                            data.append(temp)
                    
                else:
                    temp = []
                    for i in range(len(name_list)):
                        temp.append(fits_file_data.field(name_list[i])[j])
                    data.append(temp)
        if file_create:
            return Gaia_Simplified_File_Create(path, next_folder, mag_limitation, name_list, data)
        if return_data:
            return data

In [21]:
def Gaia_Simplified_File_Create(path,
                                next_folder,
                                mag_limitation,
                                name_list,
                                data): #This function only can work with Gaia_EDR3_Data_Split Function
    path = path.replace('gaiaedr3_wd_main.fits.gz' , next_folder)
    if not os.path.exists(path):
        os.makedirs(path)
        
# 以写方式打开文件。注意添加 newline=""，否则会在两行数据之间都插入一行空白。
    with open(path+'/selected_data_{0}_{1}.csv'.format(mag_limitation[0] , mag_limitation[1]), mode="w", encoding="utf-8-sig", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(name_list)
        writer.writerows(data)
    return path +'/selected_data_{0}_{1}.csv'.format(mag_limitation[0] , mag_limitation[1])

In [20]:
Gaia_EDR3_Data_Split(path, name_list, Pwd_limitation, mag_limitation)

## __Cross-match with LAMOST by Radius Search__ ##

* __only support a *.csv or *.txt with ra and dec list__  
 ( the result file is in the same folder as your file

In [16]:
# the needed module import
import pandas as pd   
import os

# the needed module import
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.edge.options import Options
import time

import urllib as url

import re
from pathlib import Path
from tqdm import tqdm,trange

In [17]:
class InputError(Exception):
    pass
class HTTPError(Exception):
    pass
def Csv_File_Cross_Match(file_dir,
                         sec = 2.0,
                         txt_save = False):
    if Path(file_dir).parts[len(Path(file_dir).parts)-1][len(Path(file_dir).parts[len(Path(file_dir).parts)-1]) - 4 :] != '.csv':
        raise InputError('Wrong file_dir is inputted')
        
    num = sum(1 for line in open(file_dir))-1
    Ra = list(pd.read_csv(file_dir , index_col = ['ra']).T)
    Dec = list(pd.read_csv(file_dir , index_col = ['dec']).T)
    
    temp_txt = file_dir.replace('.csv', '.txt')
    work = open(temp_txt , mode='w')
    for j in trange(0, num):
        work.write(str(Ra[j]) + ',' + str(Dec[j]) + ',' + str(sec) + '\n')
    work.close()
    
    #displayed or not setting
    opt = Options()
    opt.add_argument('--headless')
    driver = webdriver.Edge(options = opt)
    driver.get('http://www.lamost.org/dr8/v2.0/search')
    #上传
    driver.find_element(By.XPATH,"//html/body/div[2]/form/div[2]/div[1]/div/table/tbody/tr[4]/td[3]/input").send_keys(temp_txt)
    #选择 proximity 查找
    driver.find_element(By.XPATH,"//html/body/div[2]/form/div[2]/div[1]/div/table/tbody/tr[4]/td[1]/span/input").click()
    #点击 search
    driver.find_element(By.XPATH,"//html/body/div[2]/form/div[2]/div[2]/input[1]").click()
    #HTML进入跳转页面并等待3s，防止页面未显示
    driver.switch_to.window(driver.window_handles[-1])
    
    #获取sqlid，以便导出匹配结果
    query = url.parse.urlparse(driver.current_url).query
    driver.close()
    time.sleep(3)
    if not txt_save:
        os.remove(temp_txt)

    if query == '':
        raise HTTPError('the connection to LAMOST is failed')
    else:
        sqlid = re.findall("\d+",query)[0]
    
        #匹配结果下载
        response = url.request.urlopen('http://www.lamost.org/dr8/v2.0/sqlidall/{0}?output.fmt=csv&'.format(sqlid))
        name = response.getheader("Content-disposition").split('=')[1]
        data = response.read()
    
        savefile = file_dir.replace(Path(file_dir).parts[len(Path(file_dir).parts)-1], name)
    
        with open(savefile , 'wb+') as fh:
            fh.write(data)
            fh.close()
        
        result_file = file_dir.replace('.csv', '_cross_match.csv')
        os.rename(savefile , result_file)
        return result_file
        

In [19]:
Csv_File_Cross_Match('C:/Users/Administrator/Desktop/大创/Data/Gaia/selected_data_None_None.csv',
                         sec = 2.0,
                         txt_save = False)

## __Download Spectrum Data by Obsid__ ##

In [6]:
# necessary module
import urllib as url
from urllib.parse import urlparse

import threading

import pandas as pd
import re
import os
from tqdm import tqdm, trange

In [8]:
def Souce_Result_Match(cross_match_file, 
                       source_file, 
                       name = 'WDJ_name'):
    obsid = list(pd.read_csv(cross_match_file, sep = '|', index_col = ['combined_obsid']).T)
    input_id = list(pd.read_csv(cross_match_file, sep = '|', index_col = ['inputobjs_input_id']).T)
    source_id = list(pd.read_csv(source_file, index_col = [name]).T)
    match_name = []
    for i in trange(0, len(input_id)):
        match_name.append(source_id[input_id[i]])
    return [match_name, obsid]

In [52]:
def Download_Fits(path,
                  obsid):
    if not os.path.exists(path):
        os.makedirs(path)
    if type(obsid) == int:
        link = 'http://www.lamost.org:80/dr8/v2.0/spectrum/fits/{0}?token='.format(obsid)
        fits.open(link).writeto(path + '/' + fits.open(link)[0].header['FILENAME'], output_verify='ignore')
    elif type(obsid) == list:
        for i in obsid:
            link = 'http://www.lamost.org:80/dr8/v2.0/spectrum/fits/{0}?token='.format(obsid[i])
            fits.open(link).writeto(path + '/' + fits.open(link)[0].header['FILENAME'], output_verify='ignore')

In [53]:
def Cross_Match_Result_Fits_Download(match,
                                     path,
                                     thread_num = 50,
                                     thread_flag = True):
    if len(match) != 2:
        raise InputError('the match should be a list with 2 lists of souce_id and obsid')
    if len(match[0]) != len(match[1]):
        raise InputError('the lenth of souce_id is not the same as obsid')
    if type(thread_num) != int:
        thread_num = int(thread_num)
    if not os.path.exists(path):
        os.makedirs(path)
    
    if not thread_flag:
        for i in trange(0, len(match[1])):
            Download_Fits(path + '/' + str(match[0][i]) + '/' + str(match[1][i]), int(match[1][i]))
    else:
        if len(match[0]) <= thread_num:
            threads = []
            for i in trange(len(match[0])):
                t = threading.Thread(target = Download_Fits, args=(path + '/' + str(match[0][i]) + '/' + str(match[1][i]), int(match[1][i])))
                threads.append(t)
            for t in (threads):
                t.start()  # 调用start()方法，开始执行
            for t in (threads):
                t.join()
        else:
            if len(match[0]) % thread_num == 0 :
                lists_lenth = int(len(match[0]) / thread_num)
            else:
                lists_lenth = int(len(match[0]) / thread_num) + 1
                
            source_id = []
            obsid = []
            for i in range(lists_lenth):
                source_id.append(match[0][i * thread_num : (i + 1) * thread_num])
                obsid.append(match[1][i * thread_num : (i + 1) * thread_num])
                
            for j in trange(0, lists_lenth):
                threads = []
                for i in range(len(source_id[j])):
                    t = threading.Thread(target = Download_Fits, args=(path + '/' + str(source_id[j][i]) + '/' + str(obsid[j][i]), int(obsid[1][i])))
                    threads.append(t)
                for t in (threads):
                    t.start()  # 调用start()方法，开始执行
                for t in (threads):
                    t.join()

In [55]:
cross_match_file = r'C:\Users\Administrator\Desktop\大创\Data\temp\selected_data_None_None_cross_match.csv'
source_file = r'C:\Users\Administrator\Desktop\大创\Data\temp\selected_data_None_None.csv'
match = Souce_Result_Match(cross_match_file, source_file, name = 'EDJ_name')

In [54]:
path = r'C:\Users\Administrator\Desktop\大创\Data\Gaia\test'
Cross_Match_Result_Fits_Download(match, path, thread_num = 50, thread_flag = True)

## __Spectral Line Determination through Average Method__ ##

In [1]:
import math
import os
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

import numpy as np
from astropy.io import fits

import scipy.interpolate as spi
from sklearn.preprocessing import MinMaxScaler
import cv2

from pathlib import Path
from tqdm import tqdm, trange

In [2]:
def Delta_cal(x,
              y,
              line,
              center_broading,
              peripheral_broading
             ):
    ave_c = sum(y[(line - center_broading < x)&(x < line + center_broading)]) / len(y[(line - center_broading < x)&(x < line + center_broading)])
            
    ave_p = ((sum(y[(line + center_broading < x)&(x < line + center_broading + peripheral_broading)]) 
             + sum(y[(line - center_broading - peripheral_broading < x)&(x < line - center_broading)])) 
             / (len(y[(line + center_broading < x)&(x < line + center_broading + peripheral_broading)]) 
                + len(y[(line - center_broading - peripheral_broading < x)&(x < line - center_broading)])))
    return (ave_p - ave_c) / (max(y) - min(y))

In [30]:
def Determination_of_Spectral_Lines(x, 
                                    y,
                                    z,
                                    lines,
                                    delta,
                                    center_broading,
                                    peripheral_broading
                                   ):
    line = []
    for i in range(0, len(lines)):
        line.append(lines[i] / (z + 1))
        #Warning: Lines_included should be hard copied to avoid the change of its data 
        
    flag = 0
    n = 0
    #print(min(x))
    #print(max(x))
    
    for i in range(0,len(line)):
        if min(x) <= line[i] <= max(x):
            n = n + 1
            if Delta_cal(x, y, line[i], center_broading, peripheral_broading) >= delta:
                flag = flag + 1

            else:
                break
    if n != 0 and n == flag:
        return True
    else:
        return False

In [31]:
import re
def is_number(string):
    pattern = re.compile(r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$')
    return bool(pattern.match(string))

In [38]:
class InputError(Exception):
    pass
def File_Spectral_Lines_Check(file,
                              lines_included,
                              delta = 3.5e-06,
                              center_broading = 5,
                              peripheral_broading = 15,
                              lines_excluded = None,
                              sn = 0):
    if lines_excluded != None and lines_excluded != list and is_number(str(lines_excluded)):
            raise InputError('lines_excluded should be a form among Nonetype, list and number')
    if lines_included != list and is_number(str(lines_included)):
        raise InputError('lines_included should be a form among list and number')
    
    file_fits = fits.open(file)
    x = file_fits[1].data[0].field('WAVELENGTH')
    y = file_fits[1].data[0].field('FLUX')
    z = file_fits[0].header['Z']
    sng = file_fits[0].header['SNRU']
    file_fits.close()
    if z == -9999:
        z = 0
    
    ipo = spi.splrep(x , y , k = 3)
    X = x
    Y = spi.splev(X, ipo)
    Y = cv2.GaussianBlur(src = Y, ksize = (29, 29), sigmaX = 5)
    
    if sng > sn:
        if lines_excluded != None:
            if type(lines_excluded) != list:
                lines_ex = [lines_excluded]
            else:
                lines_ex = lines_excluded 
            flag0 = Determination_of_Spectral_Lines(x, y, z, lines_ex, delta, center_broading, peripheral_broading)
            flag1 = Determination_of_Spectral_Lines(X, Y, z, lines_ex, delta, center_broading, peripheral_broading)
            if flag0 and flag1:
                return False
            elif flag0 or flag1:
                return 'Please inspect visually'
                
        if type(lines_included) != list:
            lines_in = [lines_included]
        else:
            lines_in = lines_included
        flag0 = Determination_of_Spectral_Lines(x, y, z, lines_in, delta, center_broading, peripheral_broading)
        flag1 = Determination_of_Spectral_Lines(X, Y, z, lines_in, delta, center_broading, peripheral_broading)
        if flag0 and flag1:
            text = str(file_fits[0].header['OBSID']) + '/' + file_fits[0].header['FILENAME']
            return text
        elif flag0 or flag1:
            return 'Please inspect visually'
    return False

In [48]:
def Multiply_Files_Spectral_Lines_Check(path_dir, 
                                        lines_included,
                                        delta = 3.5e-06,
                                        center_broading = 5,
                                        peripheral_broading = 15,
                                        lines_excluded = None,
                                        sn = 0):
    visual_inspect = []
    name = []
    path = []
    if os.path.isfile(path_dir):
        path.append(path_dir)
    else:
        for root,dirs,files in os.walk(path_dir):
            for file in files:
                temp = os.path.join(root,file)
                path.append(temp)
    for i in trange(0,len(path)):
        p = File_Spectral_Lines_Check(path[i], lines_included, delta = delta, center_broading = center_broading, peripheral_broading = peripheral_broading, lines_excluded = lines_excluded, sn = sn)
        if p == 'Please inspect visually':
            visual_inspect.append(path[i])
        elif p:
            name.append(p)
    return [name, visual_inspect]

In [39]:
DA = {
    3970:r'H$\epsilon$',
    3891:r'H$\zeta$',
    4101.7:r'H$\delta$',
    4340.4:r'H$\gamma$',
    4861.3:r'H$\beta$',
    6562.7:r'H$\alpha$'}
DB = {
    4471 : 'He I'
};
DZ = {
    3933.7:'Ca II K',
    3968.5:'Ca II H',
};

In [49]:
a = Multiply_Files_Spectral_Lines_Check(r'C:\Users\Administrator\Desktop\大创\Data\temp\test', list(DA.keys()))

100%|██████████| 9270/9270 [14:42<00:00, 10.51it/s]


In [51]:
len(a[0])

4976