In [None]:
!pip install tez
!pip install timm
!pip install nnaudio

In [None]:
import os
import albumentations
import tez
import torch
import random
import timm

import pandas as pd
import torch.nn as nn
import numpy as np

from nnAudio.Spectrogram import CQT1992v2
from scipy import signal
from sklearn import metrics
from tez.callbacks import EarlyStopping
from tqdm import tqdm

In [None]:
def seed_everything(seed: int) -> None:
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
class G2NetDataset:
    def __init__(self, ids, targets, base_path, augmentation):
        self.ids = ids
        self.targets = targets
        self.augmentation = augmentation
        self.base_path = base_path
        self.qtransform_params = {
            "sr": 2048,
            "fmin": 20,
            "fmax": 1024,
            "hop_length": 32,
            "bins_per_octave": 8,
        }
        self.wave_transform = CQT1992v2(**self.qtransform_params)

    def __len__(self):
        return len(self.ids)

    def apply_qtransform(self, waves, transform):
        return_data = []
        for wave in waves:
            wave = torch.from_numpy(wave).float()
            return_data.append(transform(wave))
        return_data = torch.stack(return_data, dim=1)
        return_data = return_data.squeeze().numpy()
        return return_data

    def __getitem__(self, idx):
        sample = self.ids[idx]
        path = self.base_path.format(sample[0], sample[1], sample[2], sample)

        data = np.load(path)
        data = self.apply_qtransform(data, self.wave_transform)
        data = np.transpose(data, (1, 2, 0)).astype(np.float32)

        targets = self.targets[idx]

        if self.augmentation is not None:
            augmented = self.augmentation(image=data)
            data = augmented["image"]

        data = np.transpose(data, (2, 0, 1)).astype(np.float32)
        return {
            "x": torch.tensor(data, dtype=torch.float32),
            "targets": torch.tensor(targets, dtype=torch.float),
        }

In [None]:
class G2NetModel(tez.Model):
    def __init__(self, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.model = timm.create_model("tf_efficientnet_b7_ns", in_chans=3)
        self.model.classifier = nn.Linear(2560, 1)
        self.step_scheduler_after = "epoch"

    def monitor_metrics(self, outputs, targets):
        outputs = outputs.cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        try:
            auc = metrics.roc_auc_score(targets, outputs)
        except ValueError:
            auc = 0
        return {"auc": auc}

    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch

    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return opt

    def forward(self, x, targets=None):
        outputs = self.model(x)

        if targets is not None:
            loss = nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, 0, {}

In [None]:
# change fold here
class args:
    fold = 0
    learning_rate = 1e-5
    batch_size = 32
    epochs = 10
    accumulation_steps = 1

In [None]:
seed_everything(42)

means = (6.90108482e-26, 5.11772679e-26, -1.38312479e-26)
stds = (7.42028294e-21, 7.41993950e-21, 1.83832928e-21)

train_aug = albumentations.Compose(
    [
        albumentations.Normalize(
            mean=means,
            std=stds,
            max_pixel_value=1,
            p=1.0,
        ),
    ],
    p=1.0,
)

valid_aug = albumentations.Compose(
    [
        albumentations.Normalize(
            mean=means,
            std=stds,
            max_pixel_value=1,
            p=1.0,
        ),
    ],
    p=1.0,
)

df = pd.read_csv("../input/g2netfolds/train_folds.csv")
df_train = df[df.kfold != args.fold].reset_index(drop=True)
df_valid = df[df.kfold == args.fold].reset_index(drop=True)

In [None]:
train_dataset = G2NetDataset(
    ids=df_train["id"].values,
    targets=df_train["target"].values,
    augmentation=train_aug,
    base_path="../input/g2net-gravitational-wave-detection/train/{}/{}/{}/{}.npy",
)

valid_dataset = G2NetDataset(
    ids=df_valid["id"].values,
    targets=df_valid["target"].values,
    augmentation=valid_aug,
    base_path="../input/g2net-gravitational-wave-detection/train/{}/{}/{}/{}.npy",
)

model = G2NetModel(learning_rate=args.learning_rate)

es = EarlyStopping(
    monitor="valid_auc",
    model_path=f"model_f{args.fold}.bin",
    patience=3,
    mode="max",
    save_weights_only=True,
)

model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=args.batch_size,
    valid_bs=4 * args.batch_size,
    device="cuda",
    epochs=args.epochs,
    callbacks=[es],
    fp16=True,
    accumulation_steps=args.accumulation_steps,
)

In [None]:
# generate test and valid predictions
model.load(f"model_f{args.fold}.bin", device="cuda", weights_only=True)

valid_predictions = model.predict(valid_dataset, batch_size=args.batch_size, n_jobs=-1)
final_valid_predictions = []
for preds in tqdm(valid_predictions):
    final_valid_predictions.extend(preds.ravel().tolist())

df_valid = df_valid[["id", "target"]]
df_valid["target"] = final_valid_predictions
df_valid.to_csv(f"valid_predictions_f{args.fold}.csv", index=False)

df_test = pd.read_csv("../input/g2net-gravitational-wave-detection/sample_submission.csv")
valid_aug = albumentations.Compose(
    [
        albumentations.Normalize(
            mean=means,
            std=stds,
            max_pixel_value=1,
            p=1.0,
        ),
    ],
    p=1.0,
)
test_dataset = G2NetDataset(
    ids=df_test["id"].values,
    targets=np.zeros(len(df_test)),
    augmentation=valid_aug,
    base_path="../input/g2net-gravitational-wave-detection/test/{}/{}/{}/{}.npy",
)

test_predictions = model.predict(test_dataset, batch_size=args.batch_size, n_jobs=-1)
final_test_predictions = []
for preds in tqdm(test_predictions):
    final_test_predictions.extend(preds.ravel().tolist())

df_test["target"] = final_test_predictions
df_test.to_csv(f"test_predictions_f{args.fold}.csv", index=False)

# 