## About this notebook

This notebook is the inference notebook for [G2Net: TF On-the-fly CQT TPU Training](https://www.kaggle.com/hidehisaarai1213/g2net-tf-on-the-fly-cqt-tpu-training).

On the fly CQT computation achieves better result compared to [Welf's Notebook](https://www.kaggle.com/miklgr500/g2net-efficientnetb1-tpu-evaluate) given the same image size and EfficientNet size, which means if you scale up the model or scale up the image size, you'll possibly get the best single model compared to publicly shared models.
It also allows you to make more variations for the input, which gives you a great advantage.

### Updates

* V3: Use the weights of V2 of the Training Notebook
    * EfficientNetB0 -> EfficientNetB7

Reference
https://www.kaggle.com/hidehisaarai1213/g2net-read-from-tfrecord-train-with-pytorch


## Install Dependencies

In [None]:
!pip install efficientnet tensorflow_addons > /dev/null
!pip install -q nnAudio
!pip install timm

In [None]:
import os
import math
import random
import re
import warnings
from pathlib import Path
from typing import Optional, Tuple

import efficientnet.tfkeras as efn
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa
from kaggle_datasets import KaggleDatasets
from scipy.signal import get_window
from matplotlib import pyplot as plt
import timm

In [None]:
tf.__version__

## Config

In [None]:
IMAGE_SIZE = 256 #for b0
BATCH_SIZE = 64
EFFICIENTNET_SIZE = 5
WEIGHTS = "imagenet"
class CFG:
    debug=False
    num_workers=4
    model_name='tf_efficientnet_b0_ns'
    model_dir='../input/' 
    batch_size=512 
    
    qtransform_params={"sr": 2048, "fmin": 20, "fmax": 1024, "hop_length": 64,
                       "bins_per_octave": 24 }
    seed=42
    target_size=1
    target_col='target'
    n_fold=5
    trn_fold=[1] # [0, 1, 2, 3, 4]

## Utilities

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)


set_seed(1213)

In [None]:
def auto_select_accelerator():
    TPU_DETECTED = False
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
        TPU_DETECTED = True
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")

    return strategy, TPU_DETECTED

In [None]:
strategy, tpu_detected = auto_select_accelerator()
AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync

## Data Loading

In [None]:
gcs_paths = []
for i, j in [(0, 4), (5, 9)]:
    GCS_path = KaggleDatasets().get_gcs_path(f"g2net-waveform-tfrecords-test-{i}-{j}")
    gcs_paths.append(GCS_path)
    print(GCS_path)

In [None]:
all_files = []
for path in gcs_paths:
    all_files.extend(np.sort(np.array(tf.io.gfile.glob(path + "/test*.tfrecords"))))

print("test_files: ", len(all_files))

## Dataset Preparation

In [None]:
IMAGE_SIZE

In [None]:
def prepare_wave(wave):
    wave = tf.reshape(tf.io.decode_raw(wave, tf.float64), (3, 4096))
    normalized_waves = []
    for i in range(3):
        normalized_wave = wave[i] / tf.math.reduce_max(wave[i])
        normalized_waves.append(normalized_wave)
    wave = tf.stack(normalized_waves, axis=0)
    wave = tf.cast(wave, tf.float32)
    return wave


def read_labeled_tfrecord(example):
    tfrec_format = {
        "wave": tf.io.FixedLenFeature([], tf.string),
        "wave_id": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    }
    example = tf.io.parse_single_example(example, tfrec_format)
    return prepare_image(example["wave"], IMAGE_SIZE), tf.reshape(tf.cast(example["target"], tf.float32), [1])
def read_labeled_tfrecord(example):
    tfrec_format = {
        "wave": tf.io.FixedLenFeature([], tf.string),
        "wave_id": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    }
    example = tf.io.parse_single_example(example, tfrec_format)
    return prepare_wave(example["wave"]), tf.reshape(tf.cast(example["target"], tf.float32), [1]), example["wave_id"]


def read_unlabeled_tfrecord(example, return_image_id):
    tfrec_format = {
        "wave": tf.io.FixedLenFeature([], tf.string),
        "wave_id": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrec_format)
    return prepare_image(example["wave"], IMAGE_SIZE), example["wave_id"] if return_image_id else 0
def read_unlabeled_tfrecord(example, return_image_id):
    tfrec_format = {
        "wave": tf.io.FixedLenFeature([], tf.string),
        "wave_id": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrec_format)
    return prepare_wave(example["wave"]), example["wave_id"] if return_image_id else 0


def count_data_items(fileids):
    return len(fileids) * 28000


def count_data_items_test(fileids):
    return len(fileids) * 22600




def get_dataset(files, batch_size=16, repeat=False, shuffle=False, aug=True, labeled=True, return_image_ids=True):
    ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO, compression_type="GZIP")
    #ds = ds.cache()

    if repeat:
        ds = ds.repeat()

    if shuffle:
        ds = ds.shuffle(1024 * 2)
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)

    if labeled:
        ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)
    else:
        ds = ds.map(lambda example: read_unlabeled_tfrecord(example, return_image_ids), num_parallel_calls=AUTO)

    ds = ds.batch(batch_size )
    #if aug:
    #    ds = ds.map(lambda x, y: aug_f(x, y, batch_size * REPLICAS), num_parallel_calls=AUTO)
    ds = ds.prefetch(AUTO)
    return tfds.as_numpy(ds)

