In [1]:
#util 
import os
from astropy.io import fits
import numpy as np
from scipy.linalg import toeplitz
from lightkurve.correctors import load_tess_cbvs
import matplotlib.pyplot as plt
%matplotlib inline
import pickle
import lightkurve as lk
from info import cadence_bounds


def safe_div(n, d):
    return n / d if d else 0

def load_cbv(sector: int, camera: int, ccd: int, directory: str = "."):
    sector_str = f"{sector:04d}"
    # Find filename pattern
    for fname in os.listdir(directory):
        if (
            f"-s{sector_str}-" in fname
            and f"{camera}-{ccd}" in fname
            and fname.endswith("_cbv.fits")
        ):
            filepath = os.path.join(directory, fname)
            print(f"Opening: {filepath}")
            with fits.open(filepath) as hdul:
                N_vecs = 0
                cbv_matrix = []
                cadence = CADENCENO
                with fits.open(dir+'cbv_sector'+str(sector)+'/'+cbv_names[sector-1] % (1, 1), memmap=True) as hdulist2:
                    for j in range(30):
                        k = j+1
                        try:
                            evec = hdulist2[1].data['VECTOR_%s' % k]
                            cbv_matrix.append(evec)
                            if np.any(evec): N_vecs += 1
                        except: continue
                        raise FileNotFoundError(f"No CBV FITS found for sector {sector}, cam {camera}, ccd {ccd}")
    return cbv_matrix, cadence, N_vecs

# Example usage:
# cbv_matrix, cadence, N_vec = load_cbv(sector=70, camera=4, ccd=4)

def check_symmetric(a, rtol=1e-05):
    return (np.sum(a-a.T) < rtol)

def median_normal(lightcurve):
    """
    Median normalize lightcurve
    
    Args:
    lightcurve : lightcurve to be normalized.
        
    Returns:
    Median-normalized lightcurve.
    """
    lightcurve -= np.nanmedian(lightcurve)
    return lightcurve / np.nanmedian(np.abs(lightcurve))

def mag_normal_med(lightcurve):
    """
    Normalize lightcurve by magnitude (to be used if calculating pairwise correlation with corr_comp)
    
    Args:
    lightcurve :lightcurve to be normalized.
        
    Returns:
    magnitude-normalized data.
    """
    lightcurve -= np.nanmedian(lightcurve) #Subtract median as slightly more robust to outliers
    return lightcurve / np.linalg.norm(lightcurve)

def mag_normal_mean(lightcurve):
    """
    Normalize lightcurve by magnitude (to be used if calculating pairwise correlation with corr_comp)
    
    Args:
    lightcurve :lightcurve to be normalized.
        
    Returns:
    magnitude-normalized data.
    """
    lightcurve -= np.nanmean(lightcurve)
    return lightcurve / np.linalg.norm(lightcurve)

def calc_CDPP(lightcurve, scale, offset):
    """
    CDPP in PPM
    """
    lightcurve *= scale
    smooth_reg = lightcurve - savgol_filter(lightcurve, 97, 2)
    smooth_reg = threshold_data(smooth_reg)
    mean_bin = np.zeros(len(smooth_reg)-14)
    for j in range(len(mean_bin)):
        mean_bin[j] = np.mean(smooth_reg[j:j+13])
    cdpp_reg = ( np.std(mean_bin)*1.168 / np.median(offset) ) / (1e-6) # CDPP in PPM
    return cdpp_reg

def linear_detrend(lightcurve):
    """
    Linearly detrend an individual lightcurve

    Args:
    lightcurve : Length N time series
    
    Returns:
    lightcurve with linear trend removed
    
    """
    z = np.polyfit(range(len(lightcurve)), lightcurve, 1)
    p = np.poly1d(z)
    return lightcurve - p(range(len(lightcurve)))

