In [1]:
import pandas as pd
import numpy as np
from astropy.coordinates import SkyCoord
from astropy import units as u

import sys
sys.path.append('../')
import os
from src.classification import get_match_label_simple, get_match_label_advanced
from src.data import get_data_basic_matches
from sklearn.model_selection import train_test_split
from src.utils import transform_features, normalize_train_test

from joblib import load

In [2]:
# this is the whole dataset with probabilities generated by the model
df_all_model = pd.read_parquet('../scripts/nway_csc21_gaia3_full_neg_study_dis_niter200.parquet')

In [3]:
benchmark_ids = load('../scripts/jobs/models/neg_study_dis_niter200_withint_with_int_5X_lgbm_0-3_20241113_235113/benchmark_ids.joblib')

In [4]:
def get_train_val_test_splits(df_all_model, benchmark_ids, range_offaxis='0-3', separation=1.3):
   # get initial positives and split test set
   df_pos, _ = get_data_basic_matches(df_all_model, range_offaxis, separation)
   cscids = df_pos['csc21_name'].unique()
   cscids_train_val, cscids_test = train_test_split(cscids, test_size=0.2, random_state=42)
   
   # get train/val split from filtered train_val data
   df_train_val = df_all_model[df_all_model['csc21_name'].isin(cscids_train_val)]
   train_pos, _ = get_data_basic_matches(df_train_val, range_offaxis, separation)
   train_val_cscids = train_pos['csc21_name'].unique()
   cscids_train, cscids_val = train_test_split(train_val_cscids, test_size=0.2, random_state=42)
   
   assert set(benchmark_ids) == set(cscids_test)
   
   # get final datasets
   splits = {}
   for name, ids in [('train', cscids_train), ('val', cscids_val), ('test', cscids_test)]:
       data = df_all_model[df_all_model['csc21_name'].isin(ids)]
       pos, neg = get_data_basic_matches(data, range_offaxis, separation)
       splits[name] = {'pos': pos, 'neg': neg, 'full': data}
       
   return splits

splits = get_train_val_test_splits(df_all_model, benchmark_ids)

Range 0-3: 30279 positives, 310020 negatives
Range 0-3: 24223 positives, 245627 negatives
Range 0-3: 19378 positives, 195135 negatives
Range 0-3: 4845 positives, 50492 negatives
Range 0-3: 6056 positives, 64393 negatives


In [5]:
def validate_splits(splits, model_path):
   """validate train/val splits match saved model data"""
   
   # combine pos/neg sets
   val_data = splits['val']['full']
   train_data = splits['train']['full']
   
   # load saved validation data
   X_eval_saved = load(os.path.join(model_path, 'X_eval.joblib'))
   y_eval_saved = load(os.path.join(model_path, 'y_eval.joblib'))
   
   # prepare validation data
   val_data['eval_label'] = np.where(val_data['match_flag'] == 1, 1, 0)
   
   # preprocess
   X_train, _ = transform_features(train_data, log_transform=False, model_type='lgbm')
   X_val, cat_features = transform_features(val_data, log_transform=False, model_type='lgbm')
   _, X_val_norm, _ = normalize_train_test(X_train, X_val, method='none', 
                                         categorical_features=cat_features)
   
   # verify
   assert X_eval_saved.equals(X_val_norm)
   assert np.array_equal(y_eval_saved, val_data['eval_label'].values)
   
   return True

validate_splits(splits, '../scripts/jobs/models/neg_study_dis_niter200_withint_with_int_5X_lgbm_0-3_20241113_235113')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_data['eval_label'] = np.where(val_data['match_flag'] == 1, 1, 0)


True

In [6]:
# get the test data + everything from the dataset that was not in train and val
test_data = splits['test']['full']

# now get everything that was not in train and val
train_val_data = pd.concat([splits['train']['full'], splits['val']['full']])
train_val_ids = train_val_data['csc21_name'].unique()
not_train_val_data = df_all_model[~df_all_model['csc21_name'].isin(train_val_ids)].copy()

