<a href="https://colab.research.google.com/github/potohodnica/magistrska/blob/main/GALAH_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Load libraries

# System libraries
import os, urllib
import glob

# Astro libraries
import astropy.io.fits as pyfits

# PyTorch Specific libraries
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable

# Data manipulation and visualisation specific libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce

# For splitting the data into Train and Test set
from sklearn.model_selection import train_test_split

# This piece of code is required to make use of the GPU instead of CPU for faster processing
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# If it prints "cuda:0" that means it has access to GPU. If it prints out "cpu", then it's still running on CPU.

cuda:0


In [20]:
#@title Settings

# Adjust directory you want to work in
working_directory = '/content/'
os.chdir(working_directory)

# Choose if you want to plot the normalised or reduced spectra
normalised = False

# Choose if you want to save the plot as PNG file
savefig = True

# Print messages
printmsg = False

ccd_dict = {
  1: "B",
  2: "G",
  3: "R",
  4: "I"
  }

# Number of binaries and all types of spectra to work with. Input n = -1, if you want all available spectra.
n_bin = 0
n_all = 2

In [4]:
#@title Download sobject_ids

def download_sobject_ids(bin_or_all):
  try:
    link= 'https://raw.githubusercontent.com/potohodnica/magistrska/main/galah_binaries.tsv'
    urllib.request.urlretrieve(link, working_directory + 'galah_' + bin_or_all + '.tsv')
  except:
    if printmsg:
      print('Download error osubject_ids.')

In [5]:
#@title Sort and merge sobject_ids

def sort_sobject_ids(bin_or_all, n):
  df = pd.read_csv(working_directory + "galah_" + bin_or_all + ".tsv", on_bad_lines='skip', sep = ";", skiprows=40, usecols = ['GALAH']).drop([0, 1])
  df.rename(columns={'GALAH': 'sobject_id'}, inplace=True)
  if not n == -1:
    df = df.sample(n=n, random_state=42)
  return df

def merge_sobject_ids():
  if not os.path.isfile(working_directory + 'galah_binaries.tsv'):
      download_sobject_ids('bin')
  if not os.path.isfile(working_directory + 'galah_all.tsv'):
      download_sobject_ids('all')

  df_bin = sort_sobject_ids('bin', n_bin)
  df_all = sort_sobject_ids('all', n_all)
  df_merged = pd.merge(df_all, df_bin, how='outer', indicator=True)

  df_merged.loc[df_merged['_merge']  == 'left_only', 'bin_tf'] = 0
  df_merged.loc[df_merged['_merge']  == 'both', 'bin_tf'] = 1
  df_merged.drop(['_merge'], axis=1, inplace=True)
  return df_merged

In [6]:
#@title Download spectra

def download_spectra(sobject_id,ccd):
    """
    Try to download the specfici spectrum from Datacentral
    """
   
    try:
      link = 'https://datacentral.org.au/vo/slink/links?ID=' + str(sobject_id) + '&DR=galah_dr3&IDX=0&FILT=' + ccd_dict[ccd] + '&RESPONSEFORMAT=fits'
      urllib.request.urlretrieve(link, working_directory + str(sobject_id) + str(ccd) + '.fits')
      return [working_directory + str(sobject_id) + str(ccd) + '.fits']
    except:
      if printmsg:
        print('FITS ' + str(sobject_id) + str(ccd) + ' not available')
      return []

In [7]:
#@title Read spectra

