# Imports

In [None]:
import os
import warnings
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import timm
import scipy.signal
from scipy.signal.windows import tukey
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

# Turn of warning 
warnings.filterwarnings('ignore')

# Check nnAudio lib
try:
    from nnAudio.Spectrogram import CQT1992v2
    print("Library nnAudio imported successfully!")
except ImportError:
    print("Warning: nnAudio not installed. Please pip install nnAudio.")

# Configuration

In [None]:
class CFG:
    debug = False       
    train = True          
    seed = 42
    model_name = 'tf_efficientnet_b4_ns' 
    epochs = 12           
    batch_size = 32   
    lr = 1e-3
    weight_decay = 1e-2
    folds = 5         
    num_workers = 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Path 
    ROOT_DIR = '/kaggle/working'
    DATA_DIR = '/kaggle/input/g2net-gravitational-wave-detection/train'
    TEST_DIR = '/kaggle/input/g2net-gravitational-wave-detection/test'
    CSV_PATH = '/kaggle/input/g2net-gravitational-wave-detection/training_labels.csv'
    SAMPLE_SUB = '/kaggle/input/g2net-gravitational-wave-detection/sample_submission.csv'

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = True 

seed_everything(CFG.seed)
print(f"Device: {CFG.device}")

# Signal Processing Utils

In [None]:
SOS_FILTER = scipy.signal.butter(4, [20, 500], btype="bandpass", output="sos", fs=2048)
NORMALIZATION_FACTOR = np.sqrt((500 - 20) / (2048 / 2))
WINDOW_TUKEY = tukey(4096, alpha=0.1)

def apply_bandpass(x):
    return scipy.signal.sosfiltfilt(SOS_FILTER, x) / NORMALIZATION_FACTOR

def apply_whitening(x, sr=2048):
    x = x * WINDOW_TUKEY
    freqs, psd = scipy.signal.welch(x, fs=sr, nperseg=sr, window='hann')
    
    x_f = np.fft.rfft(x)
    freqs_f = np.fft.rfftfreq(len(x), d=1/sr)
    
    valid_indices = (freqs_f >= freqs.min()) & (freqs_f <= freqs.max())
    x_f_whitened = np.zeros_like(x_f)
    
    psd_values = np.interp(freqs_f[valid_indices], freqs, psd)
    x_f_whitened[valid_indices] = x_f[valid_indices] / np.sqrt(psd_values + 1e-20)
    
    return np.fft.irfft(x_f_whitened, n=len(x))

def apply_timeshift(x, max_shift=0.2, sr=2048):
    shift_amt = int(np.random.random() * max_shift * sr)
    if np.random.random() > 0.5:
        shift_amt = -shift_amt
    return np.roll(x, shift_amt, axis=-1)

# Dataset & DataLoader

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, root_dir, stage='train', augment=False):
        self.df = dataframe
        self.root_dir = root_dir
        self.stage = stage
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_id = str(row['id'])
        path_part = f"{file_id[0]}/{file_id[1]}/{file_id[2]}"
        file_path = os.path.join(self.root_dir, path_part, f"{file_id}.npy")
        
        try:
            waves = np.load(file_path).astype(np.float32)
            
            if self.stage == 'train' and self.augment:
                waves = apply_timeshift(waves)
            
            # Preprocessing Pipeline
            cleaned_waves = []
            for i in range(3):
                w = waves[i]
                w = apply_bandpass(w)       # Step 1: Bandpass
                w = apply_whitening(w)      # Step 2: Whitening
                w = w / (np.std(w) + 1e-20) # Step 3: Normalization 
                cleaned_waves.append(w)
            
            waves = np.stack(cleaned_waves)
            waves = torch.tensor(waves, dtype=torch.float32)
            
        except FileNotFoundError:
            waves = torch.zeros((3, 4096), dtype=torch.float32)

        if self.stage != 'test':
            return waves, torch.tensor(row['target'], dtype=torch.float32)
        return waves

# EfficientNet-B4

In [None]:
class WaveToImage(nn.Module):
    def __init__(self, sr=2048, fmin=20, fmax=1024, hop_length=64):
        super().__init__()
        self.transform = CQT1992v2(sr=sr, fmin=fmin, fmax=fmax, hop_length=hop_length, 
                                   output_format="Magnitude", verbose=False)
    
    def forward(self, x):
        batch_size, channels, time_steps = x.shape
        x = x.view(batch_size * channels, time_steps)
        images = self.transform(x)
        _, h, w = images.shape
        return images.view(batch_size, channels, h, w)

