# About

This notebook is the inference part of [@hidehisaarai1213](https://www.kaggle.com/hidehisaarai1213)'s [G2Net: Read from TFRecord & Train with PyTorch](https://www.kaggle.com/hidehisaarai1213/g2net-read-from-tfrecord-train-with-pytorch) notebook. It is also based on [@yasufuminakama](https://www.kaggle.com/yasufuminakama)'s [G2Net / efficientnet_b7 / baseline [training]](https://www.kaggle.com/yasufuminakama/g2net-efficientnet-b7-baseline-training) notebook. 


**This code is based on again [@yasufuminakama](https://www.kaggle.com/yasufuminakama)'s original [G2Net / efficientnet_b7 / baseline [inference]
](https://www.kaggle.com/yasufuminakama/g2net-efficientnet-b7-baseline-inference) notebook.**

Please show your support to original authors as well!!!

In [None]:
!pip install -q nnAudio
!pip install -q timm

# Packages

In [None]:
import os
import time
import math
import glob
import random
from pathlib import Path

import numpy as np
import pandas as pd
import scipy as sp
import tensorflow as tf  # for reading TFRecord Dataset
import tensorflow_datasets as tfds  # for making tf.data.Dataset to return numpy arrays
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import timm
from kaggle_datasets import KaggleDatasets
from nnAudio.Spectrogram import CQT1992v2
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# CFG

In [None]:
class CFG:
    debug = False
    print_freq = 50
    num_workers = 4
    model_name = "tf_efficientnet_b0_ns"
    qtransform_params = {"sr": 2048, "fmin": 20, "fmax": 1024, "hop_length": 24, "bins_per_octave": 12}
    scheduler = "CosineAnnealingLR"
    epochs = 3
    T_max = 3
    lr = 1e-4
    min_lr = 1e-7
    batch_size = 64
    weight_decay = 1e-3
    gradient_accumulation_steps = 1
    max_grad_norm = 1000
    seed = 42
    target_size = 1
    target_col = "target"
    n_fold = 5
    trn_fold = [0, 1, 2, 3, 4]
    train = True

if CFG.debug:
    CFG.epochs = 1

In [None]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

# Dataset

In [None]:
# ====================================================
# Dataset
# ====================================================
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['file_path'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def apply_transform(self, waves):
        waves = waves / np.max(waves, axis=1)[:, None]
        waves = torch.from_numpy(waves).float()
        return waves

    def __getitem__(self, idx):
        file_path = self.file_names[idx]
        waves = np.load(file_path)
        image = self.apply_transform(waves)
        image = image.squeeze().numpy()
        if self.transform:
            image = self.transform(image=image)['image']
        return image # , label

# Model

In [None]:
class CustomModel(nn.Module):
    def __init__(self, cfg, pretrained=False):
        super().__init__()
        self.cfg = cfg
        self.wave_transform = CQT1992v2(**CFG.qtransform_params)
        self.model = timm.create_model(self.cfg.model_name, pretrained=pretrained, in_chans=3)
        self.n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(self.n_features, self.cfg.target_size)

    def forward(self, x):
        waves = []
        for i in range(3):
            waves.append(self.wave_transform(x[:, i]))
        x = torch.stack(waves, dim=1)
        output = self.model(x)
        return output


In [None]:
def get_test_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/test/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

In [None]:
df_sub = pd.read_csv("../input/g2net-gravitational-wave-detection/sample_submission.csv")
df_sub['file_path'] = df_sub['id'].apply(get_test_file_path)

test_dataset = TestDataset(df_sub)
test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=512,
    num_workers=CFG.num_workers, 
    shuffle=False,
    pin_memory=True, 
    drop_last=False
)

best_loss_models = [torch.load(state, map_location=device)["model"] 
                    for state in sorted(glob.glob("../input/g2net-read-from-tfrecord-train-with-pytorch/*best_loss.pth"))]
best_score_models = [torch.load(state, map_location=device)["model"]
                     for state in sorted(glob.glob("../input/g2net-read-from-tfrecord-train-with-pytorch/*best_score.pth"))]

best_loss_model = CustomModel(CFG)
best_score_model = CustomModel(CFG)

# Inference

In [None]:
def inference(model, states, data_loader, device):
    
    model.to(device)
    tk0 = tqdm(enumerate(data_loader), total=len(data_loader))
    probs = []
    
    for idx, images in tk0:
        
        images = images.to(device)
        avg_preds = []
        for state in states:
            model.load_state_dict(state)
            model.eval()
            with torch.no_grad():
                preds = model(images)
            
            avg_preds.append(preds.sigmoid().cpu().numpy())
        
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)

    probs = np.concatenate(probs)
    return probs


In [None]:
preds = inference(best_score_model, best_score_models, test_loader, device)
# preds_2 = inference(best_loss_model, best_loss_models, test_loader, device)

In [None]:
df_sub["target"] = preds
df_sub.drop(["file_path"], axis=1, inplace=True)
df_sub.to_csv("submission.csv", index=False)

In [None]:
df_sub