In [1]:
from astropy.io import fits # , ascii
import numpy as np
import matplotlib.pyplot as plt


def preprocess_bino(fname_data, fname_err, data_dir):
	"""
	Preprocessor goals.
	Data: 
		- If NaN: Data ZERO and Error infinity.
	Error:
		- If NaN: Error infinity and Data ZERO.
	Output: 
		- A numpy array of shape (Ntargets+1, 2, 32, num_cols). 
			- Though the native data have different number of rows, we use a single fixed number here.
			- Channel 0: Data
			- Channel 1: Error
		- List of headers

	The native data unit is ergs/cm^2/s/nm. Preprocessor changes this to
	10^-17 ergs/cm^2/s/Angstrom.

	First spectrum in native data is saved in loc "1". We follow the same convention.
	"""
	infinity = 1e60
	unit_conversion = 10**18

	# ---- Output variables
	data_err = None
	list_headers = [None]

	# ---- Python list of spectral data/err
	# Zeroth element is a blank.
	data = fits.open(data_dir + fname_data)
	err = fits.open(data_dir + fname_err)

	# ---- Place holder for the output array
	Ncols = data[1].data.shape[1]
	data_err = np.zeros((len(data), 2, 32, Ncols))
	data_err[:, 1, :, :] = infinity  # All data/errors are initially set to zero and infinity.

	for i in range(1, len(data)):
		# ---- Import data
		data_tmp = data[i].data * unit_conversion
		err_tmp = err[i].data * unit_conversion

		# ---- Apply preprocessing
		ibool = np.logical_or.reduce((np.isnan(err_tmp), np.isnan(data_tmp), err_tmp <=0., data_tmp > 10**4))
		data_tmp[ibool] = 0
		err_tmp[ibool] = infinity

		# ---- Trim the data
		idx_min, idx_max = index_edges(data_tmp)
		L_trim = 50
		data_tmp[:, :idx_min+L_trim] = 0
		data_tmp[:, idx_max-L_trim:] = 0
		err_tmp[:, :idx_min+L_trim] = infinity
		err_tmp[:, idx_max-L_trim:] = infinity


		# ---- Save data
		# Nrows = min(32, data_tmp.shape[0])
		# Data is usually crap outside this range
		data_err[i, 0, 4:25] = data_tmp[4:25]
		data_err[i, 1, 4:25] = err_tmp[4:25]

		# ---- Save header
		header_tmp = data[i].header
		list_headers.append(header_tmp)

	# # ---- Plot for bebugging 
	# vmin = -4
	# vmax = 4
	# plt.close()
	# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5))
	# # Plot data
	# data_plot = ax1.imshow(data_tmp, aspect="auto", cmap="gray", interpolation="none", vmin=vmin, vmax=vmax)
	# plt.colorbar(data_plot, ax = ax1)
	# # Plot err
	# err_plot = ax2.imshow(err_tmp, aspect="auto", cmap="gray", interpolation="none", vmin=0.02, vmax=0.05)
	# plt.colorbar(err_plot, ax = ax2)

	# plt.show()
	# plt.close()        
		
	return data_err, list_headers



def bit_from_header(header):
	name = header["SLITOBJ"]
	if name == "stars":
		name = 2**1
	elif name == "gal":
		name = 2**2
	return int(name)

def extract_single_data(data_err, list_headers, specnum):
	"""
	Extract single spectrum data, err, header from the list, arr provided by Preprocessor.
	- specnum: Target object number in a file. Ranges from 1 through approx. 140.
	"""
	header = list_headers[specnum]
	data = data_err[specnum, 0]
	err = data_err[specnum, 1]
	
	return data, err, header

def ivar_from_err(err):
	return 1./np.square(err)


