In [None]:
!pip install -q   timm

In [None]:
from pathlib import Path

import numpy as np
import librosa.display as lbd
import pandas as pd

import torch
from  torch.utils.data import Dataset, DataLoader

from matplotlib import pyplot as plt

from tqdm.notebook import tqdm

import timm

In [None]:
TEST_BATCH_SIZE = 768
TEST_NUM_WORKERS = 2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Device:", DEVICE)

In [None]:
def get_file_path(file_id):
    return str("../input/g2net-test-mels-2/audio_images/{}/{}/{}/{}.npy".format(file_id[0], file_id[1], file_id[2], file_id))

In [None]:
df = pd.read_csv("../input/g2net-gravitational-wave-detection/sample_submission.csv")
df["impath"] = df["id"].apply(get_file_path)

print(df.shape)
df

# Data

In [None]:
class G2NetDataset(Dataset):
    def __init__(self, data):
        
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return np.vstack(np.load(self.data.loc[idx, "impath"]).astype("float32") / 255.0)[None]

In [None]:
ds = G2NetDataset(data=df)
print(len(ds))

x = ds[np.random.choice(len(ds))]
print(x.shape)

In [None]:
lbd.specshow(x[0])

# Inference

In [None]:
def load_net(checkpoint_path):
    net = timm.create_model("tf_efficientnet_b0", pretrained=False, num_classes=1)
    net.conv_stem = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    net = net.to(DEVICE)
    net.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
    net = net.eval()
    return net

In [None]:
@torch.no_grad()
def predict(nets, test_data):
    preds = []
    for xb in  tqdm(test_data):
        pred = 0
        for net in nets:
            o = torch.sigmoid(net(xb.to(DEVICE)).squeeze(1))
            pred += o
            
        pred /= len(nets)
            
        preds.append(pred.cpu().numpy())
    preds = np.concatenate(preds)
    return preds

In [None]:
checkpoint_paths = [
    "../input/g2net-kkiller-public-models/tf_efficientnet_b0_fold0.pth",
    "../input/g2net-kkiller-public-models/tf_efficientnet_b0_fold1.pth",
    "../input/g2net-kkiller-public-models/tf_efficientnet_b0_fold2.pth",
]

nets = [load_net(checkpoint_path) for checkpoint_path in checkpoint_paths ]

print("n_models:", len(checkpoint_paths))

In [None]:
test_data = G2NetDataset(data=df)
test_laoder = DataLoader(test_data, batch_size=TEST_BATCH_SIZE, num_workers=TEST_NUM_WORKERS, shuffle=False)

len(test_data), len(test_laoder)

In [None]:
sub = df[["id", "target"]].copy()
sub["target"] = predict(nets, test_laoder)

sub.to_csv("submission.csv", index=False)

print(sub.shape)
sub