# Search for Effective Data Augmentation & TTA
I don't know anything about data augmentation for spectograms, so I tried different types of augmentations and checked if it improves the validation score.

![](https://images.unsplash.com/photo-1613744450985-fc6372fe6a12?ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&ixlib=rb-1.2.1&auto=format&fit=crop&w=1350&q=80)

# TL;DR
I experimented with these augmentations and also measured the effect of test time augmentation:
- Mixup
- CutMix
- Horizontal flip
- Vertical flip
- Shift scale rotate
- Random resized crop
- Motion blur
- SpecAugment
- Random brightness contrast
- Gaussian noise

Comments are welcome!

# Notes
I'd like to mention these notebooks/discussions that helped me a lot!
- https://www.kaggle.com/micheomaano/efficientnet-b4-mixup-cv-0-98-lb-0-97
- https://www.kaggle.com/c/seti-breakthrough-listen/discussion/242644
- https://www.kaggle.com/yasufuminakama/seti-nfnet-l0-starter-training

**UPDATE on June 5th**: Added CutMix, SpecAugment, TTA

# Setups

In [None]:
!pip install git+https://github.com/rwightman/pytorch-image-models
import timm

In [None]:
import sys
import os
import math
import time
import random
import gc

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations.core.transforms_interface import ImageOnlyTransform

In [None]:
class CFG:
    seed = 46
    debug = False
    model_name = "tf_efficientnet_b0"
    n_epoch = 10
    n_tta = 3
    size = 256
    lr = 5e-4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device {CFG.device}")

def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(CFG.seed)

In [None]:
train = pd.read_csv('../input/seti-breakthrough-listen/train_labels.csv')
train['file_path'] = train['id'].apply(lambda x:
                                      f"../input/seti-breakthrough-listen/train/{x[0]}/{x}.npy")

if CFG.debug:
    CFG.n_epoch = 1
    train = train.sample(n=128, random_state=CFG.seed).reset_index(drop=True)
else:
    train = train.sample(frac=0.25, random_state=CFG.seed).reset_index(drop=True)

train, valid = train_test_split(train, test_size=0.25, random_state=CFG.seed)
display(train)
display(valid)

# Augmentations

In [None]:
def spec_augment(x, alpha=0.1):
    t0 = np.random.randint(0, x.shape[0])
    delta = np.random.randint(0, int(x.shape[0]*alpha))
    x[t0:min(t0+delta, x.shape[0])] = 0
    t0 = np.random.randint(0, x.shape[1])
    delta = np.random.randint(0, int(x.shape[1]*alpha))
    x[:, t0:min(t0+delta, x.shape[1])] = 0
    return x

class SpecAugment(ImageOnlyTransform):
    def apply(self, img, **params):
        return spec_augment(img)

p = 0.5
DA_DICT = {
    "spec_augment": SpecAugment(p=p),
    "hflip": A.HorizontalFlip(p=p),
    "vflip": A.VerticalFlip(p=p),
    "shift_scale_rotate": A.ShiftScaleRotate(rotate_limit=0, p=p),
    "random_resized_crop": A.RandomResizedCrop(height=CFG.size, width=CFG.size, p=p),
    "motion_blur": A.MotionBlur(p=p),
    "random_brightness_contrast": A.RandomBrightnessContrast(p=p),
    "gauss_noise": A.GaussNoise(var_limit=(0.1, 1), p=0.5)
}

# Dataset with Different Transforms

In [None]:
class SETIDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.file_names = df['file_path'].values
        self.labels = df["target"].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = np.load(self.file_names[idx]).astype(np.float32) # (6, 273, 256)
        image = np.vstack(image).transpose((1, 0)) # (256, 1638)
        image = self.transform(image=image)['image']
        label = torch.tensor(self.labels[idx]).float()
        return image, label

def get_transforms(da):
    if da not in DA_DICT:
        return A.Compose([
            A.Resize(CFG.size, CFG.size),
            ToTensorV2(),
        ])

    else:
        return A.Compose([
            A.Resize(CFG.size, CFG.size),
            DA_DICT[da],
            ToTensorV2(),
        ])
    
def mixup(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(CFG.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def rand_bbox(W, H, lam):
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(cut_w // 2, W - cut_w // 2)
    cy = np.random.randint(cut_h // 2, H - cut_h // 2)

    bbx1 = cx - cut_w // 2
    bby1 = cy - cut_h // 2
    bbx2 = cx + cut_w // 2
    bby2 = cy + cut_h // 2

    return bbx1, bby1, bbx2, bby2

def cutmix(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(CFG.device)

    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size()[1], x.size()[2], lam)
    x[:, bbx1:bbx2, bby1:bby2] = x[index, bbx1:bbx2, bby1:bby2]
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

# Training & Evaluation

In [None]:
def criterion(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets)

def mix_criterion(pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def train_one_epoch(model, optimizer, dataloader, mix=None):
    model.train()
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, (images, labels) in bar:         
        images = images.to(CFG.device)
        labels = labels.to(CFG.device)
        
        if mix=="mixup":
            images, targets_a, targets_b, lam = mixup(images, labels.view(-1, 1))
            optimizer.zero_grad()
            outputs = model(images)
            loss = mix_criterion(outputs, targets_a, targets_b, lam)
        elif mix=="cutmix":
            images, targets_a, targets_b, lam = cutmix(images, labels.view(-1, 1))
            optimizer.zero_grad()
            outputs = model(images)
            loss = mix_criterion(outputs, targets_a, targets_b, lam)
        else:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.view(-1), labels)
            
        loss.backward()
        optimizer.step()

def valid_one_epoch(model, dataloader):
    model.eval()
    targets = []
    preds = []
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, (images, labels) in bar:        
        images = images.to(CFG.device)
        labels = labels.to(CFG.device)        
        with torch.no_grad():
            outputs = model(images)
        preds.append(outputs.sigmoid().cpu().detach().numpy())
        targets.append(labels.view(-1).cpu().detach().numpy())
    
    targets = np.concatenate(targets)
    preds = np.concatenate(preds)
    return roc_auc_score(targets, preds)

def valid_one_epoch_tta(model, dataloader):
    model.eval()
    PREDS = np.zeros(len(dataloader.dataset))
    for _ in range(CFG.n_tta):
        targets = []
        preds = []
        bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for step, (images, labels) in bar:        
            images = images.to(CFG.device)
            labels = labels.to(CFG.device)        
            with torch.no_grad():
                outputs = model(images)
            preds.append(outputs.sigmoid().cpu().detach().numpy())
            targets.append(labels.view(-1).cpu().detach().numpy())
        targets = np.concatenate(targets)
        PREDS += np.concatenate(preds).reshape(-1)
    PREDS /= CFG.n_tta
    return roc_auc_score(targets, PREDS)

def run(da):
    train_dataset = SETIDataset(train, transform=get_transforms(da))
    valid_dataset = SETIDataset(valid, transform=get_transforms(None))
    tta_dataset = SETIDataset(valid, transform=get_transforms(da))
    train_loader = DataLoader(train_dataset, batch_size=32, 
                              num_workers=4, shuffle=True, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=64, 
                              num_workers=4, shuffle=False, pin_memory=True)
    tta_loader = DataLoader(tta_dataset, batch_size=64, 
                              num_workers=4, shuffle=False, pin_memory=True)
    del train_dataset, valid_dataset, tta_dataset
    model = timm.create_model(CFG.model_name, pretrained=True, in_chans=1, num_classes=1)
    model.to(CFG.device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=CFG.lr
    )
    
    valid_scores = []
    tta_scores = []
    for epoch in range(1, CFG.n_epoch + 1):
        train_one_epoch(model, optimizer, train_loader, da)
        print(da, "Train")
        gc.collect()
        torch.cuda.empty_cache()
        score = valid_one_epoch(model, valid_loader)
        valid_scores.append(score)
        print(da, "Valid", score)
        gc.collect()
        torch.cuda.empty_cache()
        if da in DA_DICT:
            score = valid_one_epoch_tta(model, tta_loader)
            tta_scores.append(score)
            print(da, "TTA", score)
            gc.collect()
            torch.cuda.empty_cache()
    return valid_scores, tta_scores

# Result

In [None]:
das = list(DA_DICT.keys()) + ["mixup", "cutmix"]

base_scores, _ = run(None)

da_adopt = []
for da in das:
    valid_scores, tta_scores = run(da)

    plt.figure()
    plt.plot(range(1, CFG.n_epoch + 1), base_scores, label="Baseline", marker=".")
    plt.plot(range(1, CFG.n_epoch + 1), valid_scores, label="+DA", marker=".")
    if da in DA_DICT:
        plt.plot(range(1, CFG.n_epoch + 1), tta_scores, label="+DA +TTA", marker=".")
    plt.xlabel("Epochs")
    plt.ylabel("Valid AUC")
    plt.legend()
    plt.title(da)
    plt.show()

    if max(valid_scores) > max(base_scores):
        print(f"{da} improves AUC!")
        da_adopt.append(da)
    elif tta_scores and max(tta_scores) > max(base_scores):
        print(f"{da} improves AUC!")
        da_adopt.append(da)
    else:
        print(f"{da} does not improve AUC.")
    
    gc.collect()
    torch.cuda.empty_cache()

print("\n===== Result Summary =====")
print("The following data augmentation improves the validation score")
for da in da_adopt:
    print(da)

# Next Directions
- [RandAugment](https://arxiv.org/abs/1909.13719)
  - As far as I understand, it automatically chooses the best set of augmentations with the optimal parameters
- Tuning augmentation probability `p`
  - [The EfficientNetV2 paper](https://arxiv.org/abs/2104.00298) says larger input images need stronger regularization and hence heavier augmentations