# **Imports & Environment Setup**

In [1]:
import os
import warnings
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import timm
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
from tqdm.auto import tqdm

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Install dependencies silently
!pip install nnAudio timm > /dev/null 2>&1



# **Utility Module**

In [None]:
%%writefile utils.py
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import scipy.signal
import scipy.interpolate
from scipy.signal.windows import tukey

try:
    from nnAudio.Spectrogram import CQT1992v2
except ImportError:
    pass

SOS_FILTER = scipy.signal.butter(4, [20, 500], btype="bandpass", output="sos", fs=2048)
NORMALIZATION_FACTOR = np.sqrt((500 - 20) / (2048 / 2))
WINDOW_TUKEY = tukey(4096, alpha=0.1)


class ProjectConfig:
    """Centralized configuration for paths."""
    ROOT_DIR = '/kaggle/working'
    DATA_DIR = '/kaggle/input/g2net-gravitational-wave-detection/train'
    CSV_PATH = '/kaggle/input/g2net-gravitational-wave-detection/training_labels.csv'

def seed_everything(seed=42):
    """Ensures reproducibility across runs."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = True 
        torch.backends.cudnn.deterministic = False

# --- Signal Processing (High Performance) ---

def apply_bandpass(x):
    return scipy.signal.sosfiltfilt(SOS_FILTER, x) / NORMALIZATION_FACTOR

def apply_whitening(x, sr=2048):
    # 1. Windowing
    x = x * WINDOW_TUKEY
    
    # 2. Calculate PSD (Power Spectral Density)
    freqs, psd = scipy.signal.welch(x, fs=sr, nperseg=sr, window='hann')
    
    # 3. FFT (Fast Fourier Transform)
    x_f = np.fft.rfft(x)
    freqs_f = np.fft.rfftfreq(len(x), d=1/sr)
    
    valid_indices = (freqs_f >= freqs.min()) & (freqs_f <= freqs.max())
    x_f_whitened = np.zeros_like(x_f)
    
    # 4. Interpolation
    psd_values = np.interp(freqs_f[valid_indices], freqs, psd)
    
    # 5. Normalize & IFFT
    x_f_whitened[valid_indices] = x_f[valid_indices] / np.sqrt(psd_values + 1e-20)
    return np.fft.irfft(x_f_whitened, n=len(x))

def apply_timeshift(x, max_shift=0.2, sr=2048):
    shift_amt = int(random.random() * max_shift * sr)
    if random.random() > 0.5:
        shift_amt = -shift_amt
    return np.roll(x, shift_amt, axis=-1)

# --- Model Components ---

class WaveToImage(nn.Module):
    def __init__(self, sr=2048, fmin=20, fmax=1024, hop_length=64, device='cuda'):
        super().__init__()
        self.transform = CQT1992v2(sr=sr, fmin=fmin, fmax=fmax, hop_length=hop_length, 
                                   output_format="Magnitude", verbose=False)
    
    def forward(self, x):
        batch_size, channels, time_steps = x.shape
        x = x.view(batch_size * channels, time_steps)
        images = self.transform(x)
        _, h, w = images.shape
        return images.view(batch_size, channels, h, w)

class CustomDataset(Dataset):
    def __init__(self, dataframe, root_dir, stage='train', augment=False):
        self.df = dataframe
        self.root_dir = root_dir
        self.stage = stage
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_id = str(row['id'])
        # Organization of G2Net: a/b/c/abc0123.npy
        path_part = f"{file_id[0]}/{file_id[1]}/{file_id[2]}"
        file_path = os.path.join(self.root_dir, path_part, f"{file_id}.npy")
        
        try:
            waves = np.load(file_path).astype(np.float32)
            
            if self.stage == 'train' and self.augment:
                waves = apply_timeshift(waves)

            cleaned_waves = []
            for i in range(3):
                w = waves[i]
                w = apply_bandpass(w)
                w = apply_whitening(w)
                w = w / (np.std(w) + 1e-20) 
                
                cleaned_waves.append(w)
            
            waves = np.stack(cleaned_waves)
            waves = torch.tensor(waves, dtype=torch.float32)
            
        except FileNotFoundError:
            waves = torch.zeros((3, 4096), dtype=torch.float32)

        if self.stage != 'test':
            return waves, torch.tensor(row['target'], dtype=torch.float32)
        return waves

Overwriting utils.py


# **Configuration & Data Preparation**

In [None]:
import utils

class CFG:
    debug = False          # Set to False for submission, True for quick debugging (10% data)
    seed = 42
    model_name = 'tf_efficientnet_b4_ns' 
    epochs = 12           
    batch_size = 32   
    lr = 1e-3
    weight_decay = 1e-2
    folds = 5
    fold_idx = 0
    num_workers = 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

utils.seed_everything(CFG.seed)
config = utils.ProjectConfig()

# Load and Split Data
df = pd.read_csv(config.CSV_PATH)

if CFG.debug:
    df = df.sample(frac=0.1, random_state=CFG.seed).reset_index(drop=True)

skf = StratifiedKFold(n_splits=CFG.folds, shuffle=True, random_state=CFG.seed)
for fold, (_, val_idx) in enumerate(skf.split(df, df['target'])):
    df.loc[val_idx, 'fold'] = fold

print(f"Data loaded. Shape: {df.shape}. Device: {CFG.device}")

Data loaded. Shape: (560000, 3). Device: cuda


# **Model Definition & Data Loaders**

In [4]:
class G2NetModel(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super().__init__()
        self.wave_to_img = utils.WaveToImage(device=CFG.device)
        self.batch_norm = nn.BatchNorm2d(3)
        self.backbone = timm.create_model(model_name, pretrained=pretrained, in_chans=3, num_classes=1)

    def forward(self, x):
        x = self.wave_to_img(x)
        x = self.batch_norm(x)
        x = self.backbone(x)
        return x

# Dataset & Loader Initialization
train_df = df[df['fold'] != CFG.fold_idx].reset_index(drop=True)
valid_df = df[df['fold'] == CFG.fold_idx].reset_index(drop=True)

train_ds = utils.CustomDataset(train_df, config.DATA_DIR, stage='train', augment=True)
valid_ds = utils.CustomDataset(valid_df, config.DATA_DIR, stage='train', augment=False)

train_loader = DataLoader(train_ds, 
                          batch_size=CFG.batch_size, 
                          shuffle=True, 
                          num_workers=CFG.num_workers, 
                          pin_memory=True,
                          persistent_workers=True,
                          prefetch_factor=2)

valid_loader = DataLoader(valid_ds, 
                          batch_size=CFG.batch_size, 
                          shuffle=False, 
                          num_workers=CFG.num_workers, 
                          pin_memory=True,
                          persistent_workers=True,
                          prefetch_factor=2)

# **Training Engine**

In [None]:
def mixup_data(x, y, alpha=0.2):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def train_one_epoch(model, loader, optimizer, criterion, scaler, scheduler=None):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc="Train", leave=False)
    
    for waves, labels in pbar:
        waves = waves.to(CFG.device)
        labels = labels.to(CFG.device).unsqueeze(1)
        
        optimizer.zero_grad()
        
        # --- 1. MIXUP ---
        mixed_waves, labels_a, labels_b, lam = mixup_data(waves, labels, alpha=0.2)
        
        with autocast(): 
            outputs = model(mixed_waves)
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # --- 2. SCHEDULER STEP ---
        if scheduler is not None:
            scheduler.step() 
        
        running_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{optimizer.param_groups[0]['lr']:.1e}")
        
    return running_loss / len(loader)

@torch.no_grad()
def validate(model, loader):
    model.eval()
    preds, targets = [], []
    
    for waves, labels in tqdm(loader, desc="Valid", leave=False):
        waves = waves.to(CFG.device)
        outputs = model(waves)
        probs = torch.sigmoid(outputs).squeeze()
        
        if torch.isnan(probs).any():
            probs = torch.nan_to_num(probs, nan=0.0)
            
        preds.extend(probs.cpu().numpy())
        targets.extend(labels.numpy())
        
    return roc_auc_score(targets, preds)

# --- Main Loop ---
model = G2NetModel(CFG.model_name).to(CFG.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=CFG.lr,               
    epochs=CFG.epochs,           
    steps_per_epoch=len(train_loader), 
    pct_start=0.3,               
    div_factor=25.0,            
    final_div_factor=1000.0    
)
criterion = nn.BCEWithLogitsLoss()
scaler = torch.cuda.amp.GradScaler()
best_auc = 0.0

print(f"Starting training for {CFG.epochs} epochs...")

for epoch in range(CFG.epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, scaler, scheduler)
    val_auc = validate(model, valid_loader)
    
    print(f"Epoch {epoch+1}/{CFG.epochs} | Loss: {train_loss:.4f} | AUC: {val_auc:.5f}")
    
    if val_auc > best_auc:
        best_auc = val_auc
        save_name = f"best_model_fold_{CFG.fold_idx}.pth"
        torch.save(model.state_dict(), save_name) 
        print(f" -> Saved Best Model: {save_name} (AUC: {best_auc:.5f})")

print(f"Training complete. Best AUC: {best_auc:.5f}")

Starting training for 12 epochs...


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 1/12 | Loss: 0.6064 | AUC: 0.84145
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.84145)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 2/12 | Loss: 0.5179 | AUC: 0.84870
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.84870)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 3/12 | Loss: 0.5128 | AUC: 0.85282
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.85282)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 4/12 | Loss: 0.5057 | AUC: 0.85417
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.85417)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 5/12 | Loss: 0.5007 | AUC: 0.85618
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.85618)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 6/12 | Loss: 0.4949 | AUC: 0.85950
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.85950)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 7/12 | Loss: 0.4904 | AUC: 0.86043
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.86043)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 8/12 | Loss: 0.4872 | AUC: 0.86105
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.86105)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 9/12 | Loss: 0.4829 | AUC: 0.86310
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.86310)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 10/12 | Loss: 0.4809 | AUC: 0.86347
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.86347)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 11/12 | Loss: 0.4775 | AUC: 0.86397
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.86397)


Train:   0%|          | 0/14000 [00:00<?, ?it/s]

Valid:   0%|          | 0/3500 [00:00<?, ?it/s]

Epoch 12/12 | Loss: 0.4751 | AUC: 0.86405
 -> Saved Best Model: best_model_fold_0.pth (AUC: 0.86405)
Training complete. Best AUC: 0.86405


# **Inference & Submission**

In [6]:
print("Generating submission...")

sample_sub = pd.read_csv('/kaggle/input/g2net-gravitational-wave-detection/sample_submission.csv')
test_ds = utils.CustomDataset(
    sample_sub, 
    config.DATA_DIR.replace('train', 'test'), 
    stage='test', 
    augment=False
)
test_loader = DataLoader(test_ds, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

weights_path = f"best_model_fold_{CFG.fold_idx}.pth"

if os.path.exists(weights_path):
    checkpoint = torch.load(weights_path, map_location=CFG.device)
    new_state_dict = {}
    for k, v in checkpoint.items():
        name = k.replace("_orig_mod.", "") 
        new_state_dict[name] = v
        
    model.load_state_dict(new_state_dict)
    print(f"Loaded best model weights from {weights_path}")
else:
    print(f"Warning: {weights_path} not found. Using current weights.")

model.eval()

test_preds = []
with torch.no_grad():
    for waves in tqdm(test_loader, desc="Predicting"):
        waves = waves.to(CFG.device)
        outputs = model(waves)
        probs = torch.sigmoid(outputs).squeeze()
        
        if probs.ndim == 0:
            test_preds.append(probs.item())
        else:
            test_preds.extend(probs.cpu().numpy())

sample_sub['target'] = test_preds
sample_sub.to_csv('submission.csv', index=False)
print("Submission saved to 'submission.csv'")

Generating submission...
Loaded best model weights from best_model_fold_0.pth


Predicting:   0%|          | 0/7063 [00:00<?, ?it/s]

Submission saved to 'submission.csv'