class G2NetModel(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super().__init__()
        self.wave_to_img = WaveToImage()
        # Batch Norm
        self.batch_norm = nn.BatchNorm2d(3)
        self.backbone = timm.create_model(model_name, pretrained=pretrained, in_chans=3, num_classes=1)

    def forward(self, x):
        x = self.wave_to_img(x)
        x = self.batch_norm(x)
        x = self.backbone(x)
        return x

# Mixup, Train Loop

In [None]:
# Mixup Augmentation
def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# Training Step
def train_fn(model, loader, optimizer, criterion, scaler, scheduler):
    model.train()
    running_loss = 0.0
    for waves, labels in tqdm(loader, desc="Train", leave=False):
        waves = waves.to(CFG.device)
        labels = labels.to(CFG.device).unsqueeze(1)
        optimizer.zero_grad()
        
        mixed_waves, labels_a, labels_b, lam = mixup_data(waves, labels, alpha=0.2)
        
        with autocast():
            outputs = model(mixed_waves)
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if scheduler: scheduler.step()
        
        running_loss += loss.item()
    return running_loss / len(loader)

# Validation Step
@torch.no_grad()
def valid_fn(model, loader):
    model.eval()
    preds, targets = [], []
    for waves, labels in tqdm(loader, desc="Valid", leave=False):
        waves = waves.to(CFG.device)
        outputs = model(waves)
        probs = torch.sigmoid(outputs).squeeze()
        preds.extend(probs.cpu().numpy())
        targets.extend(labels.numpy())
    return roc_auc_score(targets, preds)

# Main

In [None]:
if __name__ == '__main__':
    df = pd.read_csv(CFG.CSV_PATH)
    if CFG.debug: 
        df = df.sample(frac=0.05).reset_index(drop=True)
        print("DEBUG MODE: Using 5% data")
    
    # Stratified K-Fold
    skf = StratifiedKFold(n_splits=CFG.folds, shuffle=True, random_state=CFG.seed)
    for fold, (_, val_idx) in enumerate(skf.split(df, df['target'])):
        df.loc[val_idx, 'fold'] = fold

    # Training Loop
    if CFG.train:
        print(f"Start Training: {CFG.folds} Folds")
        
        for fold in range(CFG.folds):
            print(f"\n>>> TRAINING FOLD {fold} <<<")
            
            # Train/Valid
            train_idx = df[df['fold'] != fold].index
            valid_idx = df[df['fold'] == fold].index
            
            train_ds = CustomDataset(df.iloc[train_idx], CFG.DATA_DIR, 'train', augment=True)
            valid_ds = CustomDataset(df.iloc[valid_idx], CFG.DATA_DIR, 'valid', augment=False)
            
            train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True, 
                                      num_workers=CFG.num_workers, pin_memory=True)
            valid_loader = DataLoader(valid_ds, batch_size=CFG.batch_size, shuffle=False, 
                                      num_workers=CFG.num_workers, pin_memory=True)
            
            # Model & Optimizer
            model = G2NetModel(CFG.model_name).to(CFG.device)
            optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
            scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CFG.lr, 
                                                            epochs=CFG.epochs, 
                                                            steps_per_epoch=len(train_loader), 
                                                            pct_start=0.3)
            criterion = nn.BCEWithLogitsLoss()
            scaler = torch.cuda.amp.GradScaler()
            
            # Loop Epochs
            best_score = 0.0
            for epoch in range(CFG.epochs):
                train_loss = train_fn(model, train_loader, optimizer, criterion, scaler, scheduler)
                val_score = valid_fn(model, valid_loader)
                print(f"Epoch {epoch+1} | Loss: {train_loss:.4f} | AUC: {val_score:.5f}")
                
                if val_score > best_score:
                    best_score = val_score
                    torch.save(model.state_dict(), f"best_model_fold_{fold}.pth")
                    print(f"Saved Best Model Fold {fold}: {best_score:.5f}")

# Inference & Ensemble

In [None]:
print("\nStarting Ensemble Inference...")

sample_sub = pd.read_csv(CFG.SAMPLE_SUB)
test_ds = CustomDataset(sample_sub, CFG.TEST_DIR, stage='test', augment=False)
test_loader = DataLoader(test_ds, batch_size=CFG.batch_size * 2, shuffle=False, num_workers=CFG.num_workers)

final_preds = np.zeros(len(sample_sub))

for fold in range(CFG.folds):
    weight_path = f"best_model_fold_{fold}.pth"
    
    if not os.path.exists(weight_path):
        print(f"Warning: Weights for Fold {fold} not found. Skipping...")
        continue
        
    print(f"Inference using Fold {fold}...")
    
    # Load model
    model = G2NetModel(CFG.model_name, pretrained=False).to(CFG.device)
    model.load_state_dict(torch.load(weight_path, map_location=CFG.device))
    model.eval()
    
    # Prediction
    fold_preds = []
    with torch.no_grad():
        for waves in tqdm(test_loader, desc=f"Pred Fold {fold}"):
            waves = waves.to(CFG.device)
            outputs = model(waves)
            probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
            fold_preds.extend(probs)
    
    final_preds += np.array(fold_preds) / CFG.folds

sample_sub['target'] = final_preds
sample_sub.to_csv('submission.csv', index=False)
print("Ensemble Submission Saved Successfully!")
print(sample_sub.head())