# Simple starter using Tez
This is a very simple starter based on a fantastic trainer [Tez](https://github.com/abhishekkrthakur/tez), which is developed by [@abhishek](https://www.kaggle.com/abhishek) (e.g. https://www.kaggle.com/abhishek/using-tez-in-leaf-disease-classification).

Model: efficientnetv2_rw_s

In [None]:
!pip install timm
!pip install -q nnAudio
!pip install git+https://github.com/yseeker/tez_custom
!pip install wandb
from pathlib import Path

import os
import sys
import random
from tqdm import tqdm
import math

import pandas as pd
import numpy as np
from sklearn import metrics
from sklearn.model_selection import StratifiedKFold
from sklearn import model_selection as sk_model_selection
import torch
import torch.nn as nn
import torchvision

import cv2
from PIL import Image
import albumentations as A

import tez
from tez.datasets import ImageDataset
from tez.callbacks import EarlyStopping
import timm
from nnAudio.Spectrogram import CQT1992v2

In [None]:

class CFG:
    project_name = 'project name'
    pretrained_model_name = 'efficientnetv2_rw_s'
    lr = 5e-4
    batch_size= 128
    wandb_note = f'bs{batch_size}_adamW_default_lr{lr}'
    pretrained = True
    prettained_path = '../input/timm_weight/efficientnet_v2s_ra2_288-a6477665.pth'
    input_channels = 1
    out_dim = 1
    colab_or_kaggle = 'colab'
    wandb_exp_name = f'{pretrained_model_name}_{colab_or_kaggle}_{wandb_note}'
    monitor = 'valid_roc_auc'
    epochs = 2
    num_of_fold = 5
    seed = 42
    num_workers = 8
    fp16 = False
    checkpoint_path = ''
    patience_mode = 'max'
    patience = 10
    delta = 0.001
    mixup_alpha = 1.0
    benchmark = False
    wandb = False

In [None]:
train_aug = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5,
                           shift_limit = 0.2, 
                           scale_limit=0.2,
                           rotate_limit=30, 
                           border_mode = cv2.BORDER_REPLICATE),
        A.OneOf([
            A.MedianBlur(p=0.3),
            A.MotionBlur(p=0.3)
        ]
        )
    ]
)

df = pd.read_csv('../input/g2net-gravitational-wave-detection/training_labels.csv')
df['img_path'] = df['id'].apply(
    lambda x: f"../input/g2net-gravitational-wave-detection/train/{x[0]}/{x[1]}/{x[2]}/{x}.npy"
)

X = df.img_path.values
Y = df.target.values

skf = StratifiedKFold(n_splits = CFG.num_of_fold)

In [None]:
def sigmoid(gamma):
    if gamma < 0:
        return 1 - 1 / (1 + math.exp(gamma))
    return 1 / (1 + math.exp(-gamma))

# define vectorized sigmoid
sigmoid_v = np.vectorize(sigmoid)

def set_seed(seed = 0):
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return random_state

## Dataset and Model

In [None]:
class ClassificationDataset():
    def __init__(self, image_paths, targets, transform = None): 
        self.image_paths = image_paths
        self.targets = targets
        self.transform = None
        self.wave_transform = CQT1992v2(sr=2048, fmin=20, fmax=1024, hop_length=64)

    def __len__(self):
        return len(self.image_paths)
    
    def apply_qtransform(self, waves, transform):
        waves = np.hstack(waves)
        waves = waves / np.max(waves)
        waves = torch.from_numpy(waves).float()
        image = transform(waves)
        return image
    
    def __getitem__(self, item): 
        targets = self.targets[item]
        waves = np.load(self.image_paths[item])
        image = self.apply_qtransform(waves, self.wave_transform)

        return {
            "image": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.float),
        }
    

class CustomNN(tez.Model):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model(CFG.pretrained_model_name, 
                                       pretrained = CFG.pretrained, 
                                       in_chans = CFG.input_channels)
        if not CFG.pretrained: self.model.load_state_dict(torch.load(CFG.pretrained_path))
        self.model.classifier = nn.Linear(self.model.classifier.in_features, CFG.out_dim)
            
        self.criterion =  nn.BCEWithLogitsLoss()


    def forward(self, image, targets = None):
        outputs = self.model(image)
        if targets is not None:
            loss = self.criterion(outputs, targets.view(-1, 1))
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, None, None

    def epoch_metrics(self, outputs, targets):
        outputs = sigmoid_v(outputs)
        roc_auc = metrics.roc_auc_score(targets, outputs)
        return roc_auc

    def monitor_metrics(self, outputs, targets):
        outputs = outputs.sigmoid().cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        if len(np.unique(targets)) > 1: 
            roc_auc = metrics.roc_auc_score(targets, outputs)
        else: roc_auc = 0.5
        return {"roc_auc": roc_auc}

    def configure_optimizer(self):
        opt = torch.optim.AdamW(self.parameters(), lr=CFG.lr, weight_decay=0.01)
        #opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        return opt
    
    def configure_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch

In [None]:
models = []
for fold_cnt, (train_index, test_index) in enumerate(skf.split(X, Y)):
    train_images, valid_images = X[train_index], X[test_index]
    train_targets, valid_targets = Y[train_index], Y[test_index]

    train_dataset = ClassificationDataset(
        image_paths=train_images, 
        targets=train_targets, 
        transform = None
    )
    valid_dataset = ClassificationDataset(
        image_paths=valid_images, 
        targets=valid_targets, 
        transform = None
    )
    model = CustomNN()

    es = EarlyStopping(
        monitor=CFG.monitor, 
        model_path=CFG.checkpoint_path+f'{CFG.pretrained_model_name}_{fold_cnt}fold_{CFG.wandb_note}.cpt', 
        patience= CFG.patience, 
        mode=CFG.patience_mode,
        delta = CFG.delta
    )
    model.fit(
        cfg = CFG,
        train_dataset = train_dataset,
        valid_dataset = valid_dataset,
        valid_targets = valid_targets,
        train_bs=CFG.batch_size,
        valid_bs=CFG.batch_size,
        epochs=CFG.epochs,
        callbacks=[es],
        n_jobs = CFG.num_workers,
        fp16=CFG.fp16,
        benchmark = CFG.benchmark
    )
    models.append(model)
    break

In [None]:
submission = pd.read_csv('../input/g2net-gravitational-wave-detection/sample_submission.csv')
submission['img_path'] = submission['id'].apply(
    lambda x: f"../input/g2net-gravitational-wave-detection/test/{x[0]}/{x[1]}/{x[2]}/{x}.npy"
)
test_dataset = ClassificationDataset(
    image_paths=submission.img_path.values, 
    targets=submission.target.values, 
    transform = None
)

final_preds = None
num_of_ave = 1
outs = []
for i, model in enumerate(models):
    for j in range(num_of_ave):
        preds = model.predict(test_dataset, batch_size=128, n_jobs=-1)
        if final_preds is None:
            final_preds = preds
        else:
            final_preds += preds
    final_preds = final_preds/num_of_ave
    out = sigmoid_v(final_preds)
    outs.append(out)
pred = np.mean(np.array(outs), axis=0)
submission.target = pred
submission.drop(['img_path'], axis=1, inplace=True)
submission.to_csv('submission.csv', index=False)