In [3]:
import pandas as pd
import numpy as np
import re
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# set device and seeds
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# handle special source mapping
def fill_repeater_from_source(row):
    if row['Source'] == 'FRB20220912A':
        return 1
    return row['Repeater']

# clean numeric entries
def clean_numeric_value(value):
    if isinstance(value, str):
        value = value.strip()
        if not value:
            return np.nan
        try:
            for char in ['/', '+', '<', '>', '~']:
                value = value.replace(char, '')
            if '-' in value:
                value = value.split('-')[0]
            return float(value)
        except ValueError:
            return np.nan
    try:
        return float(value)
    except (ValueError, TypeError):
        return np.nan

# choose activation function
def get_activation_function(name):
    if name == 'ReLU':
        return nn.ReLU()
    if name == 'LeakyReLU':
        return nn.LeakyReLU(0.1)
    if name == 'ELU':
        return nn.ELU()
    if name == 'SELU':
        return nn.SELU()
    if name == 'GELU':
        return nn.GELU()
    raise ValueError(f"Unknown activation: {name}")

# define supervised VAE
class SupervisedVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, dropout_rate=0.3, activation=nn.LeakyReLU(0.1)):
        super(SupervisedVAE, self).__init__()
        self.activation = activation
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            activation,
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            activation,
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            activation,
            nn.Dropout(dropout_rate)
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            activation,
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            activation,
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            activation,
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, input_dim)
        )
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            activation,
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            activation,
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 4, 1)
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        class_prob = self.classifier(mu)
        return recon_x, mu, logvar, class_prob

# loss for VAE and classification
def loss_function(recon_x, x, mu, logvar, class_prob, labels, beta, gamma, class_weight, classification_multiplier):
    mse = nn.MSELoss(reduction='sum')
    pos_weight = torch.tensor([class_weight], dtype=torch.float32, device=device)
    bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    recon_loss = mse(recon_x, x)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    class_loss = classification_multiplier * bce(class_prob, labels.unsqueeze(1).float())
    total = recon_loss + beta * kl_loss + gamma * class_loss
    return total, recon_loss, kl_loss, class_loss

# train for one epoch
def train_epoch(model, loader, optimizer, beta, gamma, class_weight, classification_multiplier):
    model.train()
    total = 0
    for data, labels in loader:
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        recon, mu, logvar, logits = model(data)
        loss, *_ = loss_function(recon, data, mu, logvar, logits, labels, beta, gamma, class_weight, classification_multiplier)
        loss.backward()
        optimizer.step()
        total += loss.item()
    return total / len(loader.dataset)

# validate for one epoch
def validate_epoch(model, loader, beta, gamma, class_weight, classification_multiplier):
    model.eval()
    total = 0
    with torch.no_grad():
        for data, labels in loader:
            data, labels = data.to(device), labels.to(device)
            recon, mu, logvar, logits = model(data)
            loss, *_ = loss_function(recon, data, mu, logvar, logits, labels, beta, gamma, class_weight, classification_multiplier)
            total += loss.item()
    return total / len(loader.dataset)

# stop if no improvement
def early_stopping(losses, patience):
    if len(losses) <= patience:
        return False
    return all(losses[-i-2] <= losses[-i-1] for i in range(1, patience+1))

# get latent codes
def get_latent_representations(model, loader):
    model.eval()
    latents = []
    labels = []
    with torch.no_grad():
        for data, lbl in loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            latents.append(mu.cpu().numpy())
            labels.append(lbl.numpy())
    return np.concatenate(latents), np.concatenate(labels)

# evaluate classification
def evaluate(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data, labels in loader:
            data, labels = data.to(device), labels.to(device)
            logits = model(data)[-1]
            preds = (logits > 0.5).float().cpu().numpy().squeeze()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    accuracy = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=["Non-Repeater", "Repeater"])
    matrix = confusion_matrix(all_labels, all_preds)
    return accuracy, report, matrix

# load and clean data
frb_data = pd.read_csv('frb-data.csv')
frb_data['Repeater'] = frb_data['Repeater'].map({'Yes': 1, 'No': 0}).fillna(0).astype(int)
frb_data['Repeater'] = frb_data.apply(fill_repeater_from_source, axis=1)
error_features = ['DM_SNR', 'DM_alig', 'Flux_density', 'Fluence', 'Energy', 'Polar_l', 'Polar_c', 'RM_syn', 'RM_QUfit', 'Scatt_t']
base_features = ['Observing_band', 'SNR', 'Freq_high', 'Freq_low', 'Freq_peak', 'Width']
for feat in base_features + error_features:
    frb_data[feat] = frb_data[feat].apply(clean_numeric_value)
    frb_data[f'{feat}_err'] = frb_data.get(f'{feat}_err', pd.Series(np.nan)).apply(clean_numeric_value)
for feat in error_features:
    frb_data[f'{feat}_upper'] = frb_data[feat] + frb_data[f'{feat}_err']
    frb_data[f'{feat}_lower'] = (frb_data[feat] - frb_data[f'{feat}_err']).clip(lower=0)
features = base_features + error_features + [f'{f}_upper' for f in error_features] + [f'{f}_lower' for f in error_features]
frb_data_clean = frb_data[features].fillna(0)
labels = frb_data['Repeater']
original_data = frb_data.copy()
all_false_positives = ['FRB20181102A','FRB20180309A','FRB20141113A','FRB20190221B','FRB20210213A','FRB20210303A','FRB20200514B','FRB20211212A','FRB20220506D','FRB20150418A','FRB20190423B','FRB20010621A','FRB20190429B','FRB20010125A','FRB20191109A','FRB20190625A','FRB20191020B','FRB20220725A','FRB20210408H','FRB20190420A','FRB20180907E','FRB20140514A','FRB20010305A','FRB20110523A','FRB20010312A','FRB20190714A','FRB20191221A','FRB20210206A','FRB20221101A','FRB20230718A','FRB20190112A','FRB20200917A','FRB20200125A','FRB20200405A','FRB20210202D']