# check if test is IN not_train_val_data
assert set(test_data['csc21_name'].unique()) <= set(not_train_val_data['csc21_name'].unique())

In [7]:
# Galactic coordinates from Townsley et al.
l = 287.7 * u.degree
b = -0.8 * u.degree

# Convert to RA/Dec
coord = SkyCoord(l=l, b=b, frame='galactic')
radec = coord.icrs

print(f"RA: {radec.ra.to_string(unit=u.hour, sep=':')}")
print(f"Dec: {radec.dec.to_string(unit=u.degree, sep=':')}")

RA: 10:45:09.18751197
Dec: -59:53:00.13780856


In [10]:
# Define the center of the Carina Complex region
carina_center = radec

# Create SkyCoord objects for all sources
carina_source_coords = SkyCoord(ra=df_all_model['csc21_ra'].values * u.deg, dec=df_all_model['csc21_dec'].values * u.deg, frame='icrs')

# Calculate separations
carina_separations = carina_source_coords.separation(carina_center).to(u.arcmin)

# Filter the dataframe
df_all_model['separation_from_carina'] = carina_separations
carina_sources_in_region = df_all_model[carina_separations <= 30 * u.arcmin].copy()
carina_cscid_list = carina_sources_in_region['csc21_name'].str.replace('_', ' ').str.strip().unique().tolist()
carina_sources_in_region['num_possible_counterparts'] = carina_sources_in_region.groupby('csc21_name')['gaia3_source_id'].transform('count')


In [11]:
roc_threshold = 0.44
chance_threshold = 0.466
p_tr = chance_threshold

In [12]:
df_in_crit = get_match_label_advanced(carina_sources_in_region, p_threshold=p_tr)

In [13]:
df_in_crit_test_and_more = df_in_crit[df_in_crit['csc21_name'].isin(not_train_val_data['csc21_name'].unique())]

In [24]:
#df_in_crit_test_and_more = df_in_crit_test_and_more[df_in_crit_test_and_more['min_theta_mean'] <= 3]

In [14]:
df_in_crit_test_and_more.csc21_name.nunique()

3869

In [15]:
df_in_crit_test = df_in_crit_test_and_more[df_in_crit_test_and_more['csc21_name'].isin(test_data['csc21_name'].unique())]