def threshold_data(data, base_data = None, level=4):
    """
    Threshold outliers (flux samples) at level*std dev and replace with Gaussian random samples.
    
    The filtering is performed in a two-step procedure by first applying a coarse threshold
    to remove extremal values and using this data to calculate the std dev. 
    
    Args:
    data : 1D array containing the data to be thresholded.
    base_data (optional) : base the calculation of points to be thresholded on base_data if supplied (thresholding applied to data)
    level (optional) : Factor by which the standard deviation is multiplied to set the threshold level. Default is 5.
        
    Returns:
    1D array containing the thresholded data.
        
    """
    if base_data is None: base_data = data
    std_ = np.nanstd(base_data)
    diff = np.diff(base_data, prepend=base_data[0])
    thresh = level*std_
    mask = np.ones(len(base_data), dtype=bool)

    mask[np.abs(base_data) > thresh] = False
    mask[np.abs(diff) > thresh] = False

    std_clean = np.nanstd(base_data[mask])
    thresh = level*std_clean

    mask = np.zeros(len(data), dtype=bool)    
    mask[np.abs(base_data) > thresh] = True
    mask[np.abs(diff) > thresh] = True

    data[mask] = np.random.normal(0, std_clean, size=mask.sum())
    return data

def nan_linear_gapfill(data):
    """
    Fill NaN gaps in data using linear interpolation.
    
    The function identifies groups of consecutive NaNs in the data and fills them using 
    a linear interpolation approach based on the values immediately adjacent to the gaps.
    
    Args:
    data : 1D array containing the data with NaN gaps to be filled.
        
    Returns:
    1D array where NaN gaps have been filled using linear interpolation.
        
    """
    goodind = np.where(~np.isnan(data))
    badind = np.where(np.isnan(data))
    gaps = [list(group) for group in mit.consecutive_groups(badind[0])]
    for g in gaps:
        if len(g) == 1:
            data[g[0]] = data[g[0]-1]
            continue
        else:
            grad = (data[g[len(g)-1]+1]-data[g[0]-1])/(len(g)+2)
            data[g] = (np.arange(len(g))*grad) + data[g[0]-1]
    return data


def cbv_matrix(lc_cadence, sector, cam, ccd, model_order = None):
    '''
    Load and mask the TESS cotrending basis vectors (otherwise fit own basis)
    '''
    cbvs = load_tess_cbvs(sector=sector, camera=cam, ccd=ccd, cbv_type='SingleScale')
    if model_order == None: model_order = 16
    cbv_dm = cbvs.to_designmatrix(cbv_indices=np.arange(1, model_order+1))
    V = cbv_dm.values
    V_masked = V[np.in1d(cbvs.cadenceno, lc_cadence)]
    return V_masked.T

def covariance_stellar(lc, cadence_data, N_cadence):
    '''
    Estimate stellar covariance model (stationary toeplitz model), using spectral estimator on detrended light curve
    '''
    filled_lc = np.zeros(N_cadence)
    mask = np.zeros(N_cadence, dtype=bool)
    mask[cadence_data] = True

    filled_lc[mask] = threshold_data(lc, level=3)
    std_ = np.std(filled_lc[cadence_data])
    filled_lc[~mask] = np.random.normal(0, std_, N_cadence - len(cadence_data))

    zp_lc = np.zeros((2*N_cadence) - 1)
    zp_lc[:N_cadence] = filled_lc
    p_noise_smooth = smooth_p(zp_lc, K=3)
    ac = np.real(np.fft.ifft(p_noise_smooth)).astype('float32')
    ac = ac[:N_cadence]
    
    cov_stellar = toeplitz(ac, r = ac)
    masked_cov_stellar = cov_stellar[cadence_data]
    masked_cov_stellar = masked_cov_stellar[:, cadence_data]
    return cov_stellar, masked_cov_stellar
    
