In [37]:
import pandas as pd
import numpy as np
from astropy.coordinates import SkyCoord
from astropy import units as u

import sys
sys.path.append('../')
import os
from src.classification import get_match_label_simple
from src.data import get_data_basic_matches
from sklearn.model_selection import train_test_split
from src.utils import transform_features, normalize_train_test

from joblib import load

In [38]:
df_all_model = pd.read_parquet('../scripts/nway_csc21_gaia3_full_neg_study_dis_niter200.parquet')

In [3]:
benchmark_ids = load('../scripts/jobs/models/neg_study_dis_niter200_withint_with_int_5X_lgbm_0-3_20241113_235113/benchmark_ids.joblib')

In [4]:
def get_train_val_test_splits(df_all_model, benchmark_ids, range_offaxis='0-3', separation=1.3):
   # get initial positives and split test set
   df_pos, _ = get_data_basic_matches(df_all_model, range_offaxis, separation)
   cscids = df_pos['csc21_name'].unique()
   cscids_train_val, cscids_test = train_test_split(cscids, test_size=0.2, random_state=42)
   
   # get train/val split from filtered train_val data
   df_train_val = df_all_model[df_all_model['csc21_name'].isin(cscids_train_val)]
   train_pos, _ = get_data_basic_matches(df_train_val, range_offaxis, separation)
   train_val_cscids = train_pos['csc21_name'].unique()
   cscids_train, cscids_val = train_test_split(train_val_cscids, test_size=0.2, random_state=42)
   
   assert set(benchmark_ids) == set(cscids_test)
   
   # get final datasets
   splits = {}
   for name, ids in [('train', cscids_train), ('val', cscids_val), ('test', cscids_test)]:
       data = df_all_model[df_all_model['csc21_name'].isin(ids)]
       pos, neg = get_data_basic_matches(data, range_offaxis, separation)
       splits[name] = {'pos': pos, 'neg': neg, 'full': data}
       
   return splits

splits = get_train_val_test_splits(df_all_model, benchmark_ids)

Range 0-3: 30279 positives, 310020 negatives
Range 0-3: 24223 positives, 245627 negatives
Range 0-3: 19378 positives, 195135 negatives
Range 0-3: 4845 positives, 50492 negatives
Range 0-3: 6056 positives, 64393 negatives


In [5]:
def validate_splits(splits, model_path):
   """validate train/val splits match saved model data"""
   
   # combine pos/neg sets
   val_data = splits['val']['full']
   train_data = splits['train']['full']
   
   # load saved validation data
   X_eval_saved = load(os.path.join(model_path, 'X_eval.joblib'))
   y_eval_saved = load(os.path.join(model_path, 'y_eval.joblib'))
   
   # prepare validation data
   val_data['eval_label'] = np.where(val_data['match_flag'] == 1, 1, 0)
   
   # preprocess
   X_train, _ = transform_features(train_data, log_transform=False, model_type='lgbm')
   X_val, cat_features = transform_features(val_data, log_transform=False, model_type='lgbm')
   _, X_val_norm, _ = normalize_train_test(X_train, X_val, method='none', 
                                         categorical_features=cat_features)
   
   # verify
   assert X_eval_saved.equals(X_val_norm)
   assert np.array_equal(y_eval_saved, val_data['eval_label'].values)
   
   return True

validate_splits(splits, '../scripts/jobs/models/neg_study_dis_niter200_withint_with_int_5X_lgbm_0-3_20241113_235113')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_data['eval_label'] = np.where(val_data['match_flag'] == 1, 1, 0)


True

In [6]:
# get the test data + everything from the dataset that was not in train and val
test_data = splits['test']['full']

# now get everything that was not in train and val
train_val_data = pd.concat([splits['train']['full'], splits['val']['full']])
train_val_ids = train_val_data['csc21_name'].unique()
not_train_val_data = df_all_model[~df_all_model['csc21_name'].isin(train_val_ids)].copy()

# check if test is IN not_train_val_data
assert set(test_data['csc21_name'].unique()) <= set(not_train_val_data['csc21_name'].unique())

In [7]:
# Define the center of the Orion Nebula region
orion_center = SkyCoord(ra=83.8210 * u.deg, dec=-5.3944 * u.deg, frame='icrs')

# Create SkyCoord objects for all sources
orion_source_coords = SkyCoord(ra=df_all_model['csc21_ra'].values * u.deg, dec=df_all_model['csc21_dec'].values * u.deg, frame='icrs')

# Calculate separations
orion_separations = orion_source_coords.separation(orion_center).to(u.arcmin)

# Filter the dataframe
df_all_model['separation_from_orion'] = orion_separations
orion_sources_in_region = df_all_model[orion_separations <= 30 * u.arcmin].copy()
orion_cscid_list = orion_sources_in_region['csc21_name'].str.replace('_', ' ').str.strip().unique().tolist()
orion_sources_in_region['num_possible_counterparts'] = orion_sources_in_region.groupby('csc21_name')['gaia3_source_id'].transform('count')


In [8]:
df_in_crit = get_match_label_simple(orion_sources_in_region, p_threshold=0.35)

