In [6]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, recall_score
from os.path import join
import matplotlib.pyplot as plt
import matplotlib

# device & matplotlib config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams.update({
    'font.family': 'serif',
    'text.usetex': True,
})

# helper to fill repeater based on source
def fill_repeater_from_source(row, data):
    if row['Source'] == 'FRB20220912A':
        return 1
    else:
        return row['Repeater']

# clean numeric strings → float
def clean_numeric_value(value):
    if isinstance(value, str):
        v = value.strip()
        if not v:
            return np.nan
        for ch in ['/', '+', '<', '>', '~']:
            v = v.replace(ch, '')
        if '-' in v:
            v = v.split('-')[0]
        try:
            return float(v)
        except ValueError:
            return np.nan
    try:
        return float(value)
    except (ValueError, TypeError):
        return np.nan

# 1) Load & initial repeater processing
frb_data = pd.read_csv('frb-data.csv')
frb_data['Repeater'] = frb_data['Repeater'].map({'Yes':1,'No':0})
frb_data['Repeater'] = frb_data['Repeater'].fillna(0).astype(int)
frb_data['Repeater'] = frb_data.apply(fill_repeater_from_source, axis=1, data=frb_data)
labels = frb_data['Repeater']

# 2) Extract raw per-sample SNR for weighting
frb_data['SNR_raw'] = frb_data['SNR'].apply(clean_numeric_value).fillna(0)
snr_array = frb_data['SNR_raw'].values

# 3) Define features (exclude 'SNR')
base_features = [
    'Observing_band',
    'Freq_high', 'Freq_low', 'Freq_peak',
    'Width'
]
error_features = [
    'DM_SNR','DM_alig','Flux_density','Fluence','Energy',
    'Polar_l','Polar_c','RM_syn','RM_QUfit','Scatt_t'
]

# clean numeric for all features
for feat in base_features + error_features:
    frb_data[feat] = frb_data[feat].apply(clean_numeric_value)
for feat in error_features:
    frb_data[f'{feat}_err'] = frb_data[f'{feat}_err'].apply(clean_numeric_value)
    frb_data[f'{feat}_upper'] = frb_data[feat] + frb_data[f'{feat}_err']
    frb_data[f'{feat}_lower'] = (frb_data[feat] - frb_data[f'{feat}_err']).clip(lower=0)

# build cleaned feature matrix
features = (
    base_features +
    error_features +
    [f'{f}_upper' for f in error_features] +
    [f'{f}_lower' for f in error_features]
)
frb_data_clean = frb_data[features].fillna(0)

# 4) Scale features
scaler = StandardScaler()
frb_data_scaled = scaler.fit_transform(frb_data_clean)

# keep original indices
indices = frb_data_clean.index

# 5) Split into train/val, including per-sample SNR
train_data, val_data, train_labels, val_labels, train_snr, val_snr, train_indices, val_indices = train_test_split(
    frb_data_scaled, labels.values, snr_array, indices,
    test_size=0.2, random_state=42, stratify=labels.values
)

# 6) To tensors & dataloaders (now yielding snr too)
batch_size = 64
train_tensor = torch.tensor(train_data, dtype=torch.float32)
val_tensor   = torch.tensor(val_data,   dtype=torch.float32)
train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)
val_labels_tensor   = torch.tensor(val_labels,   dtype=torch.long)
train_snr_tensor = torch.tensor(train_snr, dtype=torch.float32)
val_snr_tensor   = torch.tensor(val_snr,   dtype=torch.float32)

train_dataset = TensorDataset(train_tensor, train_labels_tensor, train_snr_tensor)
val_dataset   = TensorDataset(val_tensor,   val_labels_tensor,   val_snr_tensor)
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader    = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)

# 7) Model hyperparams
input_dim   = val_tensor.shape[1]
hidden_dim  = 256
latent_dim  = 10
stop_patience = 8
num_epochs    = 150

