In [124]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import re
from sklearn.decomposition import PCA
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, adjusted_rand_score, confusion_matrix, accuracy_score, classification_report, recall_score, f1_score, fbeta_score
from sklearn.manifold import TSNE
from sklearn.model_selection import StratifiedKFold
from collections import Counter
import umap.umap_ as umap
import matplotlib
from sklearn.manifold import Isomap
from os.path import join
from sklearn.neighbors import NearestNeighbors
import pickle
import optuna
import os
from svae import SupervisedVAE, loss_function


In [125]:
device = torch.device("cpu")


df = pd.read_csv('chimefrbcat1.csv')

df['repeater_name'].value_counts()
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
        
set_seed(42)

In [126]:
frb_data = pd.read_csv('chimefrbcat1.csv')

frb_data.head()

def is_repeater(repeater_name):
    return 1 if repeater_name != "-9999" else 0

# Create a new column 'repeater' based on 'repeater_name', if repeater_name is not -9999, set to 1, else 0
frb_data['repeater'] = frb_data['repeater_name'].apply(is_repeater)

print(frb_data.columns)

frb_data['repeater'].value_counts()

frb_data.head(15)
frb_data['tns_name'].value_counts().head(20)
frb_data["mjd_400"] = pd.to_numeric(frb_data["mjd_400"], errors="coerce")

# choose the precision that defines sameness
PREC = 6  # use 3 if that is what is needed

frb_data["mjd_400_r"] = frb_data["mjd_400"].round(PREC)


# round MJD to the nearest 3 decimal places
mask = frb_data["repeater"] == 0
frb_data = pd.concat([
    frb_data[mask].drop_duplicates(subset=["tns_name", "mjd_400_r"], keep="first"),
    frb_data[~mask]
])

frb_data = frb_data.sort_index()

frb_data[frb_data['tns_name']=='FRB20190122C'][['mjd_400_r', 'repeater']].values

len(frb_data)
labels = frb_data['repeater']

base_features = ['bonsai_dm', 'dm_exc_ne2001', 'dm_exc_ymw16', 'bc_width', 'high_freq', 'low_freq', 'peak_freq']
error_features = ['dm_fitb', 'fluence', 'flux', 'sp_idx', 'sp_run']

all_features = base_features + error_features

for feature in all_features:
    # convert to integer if the feature is not already an integer
    if not pd.api.types.is_integer_dtype(frb_data[feature]):
        frb_data[feature] = pd.to_numeric(frb_data[feature]).astype(int)


for feature in error_features:
    frb_data[f"{feature}_lower"] = frb_data[feature] - frb_data[f'{feature}_err']
    frb_data[f"{feature}_upper"] = frb_data[feature] + frb_data[f'{feature}_err']
    

new_features = [f"{feature}_lower" for feature in error_features] + [f"{feature}_upper" for feature in error_features] + base_features



frb_data[new_features + ['repeater']].head(15)
frb_data_clean = frb_data[new_features].dropna()
scaler = StandardScaler()
frb_data_scaled = scaler.fit_transform(frb_data_clean)
indices = frb_data_clean.index
train_data, val_data, train_labels, val_labels, train_indices, val_indices = train_test_split(
    frb_data_scaled, labels, indices, test_size=0.2, random_state=42, stratify=labels
)

# Convert to PyTorch tensors
train_tensor = torch.tensor(train_data, dtype=torch.float32)
val_tensor = torch.tensor(val_data, dtype=torch.float32)
train_labels_tensor = torch.tensor(train_labels.values, dtype=torch.long)
val_labels_tensor = torch.tensor(val_labels.values, dtype=torch.long)

# Create datasets and dataloaders
batch_size = 64
train_dataset = TensorDataset(train_tensor, train_labels_tensor)
val_dataset = TensorDataset(val_tensor, val_labels_tensor)

full_dataset = ConcatDataset([train_dataset, val_dataset])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
input_dim = val_tensor.shape[1]
hidden_dim = 256
latent_dim = 10
stop_patience = 8
num_epochs = 150