In [9]:
df_in_crit_test_and_more = df_in_crit[df_in_crit['csc21_name'].isin(not_train_val_data['csc21_name'].unique())]

In [10]:
df_in_crit_test_and_more.csc21_name.nunique()

1412

In [39]:
df_in_crit_test_and_more['min_theta_mean'].describe()

count    3723.000000
mean        3.873243
std         3.295419
min         0.028973
25%         1.333459
50%         3.074227
75%         5.609001
max        21.031032
Name: min_theta_mean, dtype: float64

In [22]:
df_in_crit_test_and_more.query('min_theta_mean < 3').sort_values('separation').head(10)[['csc21_name', 'separation', 'p_i', 'p_any']]

Unnamed: 0,csc21_name,separation,p_i,p_any
1058146,2CXO J053517.0-052339,0.005774,1.0,0.999904
1054744,2CXO J053446.9-053414,0.008453,1.0,0.999904
1058743,2CXO J053518.4-052329,0.014907,1.0,0.999896
1059376,2CXO J053521.2-052457,0.015802,1.0,0.999904
1058133,2CXO J053517.0-052333,0.020119,1.0,0.999897
1058385,2CXO J053517.6-052153,0.022458,1.0,0.999899
1058357,2CXO J053517.5-052256,0.022656,1.0,0.999903
1059865,2CXO J053525.0-052258,0.023547,1.0,0.999902
1059800,2CXO J053524.1-052155,0.024223,1.0,0.999896
1058807,2CXO J053518.6-052313,0.024504,1.0,0.999898


In [40]:
# check if all the cases with min_theta_mean < 3 and separation <1.3 are in the test set

assert set(df_in_crit_test_and_more.query('min_theta_mean < 3 and separation<1.3 and p_any>0.9')['csc21_name'].unique()) <= set(test_data['csc21_name'].unique())

In [41]:
def create_performance_table(df):
   """create performance metrics and tabular summary"""
   total_sources = df['csc21_name'].nunique()
   
   match_cases = {
       'total_sources': total_sources,
       'contains_match': {'count': 0, 'pct': 0},
       'exact_match': {'count': 0, 'pct': 0}, 
       'different': {'count': 0, 'pct': 0},
       'no_match': {'count': 0, 'pct': 0},
       'multiple': {'count': 0, 'pct': 0}
   }

   for _, group in df.groupby('csc21_name'):
       ml_matches = group[group['label'] == 1]
       nway_matches = group[group['match_flag'] == 1]
       
       if len(nway_matches) == 1:
           if nway_matches.index.isin(ml_matches.index).all():
               match_cases['contains_match']['count'] += 1
       if len(ml_matches) == 1 and len(nway_matches) == 1:
           if ml_matches.index.equals(nway_matches.index):
               match_cases['exact_match']['count'] += 1
           else:
               match_cases['different']['count'] += 1
       elif len(ml_matches) == 0:
           match_cases['no_match']['count'] += 1
       elif len(ml_matches) > 1:
           match_cases['multiple']['count'] += 1
           
   for key in match_cases:
       if key != 'total_sources':
           match_cases[key]['pct'] = (
               match_cases[key]['count'] / total_sources * 100
           )
   
   table = f"""
Performance Summary (N={total_sources})
=====================================
Category          Count    Percent
-------------------------------------
Contains Match    {match_cases['contains_match']['count']:>5}     {match_cases['contains_match']['pct']:>6.1f}%
Exact Match      {match_cases['exact_match']['count']:>5}     {match_cases['exact_match']['pct']:>6.1f}%
Different Match  {match_cases['different']['count']:>5}     {match_cases['different']['pct']:>6.1f}%
No Match         {match_cases['no_match']['count']:>5}     {match_cases['no_match']['pct']:>6.1f}%
Multiple         {match_cases['multiple']['count']:>5}     {match_cases['multiple']['pct']:>6.1f}%
"""
   
   return match_cases, table

In [42]:
stats, email = create_performance_table(df_in_crit_test_and_more)
print(email)  # for quick review
# or stats for detailed analysis


Performance Summary (N=1412)
Category          Count    Percent
-------------------------------------
Contains Match      990       70.1%
Exact Match        950       67.3%
Different Match      6        0.4%
No Match           416       29.5%
Multiple            40        2.8%