# define your VAE + classifier
class SupervisedVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, dropout_rate=0.3, activation=nn.LeakyReLU(0.1)):
        super(SupervisedVAE, self).__init__()
        self.activation = activation
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), self.activation, nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), self.activation, nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), self.activation, nn.Dropout(dropout_rate)
        )
        self.fc_mu     = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), self.activation, nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), self.activation, nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), self.activation, nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, input_dim)
        )
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim//2),
            self.activation, nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim//2, hidden_dim//4),
            self.activation, nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim//4, 1),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        class_logit = self.classifier(mu)
        return recon_x, mu, logvar, class_logit

# 8) Per-sample SNR-weighted loss
def loss_function(recon_x, x, mu, logvar, class_logit,
                  labels, beta, gamma, class_weight, classification_multiplier,
                  snr_batch):
    # recon per-element → [B, D]
    recon_elem = F.mse_loss(recon_x, x, reduction='none')
    # sum over features → [B]
    recon_per = recon_elem.sum(dim=1)
    # KL per-sample → [B]
    kl_per = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    # classification per-sample → [B]
    pos_w = torch.tensor([class_weight], device=device)
    class_elem = F.binary_cross_entropy_with_logits(
        class_logit, labels.unsqueeze(1).float(),
        reduction='none', pos_weight=pos_w
    ).squeeze(1) * classification_multiplier
    # total per-sample
    total_per = recon_per + beta * kl_per + gamma * class_elem
    # weight by SNR and sum → scalar
    weighted_loss = (snr_batch * total_per).sum()
    # also sum raw terms for logging
    recon_loss = recon_per.sum()
    kl_loss    = kl_per.sum()
    class_loss = class_elem.sum()
    return weighted_loss, recon_loss, kl_loss, class_loss

# 9) Training & validation loops
def train_supervised(model, optimizer, scheduler, epoch, beta, gamma, class_weight, classification_multiplier):
    model.train()
    total_loss = total_recon = total_kl = total_class = 0
    correct = total = 0
    for batch_idx, (data, labels, snr) in enumerate(train_loader):
        data, labels, snr = data.to(device), labels.to(device), snr.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar, class_logit = model(data)
        loss, recon_l, kl_l, class_l = loss_function(
            recon_batch, data, mu, logvar, class_logit,
            labels, beta, gamma, class_weight, classification_multiplier,
            snr
        )
        loss.backward()
        optimizer.step()
        total_loss      += loss.item()
        total_recon     += recon_l.item()
        total_kl        += kl_l.item()
        total_class     += class_l.item()
        preds = (torch.sigmoid(class_logit) > 0.5).float().squeeze(1)
        correct += (preds == labels).sum().item()
        total   += labels.size(0)
    return (total_loss/len(train_loader.dataset),
            total_recon/len(train_loader.dataset),
            total_kl/len(train_loader.dataset),
            total_class/len(train_loader.dataset),
            correct/total)

def validate_supervised(model, scheduler, optimizer, epoch, beta, gamma, class_weight, classification_multiplier):
    model.eval()
    val_loss = val_recon = val_kl = val_class = 0
    correct = total = 0
    with torch.no_grad():
        for data, labels, snr in val_loader:
            data, labels, snr = data.to(device), labels.to(device), snr.to(device)
            recon_batch, mu, logvar, class_logit = model(data)
            loss, recon_l, kl_l, class_l = loss_function(
                recon_batch, data, mu, logvar, class_logit,
                labels, beta, gamma, class_weight, classification_multiplier,
                snr
            )
            val_loss   += loss.item()
            val_recon  += recon_l.item()
            val_kl     += kl_l.item()
            val_class  += class_l.item()
            preds = (torch.sigmoid(class_logit) > 0.5).float().squeeze(1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)
    return (val_loss/len(val_loader.dataset),
            val_recon/len(val_loader.dataset),
            val_kl/len(val_loader.dataset),
            val_class/len(val_loader.dataset),
            correct/total)