def index_edges(data, num_thres=20):
	"""
	Given long postage stamp of data, return the edges.
	"""
	idx_min = 0
	idx_max = data.shape[1]-1
	tally = np.sum(data == 0., axis=0)
	while tally[idx_min] > num_thres:
		idx_min += 1
	while tally[idx_max] > num_thres:
		idx_max -=1
	return idx_min, idx_max


def wavegrid_from_header(header, Ncols):
	"""
	Construct a linear grid based on the header
	and a user specified number of columns.
	"""
	x0 = header["CRVAL1"] * 10
	dx = header["CDELT1"] * 10
	return x0 + np.arange(0, Ncols, 1.) * dx



def idx_peaks(wavegrid, redz, idx_min=0, idx_max=None):
	"""
	Given a wavelength grid and a redshift, return the indices corresponding to
	the following emission line peaks: OII, Ha, Hb, OIII (1, 2)
	"""
	names = ["OII", "Hb", "OIII1", "OIII2", "Ha"]
	OII = 3727
	Ha = 6563
	Hb = 4861
	OIII1 = 4959
	OIII2 = 5007
	peak_list = [OII, Hb, OIII1, OIII2, Ha]
	
	if idx_max is None:
		idx_max = wavegrid.size-1
	
	# Compute redshifted location
	peak_redshifted_list = []            
	for pk in peak_list:
		peak_redshifted_list.append(pk * (1+redz))
		
	# Compute wavegrid index corresponding to the location. Return -1 if outside the bound.
	index_list = []
	for pk in peak_redshifted_list:
		idx = find_nearest_idx(wavegrid, pk)
		if (idx >=idx_min) and (idx < idx_max):
			index_list.append(idx)
		else:
			index_list.append(-1)
	
	return names, index_list

def find_nearest_idx(arr, x):
	return np.argmin(np.abs(arr-x))

N_peaks = 5
peak2int = {"OII": 0, "Hb": 1, "OIII1": 2, "OIII2":3 , "Ha":4}
name = ["OII", "Hb", "OIII1", "OIII2", "Ha"]

# ---- Mask directory names
data_dir = "../../data/"
# An exapmle directory names
dir_blue = data_dir + "/2-8h30m-270/"
dir_red = data_dir +"/2-8h30m-600/"

# These are file names
fname_data = "obj_abs_slits_lin.fits"
fname_err = "obj_abs_err_slits_lin.fits"



# ---- Load the data 
# Blue data
data_err_1, list_headers_1 = preprocess_bino(fname_data, fname_err, dir_blue)
# Red data
data_err_2, list_headers_2 = preprocess_bino(fname_data, fname_err, dir_red)

# ---- Extract number of objects
Nobjs = data_err_1.shape[0] # Including the blank image at index 0. 

# --- Extract data of the first object
data1, err1, header1 = extract_single_data(data_err_1, list_headers_1, 1)
wavegrid1 = wavegrid_from_header(header1, data1.shape[1]) # Wavegrid corresponding to the red dataset
data2, err2, header2 = extract_single_data(data_err_2, list_headers_2, 1)
wavegrid2 = wavegrid_from_header(header2, data2.shape[1]) # Wavegrid corresponding to the blue dataset

z = 0.9 # For a given redshift
# Find the indicies corresponding to the locations of the 2D data.
peaks, indices = idx_peaks(wavegrid1, z, idx_min=0, idx_max=wavegrid1.size) 
# Plot the image of each peak
fig, ax_list = plt.subplots(1, 5, figsize=(10, 3))
for j in range(len(peaks)):
    idx = indices[j]
    if idx > 0: # If peak is within data region.
        ax_list[peak2int[peaks[j]]].imshow(data1[:, idx-16:idx+16], cmap="gray", interpolation="None")                    
        ax_list[peak2int[peaks[j]]].axvline(x=15.5, c="red", ls="--", lw=1)            
    ax_list[peak2int[peaks[j]]].set_title(peaks[j])
    ax_list[peak2int[peaks[j]]].axis("off")
plt.show()
plt.close()