In [18]:
def create_performance_metrics(df, p_threshold=0.466):
   """compute match statistics between nway and ml model"""
   metrics = {}
   
   # base counts
   metrics['N_CSC'] = df['csc21_name'].nunique()
   metrics['N_yNWAY'] = df[df['p_any'] > 0.5]['csc21_name'].nunique()
   
   # combined criteria using label column
   ok_matches = (df['p_any'] > 0.5) & (df['label'] == 1) & (df['match_flag'] == 1)
   ok_ids = df[ok_matches]['csc21_name'].unique()
   metrics['N_OK'] = df[ok_matches]['csc21_name'].nunique()
   
   no_ml = (df['p_any'] > 0.5) & ~df.groupby('csc21_name')['label'].transform(any)
   metrics['N_NoML'] = df[no_ml]['csc21_name'].nunique()
   
   # delete all csc_ids which have an ok match
   not_ok_ids = df[~df['csc21_name'].isin(ok_ids)]['csc21_name'].unique()
   ynway_flipml = df[df['csc21_name'].isin(not_ok_ids) & (df['p_any'] > 0.5) & (df['label'] == 1)]
   metrics['N_yNWAY_FLIP'] = ynway_flipml['csc21_name'].nunique()
   
   flip = (df['p_any'] <= 0.5) & (df['label'] == 1)
   metrics['N_FLIP'] = df[flip]['csc21_name'].nunique()
   
   none = (df['p_any'] <= 0.5) & ~df.groupby('csc21_name')['label'].transform(any)
   metrics['N_NONE'] = df[none]['csc21_name'].nunique()
   
   # single/multiple match counts 
   ok_single = ok_matches & (df.groupby('csc21_name')['label'].transform(sum) == 1)
   metrics['N_yNWAY+yMLeq1'] = df[ok_single]['csc21_name'].nunique()
   metrics['N_yNWAY+yMLgt1'] = metrics['N_OK'] - metrics['N_yNWAY+yMLeq1']
   
   flip_single = flip & (df.groupby('csc21_name')['label'].transform(sum) == 1)
   metrics['N_FLIPeq1'] = df[flip_single]['csc21_name'].nunique()
   metrics['N_FLIPgt1'] = metrics['N_FLIP'] - metrics['N_FLIPeq1']

   # For each Gaia candidate (gaia_id) that has at least one row with high p_match_ind,
   # count it only if in all rows the candidate is outside the separation threshold.   
   high_ml = df[df['p_match_ind'] > p_threshold]
   gaia_flag = high_ml.groupby('gaia3_source_id').apply(lambda g: (g['separation'] > g['threshold_sep']).all())
   print(high_ml['threshold_sep']) 
   print(gaia_flag.sum())
   metrics['N_MLglobal'] = gaia_flag[gaia_flag].index.nunique()

   table = f"""
   N_CSC = {metrics['N_CSC']}
   N_yNWAY = {metrics['N_yNWAY']}
   N_OK = {metrics['N_OK']}
   N_NoML = {metrics['N_NoML']}
   N_yNWAY_FLIP = {metrics['N_yNWAY_FLIP']}
   N_FLIP = {metrics['N_FLIP']}
   N_NONE = {metrics['N_NONE']}
   N_yNWAY+yMLeq1 = {metrics['N_yNWAY+yMLeq1']}
   N_yNWAY+yMLgt1 = {metrics['N_yNWAY+yMLgt1']}
   N_FLIPeq1 = {metrics['N_FLIPeq1']}
   N_FLIPgt1 = {metrics['N_FLIPgt1']}
   N_MLglobal = {metrics['N_MLglobal']}
   """

   return metrics, table

In [19]:
stats_vk, email_vk = create_performance_metrics(df_in_crit_test_and_more)
print(email_vk)

  ok_single = ok_matches & (df.groupby('csc21_name')['label'].transform(sum) == 1)
  flip_single = flip & (df.groupby('csc21_name')['label'].transform(sum) == 1)
  gaia_flag = high_ml.groupby('gaia3_source_id').apply(lambda g: (g['separation'] > g['threshold_sep']).all())


1288020    3
1288025    3
1288026    3
1288029    3
1288135    3
          ..
1383060    4
1383061    4
1383062    4
1383151    5
1383154    5
Name: threshold_sep, Length: 8680, dtype: int64
4467

   N_CSC = 3869
   N_yNWAY = 2953
   N_OK = 1852
   N_NoML = 972
   N_yNWAY_FLIP = 129
   N_FLIP = 116
   N_NONE = 800
   N_yNWAY+yMLeq1 = 1750
   N_yNWAY+yMLgt1 = 102
   N_FLIPeq1 = 103
   N_FLIPgt1 = 13
   N_MLglobal = 4467
   


In [29]:
def get_gaia_multiple_n(df):
   """count total gaia matches in multiple match cases"""
   # get sources with multiple matches
   ok_multiples = (df['p_any'] > 0.5) & (df.groupby('csc21_name')['label'].transform(sum) > 1)
   flip_multiples = (df['p_any'] <= 0.5) & (df.groupby('csc21_name')['label'].transform(sum) > 1)

   # count matches
   n_ok_gaia = df[ok_multiples & (df['label'] == 1)]['gaia3_source_id'].nunique()
   n_flip_gaia = df[flip_multiples & (df['label'] == 1)]['gaia3_source_id'].nunique()
   
   return {
       'N_yNWAY+yMLgt1_gaia': n_ok_gaia,
       'N_FLIPgt1_gaia': n_flip_gaia
   }

