In [9]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import re
from sklearn.decomposition import PCA
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, adjusted_rand_score, confusion_matrix
from sklearn.mixture import GaussianMixture
from sklearn.manifold import TSNE
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, roc_curve, auc
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import f1_score, recall_score
from itertools import product
from sklearn.model_selection import StratifiedKFold
import optuna
import plotly.express as px
from collections import Counter
import umap.umap_ as umap
import matplotlib
from sklearn.manifold import Isomap

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from os.path import join
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
        
set_seed(42)
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams.update({
    'font.family': 'serif',
    'text.usetex': True,
    'font.size':           12,
    # Titles, labels, ticks, legends all at 12 pt
    'axes.titlesize':      12,
    'axes.labelsize':      12,
    'xtick.labelsize':     12,
    'ytick.labelsize':     12,
    'legend.fontsize':     8,
    'figure.titlesize':    12,
    # (Optional) ensure a LaTeX package for scalable fonts
    'text.latex.preamble': r'\usepackage{lmodern}'
})
def fill_repeater_from_source(row, data):
    if row['Source'] == 'FRB20220912A':
        return 1
    else:
        return row['Repeater']
frb_data = pd.read_csv('frb-data.csv')
frb_data['Repeater'] = frb_data['Repeater'].map({'Yes': 1, 'No': 0})
frb_data['Repeater'] = frb_data['Repeater'].fillna(0)
frb_data['Repeater'] = frb_data['Repeater'].astype(int)
frb_data['Repeater'] = frb_data.apply(fill_repeater_from_source, axis=1, data=frb_data)

frb_data['Repeater'].isna().sum()
labels = frb_data['Repeater']

# Function to clean numerical strings and convert to float
def clean_numeric_value(value):
    if isinstance(value, str):
        value = value.strip()
        if not value:
            return np.nan
        try:
            # Remove special characters and split if necessary
            for char in ['/', '+', '<', '>', '~']:
                value = value.replace(char, '')
            if '-' in value:
                value = value.split('-')[0]
            return float(value)
        except ValueError:
            return np.nan
    try:
        return float(value)
    except (ValueError, TypeError):
        return np.nan
    
error_features = [
    'DM_SNR', 'DM_alig', 'Flux_density', 'Fluence', 'Energy',
    'Polar_l', 'Polar_c', 'RM_syn', 'RM_QUfit', 'Scatt_t', 
    #'Scin_f'
]
base_features = [
    'Observing_band', 
    # 'GL', 'GB', 
    
    'SNR', 
    'Freq_high',
    'Freq_low', 'Freq_peak', 
    'Width'
    # 'Repeater',
    #'MJD'
]

for feature in base_features + error_features:
    frb_data[feature] = frb_data[feature].apply(clean_numeric_value)

for feature in error_features:
    frb_data[f'{feature}_err'] = frb_data[f'{feature}_err'].apply(clean_numeric_value)

for feature in error_features:
    frb_data[f'{feature}_upper'] = frb_data[feature] + frb_data[f'{feature}_err']
    frb_data[f'{feature}_lower'] = frb_data[feature] - frb_data[f'{feature}_err']
    frb_data[f'{feature}_lower'] = frb_data[f'{feature}_lower'].clip(lower=0)

features = (
    base_features +
    error_features +
    [f'{feature}_upper' for feature in error_features] +
    [f'{feature}_lower' for feature in error_features]
)

In [10]:

delta_df = pd.DataFrame(columns=['Ablated Feature', 'Added', 'Removed'])

for features in features:
    base_df = pd.read_csv('ablated_results/base_similar_signals.csv')
    ablated_df = pd.read_csv(f'ablated_results/{features}_similar_signals.csv')
    
    base_set = set(base_df['Non-Repeater'])
    ablated_set = set(ablated_df['Non-Repeater'])
    added = ""
    for item in sorted(ablated_set - base_set):
        added += f"{item}, "
    added = added.strip(", ")

    removed = ""
    for item in sorted(base_set - ablated_set):
        removed += f"{item}, "
    removed = removed.strip(", ")
    
    print(f"Ablation: {features}")
    print(f"Added: {added}")
    print(f"Removed: {removed}")

    delta_df.loc[len(delta_df)] = [features, added, removed]


