# FMix

### FMix is a variant of MixUp, CutMix, etc. introduced in the paper 'FMix: Enhancing Mixed Sampled Data Augmentation'. It uses masks sampled from Fourier space to mix training examples. 

### In short, it is like cutmix, but with a irregular shaped mask instead of rectangular. In this kernel, let's see how FMix and Cutmix looks like in our dataset

In [None]:
from IPython.display import Image, display

display(Image('../input/image-fmix/FMix-master/fmix_example.png'))  

In [None]:
display(Image('../input/image-fmix/FMix-master/fmix_3d.gif'))

In [None]:
package_path = '../input/image-fmix/FMix-master' #'../input/efficientnet-pytorch-07/efficientnet_pytorch-0.7.0'
import sys; sys.path.append(package_path)
from fmix import sample_mask

In [None]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import os
from datetime import datetime
import time
import random
import cv2
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from  torch.cuda.amp import autocast, GradScaler

import sklearn
import warnings
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import warnings
import cv2
import pydicom
#from efficientnet_pytorch import EfficientNet
from scipy.ndimage.interpolation import zoom

In [None]:
train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
train.head()

## Dataloader

In [None]:
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb

img = get_img('../input/cassava-leaf-disease-classification/train_images/1000015157.jpg')
plt.imshow(img)
plt.show()

In [None]:
class CassavaDataset(Dataset):
    def __init__(
        self, df, data_root, transforms=None, output_label=True
    ):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.output_label = output_label
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.df.iloc[index]['label']
          
        path = "{}/{}".format(self.data_root, self.df.iloc[index]['image_id'])
        
        img  = get_img(path)/255.
        
        if self.transforms:
            img = self.transforms(image=img)['image']
            
        # do label smoothing
        if self.output_label == True:
            return img, target
        else:
            return img

In [None]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

def get_default_transforms():
    return Compose([
            CenterCrop(128, 128, p=1.),
            Resize(128, 128),
            ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
train_ds = CassavaDataset(train, '../input/cassava-leaf-disease-classification/train_images/', transforms=get_default_transforms(), output_label=True)

train_loader = torch.utils.data.DataLoader(
    train_ds, 
    batch_size=32,
    num_workers=1,
    shuffle=False,
    pin_memory=False,
)

## Reference Kernel: https://www.kaggle.com/virajbagal/mixup-cutmix-fmix-visualisations

## Augmentation APIs

In [None]:
from pylab import rcParams
rcParams['figure.figsize'] = 20,40

In [None]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.clip(np.random.beta(alpha, alpha),0.3,0.4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    new_data = data.clone()
    new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
    targets = (target, shuffled_target, lam)

    return new_data, targets

def fmix(data, targets, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
    lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]
    x1 = torch.from_numpy(mask)*data
    x2 = torch.from_numpy(1-mask)*shuffled_data
    targets=(targets, shuffled_targets, lam)
    
    return (x1+x2), targets

## Cutmix 

In [None]:
iter_data = iter(train_loader)
data, target = next(iter_data)
data_aug, target = cutmix(data, target, 1.)


for i in range(3):
    f, axarr = plt.subplots(1,4)
    for p in range(0,3,2):
        idx = np.random.randint(0, len(data))
        img_org = data[idx]
        new_img = data_aug[idx]
        axarr[p].imshow(img_org.permute(1,2,0))
        axarr[p+1].imshow(new_img.permute(1,2,0))
        axarr[p].set_title('original')
        axarr[p+1].set_title('cutmix image')
        axarr[p].axis('off')
        axarr[p+1].axis('off')

> As we could see, there would be a 'rectangle' part that is replaced by another image

## FMix

1. FMix is a variant of MixUp, CutMix, Paper: [FMix: Enhancing Mixed Sampled Data Augmentation](https://arxiv.org/abs/2002.12047)
2. It uses masks sampled from Fourier space to mix training examples
3. [Parameters](https://github.com/ecs-vlc/FMix/blob/master/fmix.py#L124) for sample_mask:
    * alpha: Alpha value for beta distribution from which to sample mean of mask
    * decay_power: Decay power for frequency decay prop 1/f**d
    * shape: Shape of desired mask, list up to 3 dims
    * max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.

In [None]:
iter_data = iter(train_loader)
data, target = next(iter_data)
data_aug, target = fmix(data, target, alpha=1., decay_power=3., shape=(128,128))
#print(target[2])
for i in range(3):
    f, axarr = plt.subplots(1,4)
    for p in range(0,3,2):
        idx = np.random.randint(0, len(data))
        img_org = data[idx]
        new_img = data_aug[idx]
        axarr[p].imshow(img_org.permute(1,2,0))
        axarr[p+1].imshow(new_img.permute(1,2,0))
        axarr[p].set_title('original')
        axarr[p+1].set_title('fmix image')
        axarr[p].axis('off')
        axarr[p+1].axis('off')

In [None]:
iter_data = iter(train_loader)
data, target = next(iter_data)
data_aug, target = fmix(data, target, alpha=1., decay_power=10., shape=(128,128))

for i in range(3):
    f, axarr = plt.subplots(1,4)
    for p in range(0,3,2):
        idx = np.random.randint(0, len(data))
        img_org = data[idx]
        new_img = data_aug[idx]
        axarr[p].imshow(img_org.permute(1,2,0))
        axarr[p+1].imshow(new_img.permute(1,2,0))
        axarr[p].set_title('original')
        axarr[p+1].set_title('fmix image')
        axarr[p].axis('off')
        axarr[p+1].axis('off')

> In short: Fmix provides a irregularly shaped mask to do the cutmix 

## What's next
1. Apply cutmix and fmix in the training data loader and check validation performance.
2. Further improvement to create diversity within a batch of samples: Instead of doing cutmix and fmix from the same batch of samples and use the same mask, we could do cutmix and fmix from images in the same set of data and have different mask for every image within the batch!