# Starter using the Vision Transformer (ViT)

What Transformer does:
- dividing the spectrogram into patches
- build patch embeddings
- attention between different patches

Since the default ViT needs 16x16 patches, the final images are padded...

Reference:

- Yasufumi Nakama's (@yasufuminakama) spectrogram preprocessing notebooks and datasets:
    * Train: [Notebook](https://www.kaggle.com/yasufuminakama/g2net-spectrogram-generation-train), [Dataset](https://www.kaggle.com/yasufuminakama/g2net-n-mels-128-train-images)
    * Test: [Notebook](https://www.kaggle.com/yasufuminakama/g2net-spectrogram-generation-test), [Dataset](https://www.kaggle.com/yasufuminakama/g2net-n-mels-128-test-images)
- @xhlulu 's pipeline: https://www.kaggle.com/xhlulu/g2net-rnn-starter-from-spectrogram

In [None]:
!pip install -q vit-keras

In [None]:
import os

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers

In [None]:
from vit_keras import vit

In [None]:
FOLD = 0
N_SPLITS = 5

In [None]:
class CustomDataset(tf.keras.utils.Sequence):
    def __init__(self, df, directory, 
                 batch_size=32, 
                 random_state=1127802825, 
                 shuffle=True, target=True, ext='.npy'):
        np.random.seed(random_state)
        
        self.directory = directory
        self.df = df
        self.shuffle = shuffle
        self.target = target
        self.batch_size = batch_size
        self.ext = ext
        
        self.on_epoch_end()
    
    def __len__(self):
        return np.ceil(self.df.shape[0] / self.batch_size).astype(int)
    
    def __getitem__(self, idx):
        start_idx = idx * self.batch_size
        batch = self.df[start_idx: start_idx + self.batch_size]
        
        signals = []

        for fname in batch.id:
            path = os.path.join(self.directory, fname + self.ext)
            data = np.load(path)
            signals.append(data)
        
        signals = np.stack(signals).astype('float32')
        signals = tf.pad(signals, tf.constant([[0, 0], [2, 3,], [0, 0]]), "SYMMETRIC")
        signals = tf.tile(tf.expand_dims(signals, axis=-1), multiples=[1,1,1,3])
        
        if self.target:
            return signals, batch.target.values
        else:
            return signals
    
    def on_epoch_end(self):
        if self.shuffle:
            self.df = self.df.sample(frac=1).reset_index(drop=True)

In [None]:
vit_model = vit.vit_b16(
        image_size = (32, 128),
        activation = 'softmax',
        pretrained = True,
        include_top = False,
        pretrained_top = False,
        classes = 2)

In [None]:
def build_model():
    inputs = layers.Input(shape=(32, 128, 3))

    x = vit_model(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = layers.Dense(128, activation = tfa.activations.gelu)(x)
    x = layers.Dense(1, activation="sigmoid", name="sigmoid")(x)

    model = tf.keras.Model(inputs=inputs, outputs=x)
    
    return model

In [None]:
train = pd.read_csv('../input/g2net-gravitational-wave-detection/training_labels.csv')
train.head()

In [None]:
cv = StratifiedKFold(n_splits=N_SPLITS, random_state=1127802825, shuffle=True)
cv_splits = cv.split(X=train, y=train['target'].values)
for _fold, (train_idx, valid_idx) in enumerate(cv_splits):
    if _fold == FOLD:
        break

train_df = train.iloc[train_idx, :]
valid_df = train.iloc[valid_idx, :]

In [None]:
train_dset = CustomDataset(
    train_df, '../input/g2net-n-mels-128-train-images', batch_size=64)

valid_dset = CustomDataset(
    valid_df, '../input/g2net-n-mels-128-train-images', batch_size=64, shuffle=False)

sample = next(iter(train_dset))
for item in sample:
    print(item.shape)

In [None]:
model = build_model()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), 
              loss="binary_crossentropy", 
              metrics=[tf.keras.metrics.AUC()])
model.summary()

In [None]:
ckpt = tf.keras.callbacks.ModelCheckpoint(
    "model_weights.h5", save_best_only=True, save_weights_only=True,
)

train_history = model.fit(
    train_dset, 
    epochs=8,
    validation_data=valid_dset,
    callbacks=[ckpt],
    verbose=1
)

In [None]:
model.load_weights('model_weights.h5')

In [None]:
sub = pd.read_csv('../input/g2net-gravitational-wave-detection/sample_submission.csv')

test_dset = CustomDataset(
    sub, "../input/g2net-n-mels-128-test-images", batch_size=64, target=False, shuffle=False)

y_pred = model.predict(test_dset, verbose=1)
sub['target'] = y_pred
sub.to_csv(f'vit_sub_{FOLD}.csv', index=False)