def covariance_sector(tid, sector):
    (lc_data, processed_lc_data, detrend_data, norm_offset, quality_data, time_data, cam_data, ccd_data, coeff_ls, centroid_xy_data, pos_xy_corr) = pickle.load(open(os.path.expanduser('TESS/data/light_curves/%s.p' % (tic_id)), 'rb')) 
    cov_c = pickle.load( open(f"TESS/data/priors/{sector}/cov_c%s_%s_%s.p" % (sector, cam, ccd), "rb" ))
    lc_cadence_zero =  time_data[sector] - cadence_bounds[sector][0] - 1
    N_cadence = cadence_bounds[sector][1]-cadence_bounds[sector][0]
    _, cov_s = covariance_stellar(detrend_data[sector].unmasked, cadence_data = lc_cadence_zero, N_cadence = N_cadence)
    V = pickle.load(open(f"TESS/data/priors/{sector}/evec_matrix_%s_%s_%s.p" % (sector, cam_data[sector], ccd_data[sector]), "rb" )) #add correct path!
    V = V[:, lc_cadence_zero]
    cov_z = np.dot(V.T, np.dot(cov_c, V)) + cov_s
    cov_inv_z = jax.numpy.linalg.pinv(cov_z)    
    #cov_inv_z = np.linalg.inv(cov_z)

    print ('symmetry check', check_symmetric(cov_inv_z))
    print ('nan check',  np.sum(np.isnan(cov_inv_z)))
    print (np.sum(cov_inv_z))
    
    lc_detrend_full = np.zeros(N_cadence)
    lc_detrend_full[lc_cadence_zero] = detrend_data[sector]
    cov_inv_z_full = np.zeros((N_cadence, N_cadence))
    cov_inv_z_full[np.ix_(lc_cadence_zero, lc_cadence_zero)] = cov_inv_z
    return lc_detrend_full, cov_inv_z_full
    
def covariance_model(lc, lc_cadence, sector, cam, ccd, model_order = None, full=True):
    '''
    H1: y = z + t + n, H0: y = z + n
    z ~ N(V*mu_c, V cov_c V.T + cov_*)
    Estimate joint covariance of z: Cov_stellar (stationary) + Cov_systematics (low rank)

    args
    model_order : systematics model order
    full: return the covariance and light curve, with uniform time sampling and zero's at masked/missing samples
    '''
    V = cbv_matrix(lc_cadence, sector, cam, ccd, model_order = model_order)
    ls_fit = np.dot(V, lc.T).dot(V) 
    lc_detrend = lc - ls_fit

    # load the systematic noise covariance (estimated from collection of light curves on the same sensor)
    #cov_c = pickle.load( open("cov_c_diag%s_%s_%s.p" % (sector, cam, ccd), "rb" ))
    cov_c = pickle.load( open("cov_c%s_%s_%s.p" % (sector, cam, ccd), "rb" ))

    cov_s = covariance_stellar(np.copy(lc_detrend), lc_cadence)

    cov_z = np.dot(V.T, np.dot(cov_c, V)) + cov_s
    #cov_inv_z = jax.numpy.linalg.pinv(cov_z)
    
    cov_inv_z = np.linalg.inv(cov_z)

    print ('symmetry check', check_symmetric(cov_inv_z))
    print ('nan check',  np.sum(np.isnan(cov_inv_z)))
    print (np.sum(cov_inv_z))
    if full:
        lc_cadence_zero = lc_cadence - lc_cadence[0]
        cadence_len = lc_cadence_zero[-1]
        lc_detrend_full = np.zeros(cadence_len+1)
        lc_detrend_full[lc_cadence_zero] = lc_detrend
        cov_inv_z_full = np.zeros((cadence_len+1, cadence_len+1))
        cov_inv_z_full[np.ix_(lc_cadence_zero, lc_cadence_zero)] = cov_inv_z
        print (np.sum(cov_inv_z_full))
        return lc_detrend_full, cov_inv_z_full
    else:
        return lc_detrend, cov_inv_z

def smooth_p(noise, K=3):
    '''
    Computes the smoothed periodogram
    '''
    N = len(noise)
    p_noise = (1/float(N)) * (np.abs(np.fft.fft(noise))**2)
    integ_periodogram = np.zeros(N)
    for i in range(N):
        if i<K: integ_periodogram[i] = np.sum(p_noise[:i+K])
        elif i>N-K: integ_periodogram[i] = np.sum(p_noise[i-K:])
        else: integ_periodogram[i] = np.sum(p_noise[i-K:i+K])
    integ_periodogram *= (1/float(2*K))
    return integ_periodogram