In [30]:
get_gaia_multiple_n(df_in_crit_test_and_more)

  ok_multiples = (df['p_any'] > 0.5) & (df.groupby('csc21_name')['label'].transform(sum) > 1)
  flip_multiples = (df['p_any'] <= 0.5) & (df.groupby('csc21_name')['label'].transform(sum) > 1)


{'N_yNWAY+yMLgt1_gaia': 278, 'N_FLIPgt1_gaia': 29}

In [42]:
def analyze_match_cases(df, df_all_stacks, train_val_ids):
    """export csvs for each match case type"""
    output_cols = [
    'csc21_name', 'csc21_ra', 'csc21_dec', 'min_theta_mean',
    'detect_stack_id', 'gaia3_source_id', 'p_i', 'p_any', 
    'p_match_ind', 'separation', 'match_flag', 'label',
    'phot_g_mean_mag', 'phot_bp_mean_mag', 'phot_rp_mean_mag', 'bp_rp'
   ]

    # initialize cases
    N_yNWAY = []
    N_OK = []
    N_NoML = [] 
    N_FLIP = []
    N_NONE = []

    for csc_id, group in df.groupby('csc21_name'):
        nway_confident = group['p_any'].max() > 0.5
        ml_matches = group[group['label'] == 1]
       
        if nway_confident:
            N_yNWAY.append(group)
            # check if the match_flag==1 is also label ==1
            if len(ml_matches) > 0 and (ml_matches['match_flag'] == 1).any():
                N_OK.append(group)
            elif len(ml_matches) == 0:
                N_NoML.append(group)
            elif (ml_matches['match_flag'] != 1).all():
                N_FLIP.append(group)
        elif len(ml_matches) == 0:
            N_NONE.append(group)
   
    # save cases to csv
    cases = {
        'N_yNWAY': N_yNWAY,
        'N_OK': N_OK,
        'N_NoML': N_NoML,
        'N_FLIP': N_FLIP,
        'N_NONE': N_NONE
    }

    for case_name, case_data in cases.items():
        if case_data:
            df_case = (pd.concat(case_data)
                        .merge(df_all_stacks[['name', 'detect_stack_id']], 
                            left_on='csc21_name', right_on='name', 
                            how='left')
                        [output_cols])
            os.makedirs(f'outputs/{p_tr}', exist_ok=True)
            # delete train_val data
            df_case = df_case[~df_case['csc21_name'].isin(train_val_ids)]
            print(f"Saving {case_name} with ({df_case['csc21_name'].nunique()}) csc21_names without train_val data")
            # save to csv
            df_case.to_csv(f'outputs/{p_tr}/carina_{case_name}_trs{p_tr}.csv', index=False)
            
    return {k: len(set(pd.concat(v)['csc21_name'])) if v else 0 
            for k,v in cases.items()}

In [35]:
# read vot to pandas ../data/all_stacks.vot
from astropy.io.votable import parse_single_table

# read vot to pandas
table = parse_single_table('../data/all_stacks.vot')

# recover column names
df_all_stacks = table.to_table().to_pandas()
df_all_stacks.columns = [col.name for col in table.fields]

In [43]:
result_discrepant = analyze_match_cases(df_in_crit_test_and_more, df_all_stacks, train_val_ids)

Saving N_yNWAY with (2953) csc21_names without train_val data
Saving N_OK with (1852) csc21_names without train_val data
Saving N_NoML with (972) csc21_names without train_val data
Saving N_FLIP with (129) csc21_names without train_val data
Saving N_NONE with (800) csc21_names without train_val data


In [44]:
result_discrepant

{'N_yNWAY': 2953, 'N_OK': 1852, 'N_NoML': 972, 'N_FLIP': 129, 'N_NONE': 800}