# set hyperparameters
best_params = {'hidden_dim':1082,'latent_dim':18,'beta':1.149574612306723,'gamma':1.9210647260496314,'dropout_rate':0.13093239424733344,'lr':0.0011823749066137313,'scheduler_patience':7,'class_weight':0.35488674730648145,'classification_multiplier':7817.124805902009,'activation':'ReLU'}
stop_patience = 8
num_epochs = 100

# ablation loop
for feature in features:
    print(f"Ablating feature: {feature}")
    feats = [f for f in features if f != feature]
    subset = frb_data_clean[feats]
    repeater_counts = frb_data[frb_data['Repeater']==1]['Source'].value_counts()
    top5 = repeater_counts.head(5).index.tolist()
    sampled = []
    for rep in top5:
        idxs = frb_data[frb_data['Source']==rep].index
        sampled.extend(np.random.choice(idxs, size=min(5,len(idxs)), replace=False))
    non_top5 = frb_data[~frb_data['Source'].isin(top5)].index
    final_idx = np.concatenate([non_top5, sampled])
    data_ab = subset.loc[final_idx]
    labels_ab = labels.loc[final_idx]
    scaled = StandardScaler().fit_transform(data_ab)
    train_x, val_x, train_y, val_y = train_test_split(scaled, labels_ab, test_size=0.2, random_state=42, stratify=labels_ab)
    train_loader = DataLoader(TensorDataset(torch.tensor(train_x,dtype=torch.float32), torch.tensor(train_y.values,dtype=torch.long)), batch_size=64, shuffle=True)
    val_loader = DataLoader(TensorDataset(torch.tensor(val_x,dtype=torch.float32), torch.tensor(val_y.values,dtype=torch.long)), batch_size=64, shuffle=False)

    model = SupervisedVAE(input_dim=val_x.shape[1], hidden_dim=best_params['hidden_dim'], latent_dim=best_params['latent_dim'], dropout_rate=best_params['dropout_rate'], activation=get_activation_function(best_params['activation'])).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=best_params['lr'])
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=best_params['scheduler_patience'])

    val_losses = []
    for epoch in range(1, num_epochs+1):
        train_epoch(model, train_loader, optimizer, best_params['beta'], best_params['gamma'], best_params['class_weight'], best_params['classification_multiplier'])
        vloss = validate_epoch(model, val_loader, best_params['beta'], best_params['gamma'], best_params['class_weight'], best_params['classification_multiplier'])
        val_losses.append(vloss)
        scheduler.step(vloss)
        if early_stopping(val_losses, stop_patience):
            break

    # classification performance
    acc, report, cm = evaluate(model, val_loader)
    print(f"Validation Accuracy after ablating {feature}: {acc:.4f}")
    print("Confusion Matrix:")
    print(cm)

    # nearest neighbor similarity
    val_latent, _ = get_latent_representations(model, val_loader)
    val_idx = val_y.index
    latent_df = pd.DataFrame(val_latent, index=val_idx)
    latent_df['Source'] = original_data.loc[val_idx,'Source'].values
    nbrs = NearestNeighbors(n_neighbors=6).fit(val_latent)
    dists, inds = nbrs.kneighbors(val_latent)
    results = []
    for src in all_false_positives:
        if src in latent_df['Source'].values:
            idx0 = latent_df[latent_df['Source']==src].index[0]
            pos = latent_df.index.get_loc(idx0)
            for npos, dist in zip(inds[pos][1:], dists[pos][1:]):
                nbr_idx = latent_df.index[npos]
                nbr_src = latent_df.loc[nbr_idx,'Source']
                if labels.loc[idx0] != labels.loc[nbr_idx]:
                    if labels.loc[idx0] == 0:
                        non_rep, rep = src, nbr_src
                    else:
                        non_rep, rep = nbr_src, src
                    results.append((non_rep, rep, float(dist)))
    df_out = pd.DataFrame(results, columns=['Non-Repeater','Repeater','Score'])
    df_out.to_csv(f"ablated_results/{feature}_similar_signals.csv", index=False)


Ablating feature: Observing_band
Validation Accuracy after ablating Observing_band: 0.8996
Confusion Matrix:
[[141   9]
 [ 19 110]]
Ablating feature: SNR
Validation Accuracy after ablating SNR: 0.9104
Confusion Matrix:
[[138  12]
 [ 13 116]]
Ablating feature: Freq_high
Validation Accuracy after ablating Freq_high: 0.8459
Confusion Matrix:
[[121  29]
 [ 14 115]]
Ablating feature: Freq_low
Validation Accuracy after ablating Freq_low: 0.8889
Confusion Matrix:
[[133  17]
 [ 14 115]]
Ablating feature: Freq_peak
Validation Accuracy after ablating Freq_peak: 0.8315
Confusion Matrix:
[[112  38]
 [  9 120]]
Ablating feature: Width
Validation Accuracy after ablating Width: 0.8853
Confusion Matrix:
[[144   6]
 [ 26 103]]
Ablating feature: DM_SNR
Validation Accuracy after ablating DM_SNR: 0.8996
Confusion Matrix:
[[141   9]
 [ 19 110]]
Ablating feature: DM_alig
Validation Accuracy after ablating DM_alig: 0.8853
Confusion Matrix:
[[143   7]
 [ 25 104]]
Ablating feature: Flux_density
Validation Accu