# Imports and constants

In [None]:
import os
import glob
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from astroquery.gaia import Gaia
import warnings

In [None]:
from bokeh.io import output_notebook, reset_output, show, curdoc
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, Range1d
from bokeh.layouts import gridplot, layout, column
from bokeh.transform import linear_cmap, log_cmap
from bokeh.palettes import Spectral6, RdYlBu11, Spectral11, GnBu9

In [None]:
data_path = os.path.normpath('data')

In [None]:
Path(data_path).mkdir(parents=True, exist_ok=True)

In [None]:
np.random.seed(29)

# Data download <a id='data-download'></a>
Data is downloaded from gaia archive. The selected data are the first n objects which satisfy the next requirements:
 - No null values for teff, logg and metallicty
 - They have RVS spectra
 - They have BP_RP spectra
 
NOTE: If you have already executed the notebook you can skip to [Section 2](#load-data)
 
The used query is shown below:

    SELECT TOP N
      gs.source_id, gs.ra, gs.dec, gs.phot_g_mean_mag+5-5*log10(gedr3_distance.r_med_geo) AS mg, 
      gs.teff_gspphot, gs.logg_gspphot, gs.mh_gspphot,
      xp.bp_n_relevant_bases, xp.rp_n_relevant_bases
    FROM gaiadr3.gaia_source AS gs
    JOIN gaiadr3.xp_summary AS xp
    ON gs.source_id = xp.source_id
    JOIN external.gaiaedr3_distance AS gedr3_distance
    ON gedr3_distance.source_id = gs.source_id
    WHERE gs.has_rvs = '1' AND
          gs.teff_gspphot IS NOT NULL AND
          gs.logg_gspphot IS NOT NULL AND
          gs.mh_gspphot IS NOT NULL 
 
 This section is an adaption of this tutorial, check it for more information: https://www.cosmos.esa.int/web/gaia-users/archive/datalink-products

In [None]:
Gaia.login() # Not necessary, but it has benefits when you execute long queries

In [None]:
N_OBJECTS = 20 # More than 5000 would raise an error. For more than 5000 objects check: https://www.cosmos.esa.int/web/gaia-users/archive/datalink-products#Tutorial:--Download-DataLink-products-for-%3E5000-sources

In [None]:
N_OBJECTS = 5000 if N_OBJECTS > 5000 else N_OBJECTS

In [None]:
query = "SELECT TOP {} \
      gs.source_id, gs.ra, gs.dec, gs.phot_g_mean_mag+5-5*log10(gedr3_distance.r_med_geo) AS mg, \
      gs.teff_gspphot, gs.logg_gspphot, gs.mh_gspphot, \
      xp.bp_n_relevant_bases, xp.rp_n_relevant_bases \
    FROM gaiadr3.gaia_source AS gs \
    JOIN gaiadr3.xp_summary AS xp \
      ON gs.source_id = xp.source_id \
    JOIN external.gaiaedr3_distance AS gedr3_distance \
      ON gedr3_distance.source_id = gs.source_id \
    WHERE gs.has_rvs = '1' AND \
          gs.teff_gspphot IS NOT NULL AND \
          gs.logg_gspphot IS NOT NULL AND \
          gs.mh_gspphot IS NOT NULL".format(N_OBJECTS)

In [None]:
job = Gaia.launch_job_async(query)
results = job.get_results()

In [None]:
gaia_source_df = results.to_pandas()

In [None]:
print(gaia_source_df.shape)
gaia_source_df.head()

## Get RVS

In [None]:
retrieval_type = 'RVS' # Options are: 'EPOCH_PHOTOMETRY', 'MCMC_GSPPHOT', 'MCMC_MSC', 'XP_SAMPLED', 
                        # 'XP_CONTINUOUS', 'RVS', 'ALL' // The function 'load_data' only accepts 
                        # a string, so if you want to download several types of data, but not all, 
                        # you should launch the function as many times as different types of data you want 
                        # to download
data_structure = 'COMBINED'   # Options are: 'INDIVIDUAL', 'COMBINED', 'RAW'
data_release   = 'Gaia DR3'     # Options are: 'Gaia DR3' (default), 'Gaia DR2'


rvs_datalink = None
with warnings.catch_warnings(): # We catch warnings since the function 'load_data' may launch a lot
    warnings.simplefilter("ignore")
    
    rvs_datalink = Gaia.load_data(ids=results['source_id'], 
                              data_release = data_release, 
                              retrieval_type=retrieval_type, 
                              data_structure = data_structure, 
                              format = 'csv',
                              verbose = True, # Be careful, if you pass many ids it will generate a large output. 
                              output_file = None)
    
    

In [None]:
print('The following Datalink products have been downloaded:')

dl_keys  = [inp for inp in rvs_datalink.keys()]
dl_keys.sort()

print()
for dl_key in dl_keys:
    print(' * {}'.format(dl_key))

In [None]:
rvs_df = rvs_datalink[dl_keys[0]][0].to_pandas() # It only should have a single key, and inside, just a single table.
                                                 # If the retrieval_type parameters was equal to 'ALL', then there
                                                 # would be several keys and you would have to deal with

In [None]:
print(rvs_df.shape)
rvs_df.head()

## Get BP-RP Spectra

In [None]:
retrieval_type = 'XP_CONTINUOUS' 
data_structure = 'COMBINED'   
data_release   = 'Gaia DR3'   


bp_rp_datalink = None
with warnings.catch_warnings(): 
    warnings.simplefilter("ignore")
    
    bp_rp_datalink = Gaia.load_data(ids=results['source_id'], 
                              data_release = data_release, 
                              retrieval_type=retrieval_type, 
                              data_structure = data_structure, 
                              format = 'csv',
                              verbose = True,
                              output_file = None)
    
    

In [None]:
print('The following Datalink products have been downloaded:')

dl_keys  = [inp for inp in bp_rp_datalink.keys()]
dl_keys.sort()

print()
for dl_key in dl_keys:
    print(' * {}'.format(dl_key))

In [None]:
bp_rp_df = bp_rp_datalink[dl_keys[0]][0].to_pandas() 

In [None]:
print(bp_rp_df.shape)
bp_rp_df.head()

---

## Join general info + bp-rp + rvs

### RVS
We reformat rvs dataframe to get a RVS spectrum per row

In [None]:
rvs_wavelengths = np.unique(rvs_df.wavelength.values)

In [None]:
np.save(os.path.join(data_path, 'rvs_wavelengths'), rvs_wavelengths)

In [None]:
def reformat_rvs_df(rvs_df, verbose=0):
    '''
    Transform the dataframe directly created from gaia datalink rvs spectra and returns a dataframe with a rvs spectrum per row
    '''
    source_ids = list()
    fluxes = list()
    
    _rvs_df = rvs_df.fillna(value=0)
    
    _grouped_rvs_df = _rvs_df.groupby(by='source_id')
    
    n_groups = len(_grouped_rvs_df)
    
    i = 1
    for name, group in _grouped_rvs_df:
        verbose and print('{}/{}'.format(i, n_groups), end='\r')
        source_ids.append(name)
        fluxes.append(group.flux.values.tolist())
        i+=1

    _rvs_df = pd.DataFrame(data={'source_id':source_ids, 'rvs_flux':fluxes})

    return _rvs_df

In [None]:
rvs_df = reformat_rvs_df(rvs_df=rvs_df, verbose=1)

In [None]:
print(rvs_df.shape)
rvs_df.head()

In [None]:
rvs_df.to_csv(os.path.join(data_path, 'simplified_RVS.csv'))

### XP
We load data in a pandas and clean it up to keep only the essential information

In [None]:
def process_xp_df(xp_df):
    '''
    Recive a dataframe with the bp-rp coefficients in the format of the Gaia Datalink, and parse the strings to transform them to numpy arrays
    
    Parameteres:
    - xp_df: Dataframe with the bp-rp coefficients
    '''
    _bp_coeff_list = list()
    _rp_coeff_list = list()

    _xp_df = xp_df[['source_id', 'bp_coefficients', 'rp_coefficients']]    
    
    for _, row in _xp_df.iterrows():
        _bp = row.bp_coefficients
        _rp = row.rp_coefficients
        
        _bp_coefficients = list()
        _rp_coefficients = list()
                
        for str_number in _bp.strip('()').split(','):
            _bp_coefficients.append(float(str_number))
        for str_number in _rp.strip('()').split(','):
            _rp_coefficients.append(float(str_number))
            
        _bp_coeff_list.append(_bp_coefficients)
        _rp_coeff_list.append(_rp_coefficients)
    
    # Convert list to numpy array
    _bp_coeff_array = np.array(_bp_coeff_list)
    _rp_coeff_array = np.array(_rp_coeff_list)
    
    _xp_df = _xp_df.assign(bp_coefficients=_bp_coeff_list,
                           rp_coefficients=_rp_coeff_list)
    
    return _xp_df


In [None]:
xp_df = process_xp_df(xp_df=bp_rp_df)

In [None]:
xp_df.head()

## Teff, logg, mh distribution

In [None]:
fig, ax = plt.subplots(figsize=(10,6))
gaia_source_df.mg.hist(bins=40, ax=ax)

ax.set_xlabel('G')
ax.set_ylabel('Counts')
ax.set_title('G Histogram')

In [None]:
fig, ax = plt.subplots(figsize=(10,6))
gaia_source_df.teff_gspphot.hist(bins=40, ax=ax)

ax.set_xlabel('Teff')
ax.set_ylabel('Counts')
ax.set_title('Teff Histogram')

In [None]:
fig, ax = plt.subplots(figsize=(10,6))
gaia_source_df.mh_gspphot.hist(bins=40, ax=ax)

ax.set_xlabel('Metallicity')
ax.set_ylabel('Counts')
ax.set_title('Mh Histogram')

## Join the three dataframes

In [None]:
# Check the type of source_id column
gaia_source_df.source_id.dtype == xp_df.source_id.dtype == rvs_df.source_id.dtype

In [None]:
xp_df = xp_df.astype({'source_id':'int64'})
rvs_df = rvs_df.astype({'source_id':'int64'})

In [None]:
# Check the type of source_id column
gaia_source_df.source_id.dtype == xp_df.source_id.dtype == rvs_df.source_id.dtype

In [None]:
# We set the source_id column as a index in the three dataframes
gaia_source_df.set_index(keys='source_id', inplace=True)
xp_df.set_index(keys='source_id', inplace=True)
rvs_df.set_index(keys='source_id', inplace=True)

In [None]:
final_df = gaia_source_df.join(other=[xp_df, rvs_df]) 

In [None]:
final_df.head()

In [None]:
final_df.sort_index(inplace=True)

In [None]:
final_df.to_csv(os.path.join(data_path, 'final_df.csv'))

# Load data <a id='load-data'></a>
Only necessary if you have not run the section 1 and you have the data on disk. If you have not never run [Section 2](#data-download), come back and execute it. If you have run it, you can jump into [Section 4](#t-sne)

In [None]:
import ast

In [None]:
rvs_wavelengths = np.load(os.path.join(data_path, 'rvs_wavelengths.npy'))

In [None]:
final_df = pd.read_csv(os.path.join(data_path, 'final_df.csv'), 
                       index_col='source_id')

In [None]:
final_df['bp_coefficients'] = final_df['bp_coefficients'].apply(ast.literal_eval)

In [None]:
final_df['rp_coefficients'] = final_df['rp_coefficients'].apply(ast.literal_eval)

In [None]:
final_df['rvs_flux'] = final_df['rvs_flux'].apply(ast.literal_eval)

In [None]:
final_df.head()

# T-SNE

## Dataset creation <a id='t-sne'></a>

In [None]:
def create_dataset(df, n_coeffs=None):
    '''
    Returns the xp and rvs numpy arrays used for the T-SNE
    
    Parameters:
    - n_coeffs: number of the coefficients to be selected in bp-rp spectra
    '''
    # XP Spectra
    _xp_spectra = list()
    for _, row in df.iterrows():
        _bp = row.bp_coefficients[:n_coeffs]
        _rp = row.rp_coefficients[:n_coeffs]
            
        _xp_spectra.append(_bp + _rp)
    
    _xp_spectra = np.array(_xp_spectra)        
    
    # RVS Spectra
    _rvs_spectra = np.array([np.array(rvs) for rvs in df.rvs_flux.values])
    
    return _xp_spectra, _rvs_spectra
        



In [None]:
xp_X, rvs_X = create_dataset(df=final_df, n_coeffs=3)

In [None]:
np.save(os.path.join(data_path, 'xp'), xp_X)
np.save(os.path.join(data_path, 'rvs'), rvs_X)

In [None]:
print(xp_X.shape)
print(rvs_X.shape)

## T-sne

In [None]:
xp_X_embedded = TSNE(perplexity=90, n_components=2, learning_rate='auto',
                     init='random', random_state=0, verbose=2).fit_transform(xp_X)

In [None]:
rvs_X_embedded = TSNE(perplexity=90, n_components=2, learning_rate='auto',
                     init='random', random_state=0, verbose=2).fit_transform(rvs_X)

## Saving T-SNE data

In [None]:
np.save(os.path.join(data_path, 'rvs_embedded'), rvs_X_embedded)
np.save(os.path.join(data_path, 'xp_embedded'), xp_X_embedded)

# Visualization that combine basic info with tsne

In [None]:
teff_palette = ['#313695',
                '#4575b4',
                '#74add1',
                '#abd9e9',
                '#e0f3f8',
                '#ffffbf',
                '#fee090',
                '#fdae61',
                '#f46d43',
                '#d73027',
                '#a50026']

In [None]:
teff_palette.reverse()

In [None]:
def get_magnitude_sizes(magnitude_range):
    '''
    Returns a representation of the sizes of an object based on their magnitudes
    
    Parameters:
    - magnitude_range: 1d list/array with the object magnitudes
    '''
    _m_r = magnitude_range
    
    # We inverse the magnitude and move the initial range to 0: -g + max(g)
    sizes = -_m_r + max(_m_r)
    
    # We normalize between 0-1
    sizes = (sizes-min(sizes))/(max(sizes)-min(sizes))
    
    # We move the range to 1-10
    sizes = sizes*9 + 1
    
    return sizes

In [None]:
mg_sizes = get_magnitude_sizes(magnitude_range=final_df.mg.values)

In [None]:
curdoc().theme = 'dark_minimal'

In [None]:
def get_low_resolution_rvs_spectra(rvs_spectra, rvs_wavelengths, reduction_factor=2):
    '''
    Reduce rvs spectra with the correct wavelengths by a defined factor
    
    Parameters:
    - rvs_spectra: 2d numpy array with the rvs spectra
    - rvs_wavelengths: 1d numpy array/list with the rvs wavelengths
    - reduction_factor: the size of the reduction 2=2x, 3=3x...
    '''
    return rvs_spectra[:,::reduction_factor], rvs_wavelengths[::reduction_factor]


In [None]:
low_rvs, low_rvs_wavelengths = get_low_resolution_rvs_spectra(rvs_spectra=np.array(final_df.rvs_flux.values.tolist()), 
                                                              rvs_wavelengths=rvs_wavelengths,
                                                              reduction_factor=64)

In [None]:
source = ColumnDataSource(data={'source_id':final_df.index.values,
                                'g':final_df.mg.values,
                                'sizes': mg_sizes, 
                                'teff':final_df.teff_gspphot.values,
                                'logg':final_df.logg_gspphot.values,
                                'mh':final_df.mh_gspphot.values,
                                'rvs':low_rvs.tolist(),
                                'rvs_wavelengths':[low_rvs_wavelengths]*len(low_rvs.tolist()),
                                'bp_rp':xp_X.tolist(),
                                'n_coeffs':[list(range(len(coeffs))) for coeffs in xp_X],
                                'x_xp_tsne':xp_X_embedded[:,0],
                                'y_xp_tsne':xp_X_embedded[:,1],
                                'x_rvs_tsne':rvs_X_embedded[:,0],
                                'y_rvs_tsne':rvs_X_embedded[:,1]})

# Creating custom palette
teff_mapper = log_cmap(field_name='teff', palette=teff_palette, low=min(final_df.teff_gspphot), high=max(final_df.teff_gspphot))
mh_mapper = linear_cmap(field_name='mh', palette=Spectral11, low=min(final_df.mh_gspphot), high=max(final_df.mh_gspphot))
logg_mapper = linear_cmap(field_name='logg', palette=Spectral11, low=min(final_df.logg_gspphot), high=max(final_df.logg_gspphot))

color_mapper = teff_mapper

TOOLS = "pan,box_zoom,box_select,help,reset"


figures = list()

hr_figure = figure(title='HR-diagram', tools=TOOLS, output_backend="webgl")
figures.append(hr_figure)

bp_rp_figure = figure(title='BP-RP spectra', tools=TOOLS, output_backend="webgl") 
figures.append(bp_rp_figure)

rvs_figure = figure(title='RVS spectra', tools=TOOLS, output_backend="webgl")
figures.append(rvs_figure)

xp_tsne_figure = figure(title='XP T-SNE', tools=TOOLS, output_backend="webgl")
figures.append(xp_tsne_figure)

rvs_tsne_figure = figure(title='RVS T-SNE', tools=TOOLS, output_backend="webgl")
figures.append(rvs_tsne_figure)


# HR
hr_figure.circle(x='teff', y='g', size='sizes', source=source, line_color='black', color=color_mapper)
hr_figure.y_range.flipped = True
hr_figure.x_range = Range1d(15000, 2000)

# BP
bp_rp_figure.multi_line(xs='n_coeffs', ys='bp_rp', source=source, color=color_mapper)

# RP
rvs_figure.multi_line(xs='rvs_wavelengths', ys='rvs', source=source, color=color_mapper)

# XP TSNE
xp_tsne_figure.circle(x='x_xp_tsne', y='y_xp_tsne', size='sizes', source=source, line_color='black', color=color_mapper)

# RVS TSNE
rvs_tsne_figure.circle(x='x_rvs_tsne', y='y_rvs_tsne', size='sizes', source=source, line_color='black', color=color_mapper)

# Set autohide parameter for every figure to True
for fig in figures:
    fig.toolbar.autohide = True


# Layout
xp_figure = column([bp_rp_figure, rvs_figure], sizing_mode='stretch_both')

p = layout([[hr_figure, xp_figure],
            [xp_tsne_figure, rvs_tsne_figure]], sizing_mode='stretch_both')

show(p)

# T-SNE 3D

In [None]:
import plotly.express as px

In [None]:
xp_X_embedded_3d = TSNE(perplexity=60, n_components=3, learning_rate='auto',
                     init='random', random_state=0, verbose=2, n_jobs=-1).fit_transform(xp_X)

In [None]:
rvs_X_embedded_3d = TSNE(perplexity=100, n_components=3, learning_rate='auto',
                     init='random', random_state=0, verbose=2).fit_transform(rvs_X)

In [None]:
fig = px.scatter_3d(width=900, height=600, 
                    x=xp_X_embedded_3d[:,0], y=xp_X_embedded_3d[:,1], z=xp_X_embedded_3d[:,2], 
                    size=3*mg_sizes, size_max=max(3*mg_sizes), opacity=1,
                    color=final_df.teff_gspphot, color_continuous_scale='RdYlBu', 
                    template='plotly_dark')
fig.show()

In [None]:
fig = px.scatter_3d(width=900, height=600, 
                    x=rvs_X_embedded_3d[:,0], y=rvs_X_embedded_3d[:,1], z=rvs_X_embedded_3d[:,2], 
                    size=3*mg_sizes, size_max=max(3*mg_sizes), opacity=1,
                    color=final_df.teff_gspphot, color_continuous_scale='RdYlBu', template='plotly_dark')
fig.show()

---