def box_transit(times_, period, dur, t0, alpha=1):
    """
    Generate a transit signal time-series with box function evaluated at given times.
    
    Args:
    times_ :  Array of time points at which to evaluate the transit time-series
    period : Period of the transit
    dur : Duration of the transit.
    t0 : Epoch.
    alpha : Transit depth. Default is 1.
        
    Returns:
    Transit time series evaluated at `times_`.
    """

    return np.piecewise(times_, [((times_-t0+(dur/2))%period) > dur, ((times_-t0+(dur/2))%period) <= dur], [0, 1])*(-alpha)

def nd_argsort(x):
    return np.array(np.unravel_index(np.argsort(x, axis=None), x.shape)).T[::-1]


def covariance_sector_test(tic_id, sector):
    (lc_data, processed_lc_data, detrend_data, norm_offset, quality_data, time_data, cam_data, ccd_data, coeff_ls, centroid_xy_data, pos_xy_corr) = pickle.load(open('TESS/data/light_curves/%s.p' % (tic_id), 'rb')) 
    lc_cadence_zero =  time_data[sector] - cadence_bounds[sector][0] - 1
    N_cadence = cadence_bounds[sector][1]-cadence_bounds[sector][0]

    _, cov_s = covariance_stellar(detrend_data[sector].unmasked, cadence_data = lc_cadence_zero, N_cadence = N_cadence)
    cov_inv_s = jax.numpy.linalg.pinv(cov_s)    
    #cov_inv_s = np.linalg.pinv(cov_s)    

    print ('symmetry check', check_symmetric(cov_inv_s))
    print ('nan check',  np.sum(np.isnan(cov_inv_s)))
    print (np.sum(cov_inv_s))

    lc_detrend_full = np.zeros(N_cadence)
    lc_detrend_full[lc_cadence_zero] = detrend_data[sector].unmasked
    
    cov_inv_s_full = np.zeros((N_cadence, N_cadence))
    cov_inv_s_full[np.ix_(lc_cadence_zero, lc_cadence_zero)] = cov_inv_s
    return lc_detrend_full, cov_inv_s_full
    

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [2]:
#polynomial_detrend
'''
import numpy as np
from numpy.polynomial.polynomial import Polynomial
from info import cadence_bounds

# Identify gap indices (likely downlinks)
gap_threshold = 360 #1/2 day
dt = np.diff(cadence_data)
gap_indices = np.where(dt > gap_threshold)[0]

# Include start and end for segmentation
split_indices = np.concatenate(([0], gap_indices + 1, [len(time)]))

# Detrend each segment
flux_detrended = np.zeros_like(flux)
for i in range(len(split_indices) - 1):
    start, end = split_indices[i], split_indices[i+1]
    start = np.max(split_indices[i], split_indices[i+1] - 360)
    t_seg = time[start:end]
    f_seg = flux[start:end]

    # Remove NaNs
    valid = np.isfinite(t_seg) & np.isfinite(f_seg)
    if np.count_nonzero(valid) < 5:
        flux_detrended[start:end] = f_seg  # too few points to fit
        continue

    t_valid = t_seg[valid]
    f_valid = f_seg[valid]

    # Fit polynomial (degree can be adjusted)
    coeffs = Polynomial.fit(t_valid, f_valid, deg=2).convert().coef
    trend = Polynomial(coeffs)(t_seg)
    flux_detrended[start:end] = f_seg - trend

# Replace lc.flux with detrended result
lc_detrended = lc.copy()
lc_detrended.flux = flux_detrended

# Optionally: plot
lc_detrended.plot()
'''

