In [None]:
import numpy as np
import matplotlib.pyplot as plt

from astropy import stats
from astropy import constants as const
from astropy import units as u
import astropy.time
from astropy.time import Time
import dateutil.parser
from astropy.coordinates import SkyCoord

import scipy
from scipy import stats

from uncertainties import ufloat
from uncertainties.umath import *
from uncertainties import unumpy as unp
from uncertainties import ufloat_fromstr

import math
import pandas as pd
import pyphot
import pyphot.ezunits.pint as pint
import scipy.constants as const

import extinction #https://extinction.readthedocs.io/en/latest/
from extinction import ccm89, remove, apply

import glob
from astropy.io import fits
from astropy.wcs import WCS


In [None]:
def mjday(day):
    "Convert observation date such as 20171131 into Julian day"
    return astropy.time.Time(dateutil.parser.parse(day)).jd - 2400000.5

def fluxdens_to_mag(flux, f0):
    "Convert flux density to magnitude"
    return -2.5 * unp.log10(flux / f0)

def mag_to_fluxdens(mag, f0):
    "Convert magnitude to flux density"
    return f0 * 10**(-mag/2.5)


In [2]:
# WISE data class, has attributes: pandas table with info, galaxy?
# methods: Position and mag diagnostics


class WISE_Data:
    """
    A class used to contain, process and analyse WISE data tables (i.e. .tbl files)

    Attributes
    ----------
    source : str
        Name of source
    datatable : pandas table
        A table containing all the information from the textfile
    allowed_sep : float
        The seperation used by the filter_data method to determine good/bad data
    binned : str
        A flag that indicates if the bin_data method has been run
    filtered : str
        A flag that indicates if the filter_data method has been run
    f0_wise_3_4/self.f0_wise_4_6  :float 
        zp values for these bands, determined from pyphot
    baddata : str
        A flag that switchs on if filtering removes all data
                
    """
    def __init__(self, file,allowed_sep=1.5,source='galname?'):
        "File parameter is IRSA data table"
        neowise_header = []
        for idx,line in enumerate(open(file).readlines()):
            if line.startswith('|'):
                skiprows = idx + 4 # skip over comments and header info
                for i in line.split('|')[1:-1]:
                    neowise_header.append(i.lstrip().rstrip())
                break
        # Read in data in Pandas dataframe
        neowise_read_df = pd.read_fwf(file, skiprows = skiprows, header = None, names = neowise_header)
        ra_median = np.median(neowise_read_df['ra'])
        dec_median = np.median(neowise_read_df['dec'])
        pos_median = SkyCoord(ra_median*u.deg,dec_median*u.deg, frame='icrs')
        neowise_read_df['sep'] = pos_median.separation(
            SkyCoord(neowise_read_df['ra']*u.deg,neowise_read_df['dec']*u.deg,frame='icrs')).arcsec
        self.datatable = neowise_read_df
        self.allowed_sep = allowed_sep
        self.filtered = 'no'
        self.binned = 'no'
        lib = pyphot.get_library()
        self.f0_wise_3_4 = lib['WISE_RSR_W1'].Vega_zero_Jy.magnitude*1e8
        self.f0_wise_4_6 = lib['WISE_RSR_W2'].Vega_zero_Jy.magnitude*1e8
        self.source = source
        self.baddata ='unk'
        #self.description = "This shape has not been described yet"
        #self.author = "Erik Kool and Tom Reynolds"
        
    def position_diag(self,cut=None,contrast =(-200,500),overlay='yes'):
        """
        Makes a plot showing the position of all data in the table
        To make a cropped/zoomed image, cut=[[x1,x2][y1,y2]]
        contrast = (lower,upper) to change scaling
        """
        
        gal_images = glob.glob('./Data/gal_images/*.fits')    
        for filename in gal_images :
            gal_name = filename.split('_')[-3].split('/')[1]
            #print(gal_name)
            if gal_name == self.source :
                source_image= filename
        hdu = fits.open(source_image)
        wcs = WCS(hdu[0].header)
        image_data = hdu[0].data
        
        if cut !=None :
            wcs = wcs[cut[0][0]:cut[0][1], cut[1][0]:cut[1][1]] 
            image_data = image_data[cut[0][0]:cut[0][1], cut[1][0]:cut[1][1]] 
        
        
        fig=plt.figure(figsize=(10,10))
        ax = fig.add_subplot(1, 1, 1, projection=wcs)

        ax.imshow(image_data,clim=contrast,cmap='gray_r')


        for ra, dec, sep in zip(self.datatable['ra'], self.datatable['dec'], self.datatable['sep']):
            if sep > self.allowed_sep: 
                ax.plot(ra, dec, color = 'red', marker = 'o', markersize=5,transform=ax.get_transform('world'))
            else:
                ax.plot(ra, dec, color = 'blue', marker = 'o', markersize=5,transform=ax.get_transform('world'))

        ax.set_xlabel('ra')
        ax.set_ylabel('dec')
        
        fig.set_size_inches(7,7)
        plt.show()
        
    def phot_diag(self):
        """
        Makes a plot showing the magnitudes of the data points and colors them according to flags
        """
        
        fig, ax = plt.subplots(1,1)
        ax.errorbar(self.datatable['w1mpro'], self.datatable['w2mpro'],\
                    xerr = self.datatable['w1sigmpro'], yerr = self.datatable['w2sigmpro'],\
                   linestyle = '')
        # Poor quality
        qual_mask = self.datatable['qual_frame'] == 0
        ax.errorbar(self.datatable['w1mpro'][qual_mask], self.datatable['w2mpro'][qual_mask],\
                    xerr = self.datatable['w1sigmpro'][qual_mask], yerr = self.datatable['w2sigmpro'][qual_mask],\
                   linestyle = '', color = 'red', marker = 'o', label = 'poor qual')       
        # Flagged as upper limit or no profile-fit
        qual_mask = [('X' in i or 'U' in i) for i in self.datatable['ph_qual']]
        ax.errorbar(self.datatable['w1mpro'][qual_mask], self.datatable['w2mpro'][qual_mask],\
                    xerr = self.datatable['w1sigmpro'][qual_mask], yerr = self.datatable['w2sigmpro'][qual_mask],\
                   linestyle = '', color = 'orange', marker = 'o', label = 'photometry flag')       
        # Too far offset
        qual_mask = self.datatable['sep'] > self.allowed_sep
        ax.errorbar(self.datatable['w1mpro'][qual_mask], self.datatable['w2mpro'][qual_mask],\
                    xerr = self.datatable['w1sigmpro'][qual_mask], yerr = self.datatable['w2sigmpro'][qual_mask],\
                   linestyle = '', color = 'black', marker = 'o', label = 'offset')      
        # Close to SAA
        qual_mask = [abs(i) < 5 for i in self.datatable['saa_sep']]
        ax.errorbar(self.datatable['w1mpro'][qual_mask], self.datatable['w2mpro'][qual_mask],\
                    xerr = self.datatable['w1sigmpro'][qual_mask], yerr = self.datatable['w2sigmpro'][qual_mask],\
                   linestyle = '', color = 'purple', marker = 'o', label = 'poor qual')      
        # Within the moon-mask area
        qual_mask = self.datatable['moon_masked'] != 0
        ax.errorbar(self.datatable['w1mpro'][qual_mask], self.datatable['w2mpro'][qual_mask],\
                    xerr = self.datatable['w1sigmpro'][qual_mask], yerr = self.datatable['w2sigmpro'][qual_mask],\
                   linestyle = '', color = 'green', marker = 'o', label = 'moon mask')      
        # Flagged as known artifact
        qual_mask = [(i != 0.0 and i != '0000') for i in self.datatable['cc_flags']]
        ax.errorbar(self.datatable['w1mpro'][qual_mask], self.datatable['w2mpro'][qual_mask],\
                    xerr = self.datatable['w1sigmpro'][qual_mask], yerr = self.datatable['w2sigmpro'][qual_mask],\
                   linestyle = '', color = 'pink', marker = 'o', label = 'artifact flag')   
        ax.legend()
        ax.set_xlabel('W1')
        ax.set_ylabel('W2')
        fig.set_size_inches(7,7)
        plt.show()
        
    def filter_data(self,filters='all'):
        """
        Removes data based on a number of criteria. Currently only has strictest setting WIP
        """
        
        # later can code so that you can filter by one element at a time, based on the plots
        if filters == 'all':
            neowise_mask = [all(constraint) for constraint in zip(
                self.datatable['sep'] < self.allowed_sep,\
                self.datatable['qual_frame'] > 0,\
                self.datatable['qi_fact'] > 0,\
                [('X' not in i and 'U' not in i) for i in self.datatable['ph_qual']],\
                [abs(i) > 5 for i in self.datatable['saa_sep']],\
                self.datatable['moon_masked'] == 0,\
                [(i == 0.0 or i == '0000') for i in self.datatable['cc_flags']],\
                ~np.isnan(self.datatable['w1mpro']),\
                ~np.isnan(self.datatable['w2mpro']))]
        
        neowise_df = pd.DataFrame({})
        neowise_df['mjd'] = self.datatable['mjd'][neowise_mask]
        neowise_df['w1mag'] = self.datatable['w1mpro'][neowise_mask]
        neowise_df['w1sig'] = self.datatable['w1sigmpro'][neowise_mask]
        neowise_df['w2mag'] = self.datatable['w2mpro'][neowise_mask]
        neowise_df['w2sig'] = self.datatable['w2sigmpro'][neowise_mask]
        self.data = neowise_df
        self.filtered = 'yes'
        if len(self.data['w1mag']) == 0:
            self.baddata ='yes'
        
    def bin_data(self,plot='yes'):
        """
        Bins data from sets of observations. Will plot as default
        """
        if self.filtered == 'no':
            neowise_df = pd.DataFrame({})
            neowise_df['mjd'] = self.datatable['mjd'][neowise_mask]
            neowise_df['w1mag'] = self.datatable['w1mpro'][neowise_mask]
            neowise_df['w1sig'] = self.datatable['w1sigmpro'][neowise_mask]
            neowise_df['w2mag'] = self.datatable['w2mpro'][neowise_mask]
            neowise_df['w2sig'] = self.datatable['w2sigmpro'][neowise_mask]
            self.data = neowise_df
            if len(self.data['w1mag']) == 0:
                self.baddata ='yes'
        if self.baddata == 'yes':
            print(self.source+': No good data!')
            return
        
                
        start_epoch = np.min(self.data['mjd'])
        end_epoch = np.max(self.data['mjd'])
        yr = u.year.to(u.d)    
        cycles = (end_epoch - start_epoch)/yr * 2
        print(cycles) # should be close to .5 or .0, as WISE is on a 6 month cycle. Inspect below
        cycles = round(cycles,0)
        bins = [start_epoch - yr/4 + a*(yr/2) for a in np.arange(cycles + 2)]
        if plot == 'yes':
            fig, (ax1, ax2) = plt.subplots(1,2)
            ax1.errorbar(self.data['mjd'], self.data['w1mag'], yerr=self.data['w1sig'],\
                        label=r"W1 measurements", color='orange', linestyle = '', marker = 'o', markersize=5)
            
            ax2.errorbar(self.data['mjd'], self.data['w2mag'], yerr=self.data['w2sig'],\
                        label=r"W2 measurements", color='red', linestyle = '', marker = 'o', markersize=5)
            
            for epoch in bins:
                ax1.axvline(epoch)
                ax2.axvline(epoch)
            
            fig.set_size_inches(12,6)
            
            plt.show()
        neowise_bin_df = pd.DataFrame({})

        # Bin mean mjd according to bin edges
        neowise_bin_df['mjd'] = scipy.stats.binned_statistic(self.data['mjd'].values,
                                                             self.data['mjd'].values,\
                                                       statistic=np.mean, bins=bins, range=None)[0]
        
        # Bin mean magnitude, with error standard error of mean
        neowise_bin_df['w1mag'] = unp.uarray(scipy.stats.binned_statistic(self.data['mjd'].values,
                                                                          self.data['w1mag'].values,\
                                                       statistic=np.mean, bins=bins, range=None)[0],\
                                          scipy.stats.binned_statistic(self.data['mjd'].values,
                                                                       self.data['w1mag'].values,\
                                                       statistic=scipy.stats.sem, bins=bins, range=None)[0])
        
        neowise_bin_df['w2mag'] = unp.uarray(scipy.stats.binned_statistic(self.data['mjd'].values,
                                                                          self.data['w2mag'].values,\
                                                       statistic=np.mean, bins=bins, range=None)[0],\
                                          scipy.stats.binned_statistic(self.data['mjd'].values,
                                                                       self.data['w2mag'].values,\
                                                       statistic=scipy.stats.sem, bins=bins, range=None)[0])
        
        neowise_bin_df['w1flux'] = mag_to_fluxdens(neowise_bin_df['w1mag'], self.f0_wise_3_4)
        neowise_bin_df['w2flux'] = mag_to_fluxdens(neowise_bin_df['w2mag'], self.f0_wise_4_6)
        self.binned_data = neowise_bin_df
    
    def plot_data(self,saveonly='no',
                  path='/home/treynolds/data/LIRGS/WISE/WISE_analysis/Data/WISE_gal_plots/'):
        """
        Makes plots of the W1 and W2 LCs. Will save to the path folder. saveonly = 'yes' not functional
        """
        
        if self.baddata == 'yes':
            print(self.source+': No good data!')
            return
        # Plot measurements and binned values
        fig, (ax1, ax2) = plt.subplots(1,2)
        
        # W1
        ax1.errorbar(self.data['mjd'], self.data['w1mag'], yerr=self.data['w1sig'],
                    label=r"W1 measurements", color='black', linestyle = '', 
                     marker = 'o', markersize=5, alpha = .1, zorder = 0)
        
        ax1.errorbar(self.binned_data['mjd'], unp.nominal_values(self.binned_data['w1mag']),
                     yerr = unp.std_devs(self.binned_data['w1mag']),
                    label=r"W1 mean", color='blue', linestyle = '',
                     marker = 'o', markersize=5, capsize = 3, elinewidth = 1, zorder=1)
        
        # W2
        ax2.errorbar(self.data['mjd'], self.data['w2mag'], yerr=self.data['w2sig'],
                    label=r"W2 measurements", color='black', linestyle = '', 
                     marker = 'o', markersize=5, alpha = .1, zorder = 0)
        
        ax2.errorbar(self.binned_data['mjd'], unp.nominal_values(self.binned_data['w2mag']),
                     yerr = unp.std_devs(self.binned_data['w2mag']),
                    label=r"W2 mean", color='red', linestyle = '',
                     marker = 'o', markersize=5, capsize = 3, elinewidth = 1, zorder=1)
        
        # Optional: explosion/discovery epoch
        eplosion_epoch = mjday('20190111')
        ax1.axvline(x=eplosion_epoch, color='black', linestyle = ':', linewidth = 2.0)
        ax2.axvline(x=eplosion_epoch, color='black', linestyle = ':', linewidth = 2.0)
        
        
        ax1.set_ylim(ymin = max(self.data['w1mag']) + 0.1, ymax = min(self.data['w1mag']) - 0.1)
        ax2.set_ylim(ymin = max(self.data['w2mag']) + 0.1, ymax = min(self.data['w2mag']) - 0.1)
        
        ax1.set_title(self.source+' W1')
        ax2.set_title(self.source+' W2')
        
        ax1.set_xlabel(r'MJD')
        ax2.set_xlabel(r'MJD')
        ax1.set_ylabel(r'mag')
        
        ax1.legend()
        ax2.legend()
        
        
        fig.set_size_inches(14,7)
        fig.savefig(path + f'{self.source}_NeoWISE_lightcurve.png', bbox_inches='tight', dpi=600)
        if saveonly != 'yes':
            plt.show()
                
    def write(self,path='/home/treynolds/data/LIRGS/WISE/WISE_analysis/Data/WISE_gal_processed_data/'):
        """
         Writes data tables to text files at the requested path.
        """
        if self.baddata == 'yes':
            print(self.source+': No good data!')
            return
        # Write masked and binned magnitudes to file
        self.data.to_csv(path + f'{self.source}_NeoWISE_masked.tbl', header = True, index = None, sep = '\t')
        self.binned_data.to_csv(path + f'{self.source}_NeoWISE_binned.tbl', header = True, index = None, sep = '\t')
        

        


        
    


    