Ablation: Observing_band
Added: FRB20210408H
Removed: FRB20180309A
Ablation: SNR
Added: 
Removed: 
Ablation: Freq_high
Added: 
Removed: 
Ablation: Freq_low
Added: 
Removed: 
Ablation: Freq_peak
Added: 
Removed: 
Ablation: Width
Added: 
Removed: 
Ablation: DM_SNR
Added: 
Removed: 
Ablation: DM_alig
Added: 
Removed: 
Ablation: Flux_density
Added: 
Removed: 
Ablation: Fluence
Added: 
Removed: 
Ablation: Energy
Added: 
Removed: 
Ablation: Polar_l
Added: 
Removed: 
Ablation: Polar_c
Added: 
Removed: FRB20180309A
Ablation: RM_syn
Added: 
Removed: 
Ablation: RM_QUfit
Added: 
Removed: 
Ablation: Scatt_t
Added: FRB20210408H
Removed: FRB20180309A
Ablation: DM_SNR_upper
Added: 
Removed: FRB20180309A
Ablation: DM_alig_upper
Added: FRB20210408H
Removed: 
Ablation: Flux_density_upper
Added: 
Removed: 
Ablation: Fluence_upper
Added: 
Removed: 
Ablation: Energy_upper
Added: 
Removed: 
Ablation: Polar_l_upper
Added: 
Removed: FRB20180309A
Ablation: Polar_c_upper
Added: 
Removed: 
Ablation: RM_syn_upper

In [11]:
delta_df

Unnamed: 0,Ablated Feature,Added,Removed
0,Observing_band,FRB20210408H,FRB20180309A
1,SNR,,
2,Freq_high,,
3,Freq_low,,
4,Freq_peak,,
5,Width,,
6,DM_SNR,,
7,DM_alig,,
8,Flux_density,,
9,Fluence,,


In [15]:
for index, row in delta_df.iterrows():
    print(f'\\texttt{{{row["Ablated Feature"].replace("_", "\\_")}}} & {row["Added"].replace("_", "\\_") if row["Added"] else "-"} & {row["Removed"].replace("_", "\\_") if row["Removed"] else "-"} \\\\ \\hline')

\texttt{Observing\_band} & FRB20210408H & FRB20180309A \\ \hline
\texttt{SNR} & - & - \\ \hline
\texttt{Freq\_high} & - & - \\ \hline
\texttt{Freq\_low} & - & - \\ \hline
\texttt{Freq\_peak} & - & - \\ \hline
\texttt{Width} & - & - \\ \hline
\texttt{DM\_SNR} & - & - \\ \hline
\texttt{DM\_alig} & - & - \\ \hline
\texttt{Flux\_density} & - & - \\ \hline
\texttt{Fluence} & - & - \\ \hline
\texttt{Energy} & - & - \\ \hline
\texttt{Polar\_l} & - & - \\ \hline
\texttt{Polar\_c} & - & FRB20180309A \\ \hline
\texttt{RM\_syn} & - & - \\ \hline
\texttt{RM\_QUfit} & - & - \\ \hline
\texttt{Scatt\_t} & FRB20210408H & FRB20180309A \\ \hline
\texttt{DM\_SNR\_upper} & - & FRB20180309A \\ \hline
\texttt{DM\_alig\_upper} & FRB20210408H & - \\ \hline
\texttt{Flux\_density\_upper} & - & - \\ \hline
\texttt{Fluence\_upper} & - & - \\ \hline
\texttt{Energy\_upper} & - & - \\ \hline
\texttt{Polar\_l\_upper} & - & FRB20180309A \\ \hline
\texttt{Polar\_c\_upper} & - & - \\ \hline
\texttt{RM\_syn\_upper} & - &