'\nimport numpy as np\nfrom numpy.polynomial.polynomial import Polynomial\nfrom info import cadence_bounds\n\n# Identify gap indices (likely downlinks)\ngap_threshold = 360 #1/2 day\ndt = np.diff(cadence_data)\ngap_indices = np.where(dt > gap_threshold)[0]\n\n# Include start and end for segmentation\nsplit_indices = np.concatenate(([0], gap_indices + 1, [len(time)]))\n\n# Detrend each segment\nflux_detrended = np.zeros_like(flux)\nfor i in range(len(split_indices) - 1):\n    start, end = split_indices[i], split_indices[i+1]\n    start = np.max(split_indices[i], split_indices[i+1] - 360)\n    t_seg = time[start:end]\n    f_seg = flux[start:end]\n\n    # Remove NaNs\n    valid = np.isfinite(t_seg) & np.isfinite(f_seg)\n    if np.count_nonzero(valid) < 5:\n        flux_detrended[start:end] = f_seg  # too few points to fit\n        continue\n\n    t_valid = t_seg[valid]\n    f_valid = f_seg[valid]\n\n    # Fit polynomial (degree can be adjusted)\n    coeffs = Polynomial.fit(t_valid, f_vali

In [2]:
#error fix
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_strict_conv_algorithm_picker=false'

In [5]:
#exosearch 
import lightkurve as lk
import numpy as np
import pickle
import jax
import jax.numpy as jnp
from util import *
import scipy
import pandas as pd

day_to_cadence = 720

def transit_num(y_d, num_period):
    ''''
    Arg:
    y_d : y.Cov_inv * transit_profile_d the size of this is ~ N_full / delta

    Returns:
    num_det: returns numerator of likelihoods as a 2D array indexed by period/delta and epoch/delta
    '''
    num_det = jnp.zeros((num_period, num_period)) 
    for p in range(num_period):
        for t0 in range(p+1):
            num_det = num_det.at[p, t0].set(jnp.sum(y_d[t0 :: p+1]))
    return num_det

def transit_den(K_d, num_period):
    '''
    Arg:
    K_d : Cov_inv * (transit_profile_d.transit_profile_d^T) the size of this is ~ (N_full / delta, N_full / delta)

    Returns:
    den_det: returns denominator of likelihoods as a 2D array indexed by period/delta and epoch/delta
    '''
    den_det = jnp.zeros((num_period, num_period))
    for p in range(num_period):
        for t0 in range(p+1): 
            den_det = den_det.at[p, t0].set(jnp.sum(K_d[t0 :: p+1, t0 :: p+1]))
    return den_det


# Load lc and inverse covariance model
# ====================================
#(lc_data, processed_lc_data, detrend_data, norm_offset, quality_data, time_data, cam_data, ccd_data, coeff_ls, centroid_xy_data, pos_xy_corr) = pickle.load(open(os.path.expanduser('~/TESS/data/%s.p' % (tic_id)), 'rb')) 

#our own list of TIC ID's to search

df = pd.read_csv('persistant_tids.txt', header=None, names=['tic_id'])
tic_ids = df['tic_id'].tolist() #reads out our list of tids


tic_id = tic_ids[504]
sector = 73
cam = 1
ccd = 1

lc_detrend, cov_inv = covariance_sector(tic_id, sector)
#data = np.load('TESS/data/light_curves/info/transit_templates(2).npz', allow_pickle=True)
#transit_profile_d = data['transit_templates']
#transit_profile_d = pickle.load(open("TESS/data/light_curves/info/transit_templates.p", "rb"))
data = np.load("transit_templates (2).npz")
print (data['10'].shape)
print("Line 53 finished")


# ====================================
# Defining transit parameter search space (period, epoch, duration)
# Period ranges from 0 to N/2
# epoch ranges from 0 to P

delta = 10 # period and epoch step size in 2-minute samples (origianlly 5)
durations = jnp.array([1, 2, 3, 4, 6, 8, 10, 12, 14, 16])*30 
N_full = len(lc_detrend)
lc_cov_inv = cov_inv.dot(lc_detrend) # y^T Cov_inv
# Compute transit likelihoods over parameter space 

print("Complete")