def read_spectra(sobject_id):
    """
    Read in all available CCDs and give back a dictionary
    Download them if not already in working directory
    """
    
    # Check if FITS files already available in working directory
    fits_files = [[], [], [], []]
    for each_ccd in [1,2,3,4]:
        fits_files[each_ccd-1] = glob.glob(working_directory+str(sobject_id)+str(each_ccd)+'.fits')
    # If not already available, try to download
    for each_ccd in [1,2,3,4]:
        if fits_files[each_ccd-1] == []:
            fits_files[each_ccd-1] = download_spectra(sobject_id,each_ccd)
    
    spectrum = dict()

    for each_ccd in [1,2,3,4]:
        if fits_files[each_ccd-1]!=[]:
            fits = pyfits.open(fits_files[each_ccd-1][0])

            # Extension 0: Reduced spectrum
            # Extension 1: Relative error spectrum
            # Extension 4: Normalised spectrum, NB: cut for CCD4

            if len(fits) == 5:
              ext1 = True
              ext4 = True
            elif len(fits) == 2:
              ext1 = True
              ext4 = False
              if printmsg:
                print('Normalised spectrum missing in',str(each_ccd),'ccd.')
            else:
              ext1 = False
              ext4 = False
              if printmsg:
                print('Relative error spectrum and normalised spectrum missing in',str(each_ccd),'ccd.')

            # Extract wavelength grid for the reduced spectrum
            start_wavelength = fits[0].header["CRVAL1"]
            dispersion       = fits[0].header["CDELT1"]
            nr_pixels        = fits[0].header["NAXIS1"]
            reference_pixel  = fits[0].header["CRPIX1"]
            if reference_pixel == 0:
                reference_pixel = 1
            spectrum['wave_red_'+str(each_ccd)] = ((np.arange(0,nr_pixels)--reference_pixel+1)*dispersion+start_wavelength)

            if ext4:
              # Extract wavelength grid for the normalised spectrum

              start_wavelength = fits[4].header["CRVAL1"]
              dispersion       = fits[4].header["CDELT1"]
              nr_pixels        = fits[4].header["NAXIS1"]
              reference_pixel  = fits[4].header["CRPIX1"]
              if reference_pixel == 0:
                reference_pixel=1
              spectrum['wave_norm_'+str(each_ccd)] = ((np.arange(0,nr_pixels)--reference_pixel+1)*dispersion+start_wavelength)

            # Extract flux and flux error of reduced spectrum
            # Added byteswap for Pandas use ----> https://stackoverflow.com/questions/30283836/creating-pandas-dataframe-from-numpy-array-leads-to-strange-errors
            spectrum['sob_red_'+str(each_ccd)]  = np.array(fits[0].data).byteswap().newbyteorder()
            if ext1:
              spectrum['uob_red_'+str(each_ccd)]  = np.array(fits[0].data * fits[1].data)

            if ext4 and ext1: 
              # Extract flux and flux error of normalised spectrum
              spectrum['sob_norm_'+str(each_ccd)] = np.array(fits[4].data)
              if each_ccd != 4:
                 spectrum['uob_norm_'+str(each_ccd)] = np.array(fits[4].data * fits[1].data)
              else:
                 # for normalised error of CCD4, only used appropriate parts of error spectrum
                 spectrum['uob_norm_4'] = np.array(fits[4].data * (fits[1].data)[-len(spectrum['sob_norm_4']):])

            fits.close()
        else:
            spectrum['wave_red_'+str(each_ccd)] = []
            spectrum['wave_norm_'+str(each_ccd)] = []
            spectrum['sob_red_'+str(each_ccd)] = []
            spectrum['sob_norm_'+str(each_ccd)] = []
            spectrum['uob_red_'+str(each_ccd)] = []
            spectrum['uob_norm_'+str(each_ccd)] = []
    
    spectrum['wave_red'] = np.concatenate(([spectrum['wave_red_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    spectrum['sob_red'] = np.concatenate(([spectrum['sob_red_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    if ext4:
       spectrum['sob_norm'] = np.concatenate(([spectrum['sob_norm_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
       spectrum['wave_norm'] = np.concatenate(([spectrum['wave_norm_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    if ext1:
       spectrum['uob_red'] = np.concatenate(([spectrum['uob_red_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    if ext1 and ext4:
       spectrum['uob_norm'] = np.concatenate(([spectrum['uob_norm_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    
   
    return spectrum

In [30]:
#@title Create training set

def create_training_set():
  X = np.empty(shape=(2, 4096, 4))
  df_sobject_ids = merge_sobject_ids()
  row = 0
  
  for sobject_id in df_sobject_ids["sobject_id"].tolist():
      spectrum = read_spectra(sobject_id)

      B = np.array(spectrum['sob_red_1'])
      G = np.array(spectrum['sob_red_2'])
      R = np.array(spectrum['sob_red_3'])
      I = np.array(spectrum['sob_red_4'])

      B = np.pad(B.astype(float), (0, 4096*1 - B.size), mode='constant', constant_values=np.nan).reshape(4096,)
      G = np.pad(G.astype(float), (0, 4096*1 - G.size), mode='constant', constant_values=np.nan).reshape(4096,)
      R = np.pad(R.astype(float), (0, 4096*1 - R.size), mode='constant', constant_values=np.nan).reshape(4096,)
      I = np.pad(I.astype(float), (0, 4096*1 - I.size), mode='constant', constant_values=np.nan).reshape(4096,)

      X[row] =  np.stack([B, G, R, I], axis=1)
      row = row + 1
  y = df_sobject_ids.sobject_id.values
  return X, y

In [31]:
X, y = create_training_set()

print(X)

[[[0.91046986        nan 0.97166464 0.95877877]
  [0.91004311        nan 0.97940986 0.95867275]
  [0.90906717        nan 0.98719891 0.95458164]
  ...
  [0.95334746        nan 1.00603439 1.01045392]
  [0.96801829        nan 0.99656965 0.98196131]
  [0.97017019        nan 0.97361207 0.97745693]]

 [[0.91206494        nan 0.99545968 0.86882504]
  [0.97498449        nan 0.99786732 0.86768742]
  [0.9892108         nan 1.00298626 0.91121545]
  ...
  [1.06529619        nan 1.02180527 0.97278031]
  [1.05334701        nan 1.0163711  0.97777313]
  [1.09046395        nan 1.00340976 0.9523951 ]]]