def evaluate_classifier(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data, labels in dataloader:
            data = data.to(device)
            labels = labels.to(device)
            class_logits = model(data)[-1]
            preds = (torch.sigmoid(class_logits) > 0.5).float().cpu().numpy().squeeze()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    accuracy = accuracy_score(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, target_names=["Non-Repeater", "Repeater"])
    conf_matrix = confusion_matrix(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds, average='weighted')

    
    false_positives = np.sum((all_labels == 0) & (all_preds == 1))

    return accuracy, class_report, conf_matrix, recall, false_positives  # Return F1 score as well

def get_activation_function(name):
    if name == 'ReLU':
        return nn.ReLU()
    elif name == 'LeakyReLU':
        return nn.LeakyReLU(0.1)
    elif name == 'ELU':
        return nn.ELU()
    elif name == 'SELU':
        return nn.SELU()
    elif name == 'GELU':
        return nn.GELU()
    else:
        raise ValueError(f"Unknown activation function: {name}")

def evaluate_classifier_full(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data, labels in dataloader:
            data = data.to(device)
            class_logits = model(data)[-1]
            preds = (torch.sigmoid(class_logits) > 0.5).float().cpu().numpy().squeeze()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, target_names=["Non-Repeater", "Repeater"])
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    return accuracy, class_report, conf_matrix, all_preds, all_labels


Index(['tns_name', 'previous_name', 'repeater_name', 'ra', 'ra_err',
       'ra_notes', 'dec', 'dec_err', 'dec_notes', 'gl', 'gb', 'exp_up',
       'exp_up_err', 'exp_up_notes', 'exp_low', 'exp_low_err', 'exp_low_notes',
       'bonsai_snr', 'bonsai_dm', 'low_ft_68', 'up_ft_68', 'low_ft_95',
       'up_ft_95', 'snr_fitb', 'dm_fitb', 'dm_fitb_err', 'dm_exc_ne2001',
       'dm_exc_ymw16', 'bc_width', 'scat_time', 'scat_time_err', 'flux',
       'flux_err', 'flux_notes', 'fluence', 'fluence_err', 'fluence_notes',
       'sub_num', 'mjd_400', 'mjd_400_err', 'mjd_inf', 'mjd_inf_err',
       'width_fitb', 'width_fitb_err', 'sp_idx', 'sp_idx_err', 'sp_run',
       'sp_run_err', 'high_freq', 'low_freq', 'peak_freq', 'chi_sq', 'dof',
       'flag_frac', 'excluded_flag', 'repeater'],
      dtype='object')


In [127]:
original_data = pd.read_csv('chimefrbcat1.csv')

In [128]:

garcia_list = '''
FRB20180907E
FRB20180920B
FRB20180928A
FRB20181017B
FRB20181022E
FRB20181125A
FRB20181125A
FRB20181125A
FRB20181214A
FRB20181220A
FRB20181226E
FRB20181229B
FRB20190112A
FRB20190128C
FRB20190206B
FRB20190206A
FRB20190218B
FRB20190223A
FRB20190308C
FRB20190308C
FRB20190323D
FRB20190329A
FRB20190410A
FRB20190412B
FRB20190423B
FRB20190423B
FRB20190429B
FRB20190430A
FRB20190527A
FRB20190527A
FRB20190601C
FRB20190601C
FRB20190617B
FRB20180910A
FRB20190210C
FRB20200726D
'''.split()

luo_list = '''
FRB20181229B
FRB20190423B
FRB20190410A
FRB20181017B
FRB20181128C
FRB20190422A
FRB20190409B
FRB20190329A
FRB20190423B
FRB20190206A
FRB20190128C
FRB20190106A
FRB20190129A
FRB20181030E
FRB20190527A
FRB20190218B
FRB20190609A
FRB20190412B
FRB20190125B
FRB20181231B
FRB20181221A
FRB20190112A
FRB20190125A
FRB20181218C
FRB20190429B
FRB20190109B
FRB20190206B
'''.split()

zhu_ge_list = '''
FRB20180911A
FRB20180915B
FRB20180920B
FRB20180923A
FRB20180923C
FRB20180928A
FRB20181013E
FRB20181017B
FRB20181030E
FRB20181125A
FRB20181125A
FRB20181125A
FRB20181130A
FRB20181214A
FRB20181220A
FRB20181221A
FRB20181226E
FRB20181229B
FRB20181231B
FRB20190106B
FRB20190109B
FRB20190110C
FRB20190111A
FRB20190112A
FRB20190129A
FRB20190204A
FRB20190206A
FRB20190218B
FRB20190220A
FRB20190221A
FRB20190222B
FRB20190223A
FRB20190228A
FRB20190308C
FRB20190308C
FRB20190308B
FRB20190308B
FRB20190323D
FRB20190329A
FRB20190403E
FRB20190409B
FRB20190410A
FRB20190412B
FRB20190418A
FRB20190419A
FRB20190422A
FRB20190422A
FRB20190423A
FRB20190423B
FRB20190423B
FRB20190429B
FRB20190430A
FRB20190517C
FRB20190527A
FRB20190527A
FRB20190531C
FRB20190601B
FRB20190601C
FRB20190601C
FRB20190609A
FRB20190617A
FRB20190617B
FRB20190618A
FRB20190625A
'''.split()

In [129]:
best_params = {'hidden_dim': 1530, 'latent_dim': 16, 'beta': 1.2211908840673436, 'gamma': 0.5885532829581379, 'dropout_rate': 0.10966445430577035, 'lr': 0.00013082216688850454, 'scheduler_patience': 7, 'class_weight': 0.8946298975578247, 'activation': 'ReLU', 'classification_multiplier': 12452.143276136809}

lr = best_params["lr"]
scheduler_patience = best_params["scheduler_patience"]
num_epochs = 150

all_false_positives = []
all_false_negatives = []
all_true_positives = []
all_true_negatives = []

val_preds_full = []
val_labels_full = []

num_epochs = 150

n_folds = 3
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

accuracy = 0

for fold, (train_index, val_index) in enumerate(skf.split(frb_data_scaled, labels)):
# print(f"\n=== Fold {fold + 1}/{n_folds} ===")

    train_data, val_data = frb_data_scaled[train_index], frb_data_scaled[val_index]
    train_labels, val_labels = labels.iloc[train_index], labels.iloc[val_index]
    
    train_tensor = torch.tensor(train_data, dtype=torch.float32)
    val_tensor = torch.tensor(val_data, dtype=torch.float32)
    train_labels_tensor = torch.tensor(train_labels.values, dtype=torch.long)
    val_labels_tensor = torch.tensor(val_labels.values, dtype=torch.long)
    
    train_dataset = TensorDataset(train_tensor, train_labels_tensor)
    val_dataset = TensorDataset(val_tensor, val_labels_tensor)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    best_model = SupervisedVAE(
        input_dim,
        best_params["hidden_dim"],
        best_params["latent_dim"],
        best_params["dropout_rate"],
        get_activation_function(best_params["activation"])
    ).to(device)

    best_model.load_state_dict(torch.load(f"saves/trial_224/model_fold_{fold+1}.pth", map_location=device))
        

    val_accuracy, val_class_report, val_conf_matrix, val_preds, val_labels = evaluate_classifier_full(best_model, val_loader, device)
    
    false_positives = original_data.loc[val_index[(val_labels == 0) & (val_preds == 1)], "tns_name"]
    false_negatives = original_data.loc[val_index[(val_labels == 1) & (val_preds == 0)], "tns_name"]
    true_positives = original_data.loc[val_index[(val_labels == 1) & (val_preds == 1)], "tns_name"]
    true_negatives = original_data.loc[val_index[(val_labels == 0) & (val_preds == 0)], "tns_name"]
    
    val_preds_full.extend(val_preds)
    val_labels_full.extend(val_labels)

    print(f"Fold {fold + 1}/{n_folds} - Validation Accuracy: {val_accuracy:.4f}")
    print("Classification Report:")
    print(val_class_report)
    print("Confusion Matrix:")
    print(val_conf_matrix)
    
    all_false_negatives.extend(false_negatives)
    all_true_positives.extend(true_positives)
    all_true_negatives.extend(true_negatives)
    all_false_positives.extend(false_positives)
    accuracy += val_accuracy
    


accuracy /= n_folds


all_false_positives = pd.Series(all_false_positives)
all_false_negatives = pd.Series(all_false_negatives)
all_true_positives = pd.Series(all_true_positives)
all_true_negatives = pd.Series(all_true_negatives)

print(accuracy)

Fold 1/3 - Validation Accuracy: 0.9789
Classification Report:
              precision    recall  f1-score   support

Non-Repeater       0.99      0.99      0.99       159
    Repeater       0.94      0.94      0.94        31

    accuracy                           0.98       190
   macro avg       0.96      0.96      0.96       190
weighted avg       0.98      0.98      0.98       190

Confusion Matrix:
[[157   2]
 [  2  29]]
Fold 2/3 - Validation Accuracy: 0.9789
Classification Report:
              precision    recall  f1-score   support

Non-Repeater       0.98      0.99      0.99       159
    Repeater       0.97      0.90      0.93        31

    accuracy                           0.98       190
   macro avg       0.97      0.95      0.96       190
weighted avg       0.98      0.98      0.98       190

Confusion Matrix:
[[158   1]
 [  3  28]]
Fold 3/3 - Validation Accuracy: 0.9842
Classification Report:
              precision    recall  f1-score   support

Non-Repeater       0.99

In [130]:
print("")

print("\n=== Summary ===")
print(f"Total False Positives: {all_false_positives.size}")
print(f"Total False Negatives: {all_false_negatives.size}")
print(f"Total True Positives: {all_true_positives.size}")
print(f"Total True Negatives: {all_true_negatives.size}")

conf_mat_dups = np.zeros((2, 2))
conf_mat_dups[0, 0] = all_true_negatives.size
conf_mat_dups[0, 1] = all_false_positives.size
conf_mat_dups[1, 0] = all_false_negatives.size
conf_mat_dups[1, 1] = all_true_positives.size


conf_mat_dups = pd.DataFrame(conf_mat_dups, index=["Non-Repeater", "Repeater"], columns=["Non-Repeater", "Repeater"])
print("\nConfusion Matrix (with duplicates):")
print(conf_mat_dups)

print("accuracy_score")
accuracy = (all_true_positives.size + all_true_negatives.size) / (all_false_positives.size + all_false_negatives.size + all_true_positives.size + all_true_negatives.size)
print(accuracy)



=== Summary ===
Total False Positives: 4
Total False Negatives: 7
Total True Positives: 87
Total True Negatives: 472

Confusion Matrix (with duplicates):
              Non-Repeater  Repeater
Non-Repeater         472.0       4.0
Repeater               7.0      87.0
accuracy_score
0.980701754385965


In [131]:
cm = conf_mat_dups.to_numpy().astype(int)

print(cm)
y_true = np.concatenate([np.zeros(cm[0].sum()), np.ones(cm[1].sum())])
y_pred = np.concatenate([
    np.concatenate([np.zeros(cm[0, 0]), np.ones(cm[0, 1])]),
    np.concatenate([np.zeros(cm[1, 0]), np.ones(cm[1, 1])])
])

print(classification_report(y_true, y_pred, target_names=['Non-Repeater', 'Repeater'], digits=4))

[[472   4]
 [  7  87]]
              precision    recall  f1-score   support

Non-Repeater     0.9854    0.9916    0.9885       476
    Repeater     0.9560    0.9255    0.9405        94

    accuracy                         0.9807       570
   macro avg     0.9707    0.9586    0.9645       570
weighted avg     0.9805    0.9807    0.9806       570



In [132]:
for fp in all_false_positives:
    if fp in luo_list or fp in zhu_ge_list or fp in garcia_list:
        print(fp)

FRB20181218C
FRB20190221A


In [133]:
all_false_positives.to_list()

['FRB20181218C', 'FRB20190122C', 'FRB20190221A', 'FRB20190320A']

In [134]:
def get_model_size_and_params(model):
    # Total number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    

    # Each parameter is a float32 (4 bytes)
    total_size_bytes = total_params * 4
    total_size_mb = total_size_bytes / (1024 ** 2)  # Convert to MB
    
    print(f"Total parameters: {total_params:,}")
    print(f"Model size: {total_size_mb:.2f} MB")

# Example usage:
best_model = SupervisedVAE(
    input_dim,
    best_params["hidden_dim"],
    best_params["latent_dim"],
    best_params["dropout_rate"],
    get_activation_function(best_params["activation"])
).to(device)

get_model_size_and_params(best_model)

Total parameters: 9,822,649
Model size: 37.47 MB


In [135]:
for param in best_params:
    if type(best_params[param]) == float:
        print(f"{param}: {best_params[param]:.4f}")
    else:
        print(f"{param}: {best_params[param]}")

hidden_dim: 1530
latent_dim: 16
beta: 1.2212
gamma: 0.5886
dropout_rate: 0.1097
lr: 0.0001
scheduler_patience: 7
class_weight: 0.8946
activation: ReLU
classification_multiplier: 12452.1433


In [136]:
fbeta_score(val_preds_full, val_labels_full, beta=2, average='weighted')

0.9807405337693228