num_period = int((N_full - durations[-1]) // (2 * delta))# number of periods to search in stepsize of delta
transit_likelihood_stats = np.zeros((len(durations), num_period, num_period))

jit_lik_num = jax.jit(transit_num, static_argnums=(1))
jit_lik_den = jax.jit(transit_den, static_argnums=(1))

print("starting loop")

for i in range(len(durations)):
    print (i)
    d = durations[i]
    #transit_profile = jnp.ones(d) * transit_profile_d[durations[i]//30][:d]
    transit_profile = jnp.ones(d) #* transit_profile_d[durations[i]//30][:d] 
    transit_kernel = jnp.outer(transit_profile, transit_profile)

    #y_d = jax.scipy.signal.convolve(lc_cov_inv, transit_profile)[int(d/2)-1:N_full-int(d/2)-1][::delta]
    y_d = jax.scipy.signal.convolve(lc_cov_inv, transit_profile, method='fft')[int(d/2)-1:N_full-int(d/2)-1][::delta]
    
    
    # commented this out as I get a memory error
    #K_d = jax.scipy.signal.convolve2d(cov_inv, transit_kernel)[int(d/2)-1:N_full-int(d/2)-1,int(d/2)-1:N_full-int(d/2)-1][::delta,::delta]

    # different way to calculate K_d
    K_d = np.zeros((np.shape(y_d)[0], np.shape(y_d)[0]))
    for l in range(num_period):
        for m in range(num_period):
            K_d[l,m] = np.sum(transit_kernel*cov_inv[(l*delta):(l*delta) + d, (m*delta):(m*delta) + d])    
    K_d = jnp.array(K_d)

    likelihoods_num = transit_num(y_d, num_period)
    likelihoods_den = transit_den(K_d, num_period)

    # Output transit detection tests, indexed as [P/delta, t_0/delta]
    transit_likelihood_stats[i] = np.divide(likelihoods_num, np.sqrt(likelihoods_den), out=np.zeros_like(likelihoods_num), where=likelihoods_den!=0.)
    
    break
    
print("ended loop")
    
top_detections = nd_argsort(transit_likelihood_stats)

for i in range(5):
    print (top_detections[i], transit_likelihood_stats[top_detections[i][0], top_detections[i][1], top_detections[i][2]])
    print ('LRT (SNR): ', np.round(transit_likelihood_stats[top_detections[i][0], top_detections[i][1], top_detections[i][2]],2), 'duration (hr): ', np.round(top_detections[i][0]/30, 2), 'period (day): ', np.round(delta*(top_detections[i][1]+1)/(day_to_cadence),2), 'epoch(day): ',  np.round((delta/day_to_cadence)*top_detections[i][2], 2) )
    
print("Script Finished")

symmetry check True
nan check 0
816.24646
(1, 301)
Line 53 finished
Complete
starting loop
0
ended loop
[  0 504 466] 29.484617233276367
LRT (SNR):  29.48 duration (hr):  0.0 period (day):  7.01 epoch(day):  6.47
[  0 501 466] 25.409574508666992
LRT (SNR):  25.41 duration (hr):  0.0 period (day):  6.97 epoch(day):  6.47
[  0 519 466] 21.948177337646484
LRT (SNR):  21.95 duration (hr):  0.0 period (day):  7.22 epoch(day):  6.47
[  0 684 466] 19.71776008605957
LRT (SNR):  19.72 duration (hr):  0.0 period (day):  9.51 epoch(day):  6.47
[  0 515 466] 18.077341079711914
LRT (SNR):  18.08 duration (hr):  0.0 period (day):  7.17 epoch(day):  6.47
Script Finished


In [1]:
print(1)

1


In [8]:
data = np.load('TESS/data/light_curves/info/transit_templates.npz', allow_pickle=True)
print(data.files)
#transit_profile_d = data['transit_templates']
#print(transit_profile_d.files)

names = data.files
print(names[0])

['transit_templates']
transit_templates


In [4]:
!pip install --user numpy==1.26.4
# old version 1.22.4
import numpy
print(numpy.__version__)
print(numpy.__file__)

Collecting numpy==1.26.4
  Using cached numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.22.4
    Uninstalling numpy-1.22.4:
      Successfully uninstalled numpy-1.22.4
[31mERROR: Could not install packages due to an OSError: [Errno 16] Device or resource busy: '.nfs0000001112deafd500000c65'
[0m


In [2]:
import jax
print (jax.devices())


[cuda(id=0), cuda(id=1)]