In [60]:
def create_performance_metrics(df):
   """compute match statistics between nway and ml model"""
   metrics = {}
   
   # base counts
   metrics['N_CSC'] = df['csc21_name'].nunique()
   metrics['N_yNWAY'] = df[df['p_any'] > 0.5]['csc21_name'].nunique()
   
   # combined criteria using label column
   ok_matches = (df['p_any'] > 0.5) & (df['label'] == 1)
   metrics['N_OK'] = df[ok_matches]['csc21_name'].nunique()
   
   no_ml = (df['p_any'] > 0.5) & ~df.groupby('csc21_name')['label'].transform(any)
   metrics['N_NoML'] = df[no_ml]['csc21_name'].nunique()
   
   flip = (df['p_any'] <= 0.5) & (df['label'] == 1)
   metrics['N_FLIP'] = df[flip]['csc21_name'].nunique()
   
   none = (df['p_any'] <= 0.5) & ~df.groupby('csc21_name')['label'].transform(any)
   metrics['N_NONE'] = df[none]['csc21_name'].nunique()
   
   # single/multiple match counts 
   ok_single = ok_matches & (df.groupby('csc21_name')['label'].transform(sum) == 1)
   metrics['N_yNWAY+yMLeq1'] = df[ok_single]['csc21_name'].nunique()
   metrics['N_yNWAY+yMLgt1'] = metrics['N_OK'] - metrics['N_yNWAY+yMLeq1']
   
   flip_single = flip & (df.groupby('csc21_name')['label'].transform(sum) == 1)
   metrics['N_FLIPeq1'] = df[flip_single]['csc21_name'].nunique()
   metrics['N_FLIPgt1'] = metrics['N_FLIP'] - metrics['N_FLIPeq1']

   table = f"""
N_CSC = {metrics['N_CSC']}
N_yNWAY = {metrics['N_yNWAY']}
N_OK = {metrics['N_OK']}
N_NoML = {metrics['N_NoML']}
N_FLIP = {metrics['N_FLIP']}
N_NONE = {metrics['N_NONE']}
N_yNWAY+yMLeq1 = {metrics['N_yNWAY+yMLeq1']}
N_yNWAY+yMLgt1 = {metrics['N_yNWAY+yMLgt1']}
N_FLIPeq1 = {metrics['N_FLIPeq1']}
N_FLIPgt1 = {metrics['N_FLIPgt1']}
"""
   
   return metrics, table

In [61]:
stats_vk, email_vk = create_performance_metrics(df_in_crit_test_and_more)

  ok_single = ok_matches & (df.groupby('csc21_name')['label'].transform(sum) == 1)
  flip_single = flip & (df.groupby('csc21_name')['label'].transform(sum) == 1)


In [62]:
print(email_vk)


N_CSC = 1412
N_yNWAY = 1006
N_OK = 952
N_NoML = 54
N_FLIP = 44
N_NONE = 362
N_yNWAY+yMLeq1 = 920
N_yNWAY+yMLgt1 = 32
N_FLIPeq1 = 36
N_FLIPgt1 = 8



In [43]:
def analyze_matches(df, df_all_stacks, threshold):
   """analyze match cases and output standardized csvs:
   coup_nway1_mlX_tr{threshold}: no match by ML  
   coup_nway1_ml1M_tr{threshold}: multiple matches by ML
   coup_nway1_mlneq0_tr{threshold}: different match by ML"""
   
   output_cols = [
       'csc21_name', 'csc21_ra', 'csc21_dec', 'min_theta_mean',
       'detect_stack_id', 'gaia3_source_id', 'p_i', 'p_any', 
       'p_match_ind', 'separation', 'match_flag', 'label', 'phot_g_mean_mag',
       'phot_bp_mean_mag', 'phot_rp_mean_mag', 'bp_rp'
   ]
   
   result = {}
   different_matches = []
   no_counterparts = []
   multiple_matches = []
   
   for csc_id, group in df.groupby('csc21_name'):
       label_matches = group[group['label'] == 1]
       nway_matches = group[group['match_flag'] == 1]
       
       if len(label_matches) == 1 and len(nway_matches) == 1:
           if not label_matches.index.equals(nway_matches.index):
               different_matches.append(group)
       elif len(label_matches) == 0:
           no_counterparts.append(group)
       elif len(label_matches) > 1:
           multiple_matches.append(group)

   # map keys to filenames
   filenames = {
       'different_matches': f'coup_nway1_mlneq0_tr{threshold}',
       'no_counterparts': f'coup_nway1_mlX_tr{threshold}',
       'multiple_matches': f'coup_nway1_ml1M_tr{threshold}'
   }
   
   for key, data in zip(
       ['different_matches', 'no_counterparts', 'multiple_matches'],
       [different_matches, no_counterparts, multiple_matches]
   ):
       if data:
           df_merged = (pd.concat(data)
                       .merge(df_all_stacks[['name', 'detect_stack_id']],
                              left_on='csc21_name', right_on='name', how='left')
                       [output_cols])
           #df_merged.to_csv(f'outputs/{filenames[key]}.csv', index=False)
           result[key] = df_merged

   result['summary'] = {
       'different_matches': len(set(result['different_matches']['csc21_name'])),
       'no_counterparts': len(set(result['no_counterparts']['csc21_name'])),
       'multiple_matches': len(set(result['multiple_matches']['csc21_name']))
   }
   
   return result

In [44]:
# read vot to pandas ../data/all_stacks.vot
from astropy.io.votable import parse_single_table

# read vot to pandas
table = parse_single_table('../data/all_stacks.vot')

# recover column names
df_all_stacks = table.to_table().to_pandas()
df_all_stacks.columns = [col.name for col in table.fields]

In [45]:
result_discrepant = analyze_matches(df_in_crit_test_and_more, df_all_stacks, threshold=0.35)