# 10) Early stopping helper
def early_stopping(val_losses, patience):
    if len(val_losses) > patience:
        if all(val_losses[-i-1] <= val_losses[-i] for i in range(1, patience+1)):
            return True
    return False

# 11) Hyperparams and training
best_params = {
    'hidden_dim': 1082,
    'latent_dim': 18,
    'beta': 1.149574612306723,
    'gamma': 1.9210647260496314,
    'dropout_rate': 0.13093239424733344,
    'lr': 0.0011823749066137313,
    'scheduler_patience': 7,
    'class_weight': 0.35488674730648145,
    'classification_multiplier': 7817.124805902009,
    'activation': nn.ReLU()
}

model = SupervisedVAE(
    input_dim, best_params['hidden_dim'], best_params['latent_dim'],
    best_params['dropout_rate'], best_params['activation']
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=best_params['lr'])
scheduler = ReduceLROnPlateau(optimizer, mode='min',
                              factor=0.5,
                              patience=best_params['scheduler_patience'])

val_losses = []
for epoch in range(1, num_epochs + 1):
    train_loss, _, _, _, train_acc = train_supervised(
        model, optimizer, scheduler, epoch,
        best_params['beta'], best_params['gamma'],
        best_params['class_weight'], best_params['classification_multiplier']
    )
    val_loss, _, _, _, val_acc = validate_supervised(
        model, optimizer, scheduler, epoch,
        best_params['beta'], best_params['gamma'],
        best_params['class_weight'], best_params['classification_multiplier']
    )
    scheduler.step(val_loss)
    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, "
          f"train_acc={train_acc:.4f}, val_acc={val_acc:.4f}")
    val_losses.append(val_loss)
    if early_stopping(val_losses, stop_patience):
        print(f"Early stopping triggered at epoch {epoch}")
        break

# 12) Final evaluation
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for data, labels, _ in val_loader:
        data = data.to(device)
        logits = model(data)[-1]
        preds = (torch.sigmoid(logits) > 0.5).cpu().numpy().squeeze()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

print(f"Validation Accuracy: {accuracy_score(all_labels, all_preds):.4f}")
print("Classification Report:\n", classification_report(all_labels, all_preds,
                                                      target_names=["Non-Repeater","Repeater"]))
print("Confusion Matrix:\n", confusion_matrix(all_labels, all_preds))

Epoch 1: train_loss=33018.1137, val_loss=29742.3203, train_acc=0.6399, val_acc=0.5597
Epoch 2: train_loss=22320.3955, val_loss=42907.5988, train_acc=0.5724, val_acc=0.5932
Epoch 3: train_loss=18624.5551, val_loss=20186.4049, train_acc=0.6109, val_acc=0.5817
Epoch 4: train_loss=19623.9426, val_loss=21332.9525, train_acc=0.6319, val_acc=0.6024
Epoch 5: train_loss=19140.2292, val_loss=21948.5171, train_acc=0.6058, val_acc=0.6047
Epoch 6: train_loss=18674.4522, val_loss=20021.0766, train_acc=0.6923, val_acc=0.6884
Epoch 7: train_loss=17781.6923, val_loss=19549.8854, train_acc=0.6759, val_acc=0.6601
Epoch 8: train_loss=15081.0493, val_loss=14942.9997, train_acc=0.6311, val_acc=0.6411
Epoch 9: train_loss=15705.2692, val_loss=22289.2445, train_acc=0.6899, val_acc=0.6405
Epoch 10: train_loss=15976.8720, val_loss=17336.4876, train_acc=0.6996, val_acc=0.7467
Epoch 11: train_loss=14693.3480, val_loss=18692.4202, train_acc=0.6980, val_acc=0.7328
Epoch 12: train_loss=14187.6983, val_loss=16785.5578