# <a href='https://github.com/YuanGongND/ast'>AST: Audio Spectrogram Transformer</a>

**This don't need to resize your input. Also I don't know what the hyper-parameters in the model exactly mean So most likely I made a huge mistake somewhere in the hyperparameters but I have made a quick implementation of the model for the community. Here you go.** 

# Importing Libraries

In [None]:
!git clone https://github.com/YuanGongND/ast.git --quiet
!pip install llvmlite --quiet
!pip install wget --quiet
!pip install zipp --quiet
!pip install wandb --upgrade --quiet
!pip install nnAudio --quiet
!pip install pytorch_lightning --quiet

In [None]:
!pip install timm==0.4.5

In [None]:
import sys
sys.path.append('./ast/src')

In [None]:
import os
import glob
import wandb
import shutil

from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import timm
import torch
from models import ASTModel
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader

from nnAudio.Spectrogram import CQT1992v2

from tqdm.notebook import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
%cd ./ast/

base_model = ASTModel(label_dim=1,
                     fstride=5, tstride=5, \
                     input_fdim=69, input_tdim=193, \
                     imagenet_pretrain=True, audioset_pretrain=False, \
                     model_size='base384')

# Data Loading

In [None]:
%cd /kaggle/working

train = pd.read_csv('../input/g2net-gravitational-wave-detection/training_labels.csv')
test = pd.read_csv('../input/g2net-gravitational-wave-detection/sample_submission.csv')

def get_train_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/train/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

def get_test_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/test/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

train['file_path'] = train['id'].apply(get_train_file_path)
test['file_path'] = test['id'].apply(get_test_file_path)

display(train.head(2))
display(test.head(2))

# CFG

In [None]:
wandb.login()

In [None]:
# ====================================================
# CFG
# ====================================================
class Config:
    debug = False
    num_workers = 4
    epochs = 3
    lr = 1e-4
    weight_decay = 1e-6
    batch_size = 32
    seed = 1234
    target_size = 1
    n_folds = 5
    target_col = 'target'
    LOSS = torch.nn.BCEWithLogitsLoss()
    epochs = 10
    dev_run = False

if Config.debug:
    train = train.sample(n=50000, random_state=Config.seed).reset_index(drop=True)
    Config.epochs = 1
    Config.dev_run = True
    

pl.seed_everything(Config.seed)

# Simple train valid split

In [None]:
train, valid = train_test_split(train, test_size=0.3, stratify=train['target'])

# Dataset

In [None]:
# ====================================================
# Dataset
# ====================================================
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['file_path'].values
        self.labels = df[Config.target_col].values
        self.wave_transform = CQT1992v2(sr=2048, fmin=20, fmax=1024, hop_length=64)
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def apply_qtransform(self, waves, transform):
        waves = np.hstack(waves)
        waves = waves / np.max(waves)
        waves = torch.from_numpy(waves).float()
        image = transform(waves)
        return image

    def __getitem__(self, idx):
        file_path = self.file_names[idx]
        waves = np.load(file_path)
        image = self.apply_qtransform(waves, self.wave_transform)
        if self.transform:
            image = image.squeeze().numpy()
            image = self.transform(image=image)['image']
        label = torch.tensor(self.labels[idx]).float()
        return image[0], label

# Transforms

In [None]:
# ====================================================
# Transforms
# ====================================================
def get_transforms(*, data):
    
    if data == 'train':
        return A.Compose([
            ToTensorV2(),
        ])

    elif data == 'valid':
        return A.Compose([
            ToTensorV2(),
        ])

In [None]:
train_dataset = TrainDataset(train, transform=get_transforms(data='train'))

train_dl = DataLoader(train_dataset, 
                      batch_size=Config.batch_size,
                      num_workers=Config.num_workers,
                      shuffle=True,
                      pin_memory=True)

valid_dataset = TrainDataset(valid, transform=get_transforms(data='valid'))

valid_dl = DataLoader(valid_dataset,
                     batch_size=Config.batch_size,
                     num_workers=Config.num_workers,
                     pin_memory=True)

In [None]:
sample = None
for i in train_dl:
    print(i[0].shape, i[1])
    sample = i[0]
    break

# MODEL

In [None]:
class Classifier(pl.LightningModule):
    
    def __init__(self):
        super(Classifier, self).__init__()
        self.model = base_model
        
    def forward(self, x):
        output = self.model(x)
        return output
    
    def training_step(self, batch, batch_no):
        images, labels = batch
        outputs = self(images)
        loss = Config.LOSS(outputs.view(-1), labels)
        return loss
    
    def validation_step(self, batch, batch_no):
        images, labels = batch
        outputs = self(images)
        loss = Config.LOSS(outputs.view(-1), labels)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=Config.lr, 
                                     weight_decay=Config.weight_decay)
        return optimizer

In [None]:
model = Classifier()

In [None]:
wandb_logger = pl.loggers.WandbLogger(project='G2Net')

trainer = pl.Trainer(gpus=1, max_epochs=Config.epochs, fast_dev_run=Config.dev_run, logger=wandb_logger)

In [None]:
trainer.fit(model, train_dl, valid_dl)