## Model

In [None]:
class TFRecordDataLoader:
    def __init__(self, files, batch_size=32, cache=False, train=False, repeat=False, 
                 shuffle=False, labeled=False, return_image_ids=True):
        self.ds = get_dataset(
            files, 
            batch_size=batch_size ,
            repeat=repeat,
            shuffle=shuffle,
            labeled=labeled,
            return_image_ids=return_image_ids)
        
        if train:
            self.num_examples = count_data_items(files)
        else:
            self.num_examples = count_data_items_test(files)

        self.batch_size = batch_size
        self.labeled = labeled
        self.return_image_ids = return_image_ids
        self._iterator = None
    
    def __iter__(self):
        if self._iterator is None:
            self._iterator = iter(self.ds)
        else:
            self._reset()
        return self._iterator

    def _reset(self):
        self._iterator = iter(self.ds)

    def __next__(self):
        batch = next(self._iterator)
        return batch

    def __len__(self):
        n_batches = self.num_examples // self.batch_size
        if self.num_examples % self.batch_size == 0:
            return n_batches
        else:
            return n_batches + 1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from nnAudio.Spectrogram import CQT1992v2
class CustomModel(nn.Module):
    def __init__(self, cfg, pretrained=False):
        super().__init__()
        self.cfg = cfg
        self.wave_transform = CQT1992v2(**CFG.qtransform_params)
        self.model = timm.create_model(self.cfg.model_name, pretrained=pretrained, in_chans=3)
        self.n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(self.n_features, self.cfg.target_size)

    def forward(self, x):
        waves = []
        for i in range(3):
            waves.append(self.wave_transform(x[:, i]))
        x = torch.stack(waves, dim=1)
        output = self.model(x)
        return output

## Inference

In [None]:
files_test_all = np.array(all_files)
all_test_preds = []

In [None]:
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
with torch.no_grad():
    model = CustomModel(CFG, pretrained=False)
    model.to(device)

In [None]:
from tqdm.notebook import tqdm as tqdm
import tensorflow_datasets as tfds
import gc 

In [None]:

states = [torch.load(CFG.model_dir+f'{CFG.model_name}_fold{fold}_best_score.pth',
                     map_location=device) 
          for fold in CFG.trn_fold]
num_folds=4
filenames = []

test_loader = TFRecordDataLoader(
        all_files, batch_size=CFG.batch_size  , shuffle=False)

with torch.no_grad():
        
    for step, d in enumerate(tqdm(test_loader)):
        avg_preds=[]
            #targets.extend(d[1].reshape(-1).tolist())
        filenames.extend([f.decode("UTF-8") for f in d[1]])

        images = torch.from_numpy(d[0]).to(device)
        del d 
        gc.collect()
        for state in states:
            model.load_state_dict(state['model'])
            model.eval()
               
            y_preds = model(images)
            avg_preds.append(y_preds.sigmoid().to('cpu').numpy())
            avg_preds = np.mean(avg_preds, axis=0)
            #labels = torch.from_numpy(d[1]).to(device)

            #batch_size = labels.size(0)
            # compute loss


        
        

        all_test_preds.append(avg_preds )
    probs = np.concatenate(all_test_preds)
    #file_names=np.concatenate(file_names)

In [None]:
probs = np.concatenate(all_test_preds)


In [None]:
#all_test_preds
test_df = pd.DataFrame({
    "id": filenames ,
    "target": probs.reshape(-1)
})

In [None]:
#ds_test = get_dataset(files_test_all, batch_size=BATCH_SIZE * 2, repeat=False, shuffle=False, aug=False, labeled=False, return_image_ids=True)
#file_ids = np.array([target.numpy() for img, target in iter(ds_test.unbatch())])

In [None]:



test_df.head()

In [None]:
plt.hist(test_df.target.values)

In [None]:
test_df.to_csv("submission.csv", index=False)