# CutMix with FastAI

In this notebook I  want to share how to implimente CutMix usnig fastai callback system. I used [this](https://www.kaggle.com/muellerzr/fastai-efficientnet-b3-with-ranger-83-5) notebook by Zach Mueller as a reference, which is a great fastai starter.

After writing this I found that this has already been developed in fastai. It has been discussed and developed in the [forums ](https://forums.fast.ai/t/implementing-cutmix-in-fastaiv2/67350) by Aman Arora, Sanyam Bhutani, Akash Palrecha, Rekil Prashanth, Asimo, and PRed in the [repo](https://github.com/fastai/fastai/pull/3037) by Zach Mueller.

## Installs and imports

We install `timm` to have access to EfficientNet with pretrained weights. 
To install a new library without internet access I download the wheels file form [pypi](https://pypi.org/project/timm/#files) and I add it as a dataset clicking this button on the  top right corner:
![](https://i.ibb.co/X73N5PZ/Screenshot-from-2020-11-25-17-02-25.png)

Then I install it with this command:
```python
!pip install ../input/pytorchimagemodels/
```
But for this notebook we can turn internet access on and download it with pip

In [None]:
!pip install timm

In [None]:
from fastai.vision.all import *
from fastai.callback.mixup import *
from torch.distributions.beta import Beta
import timm
set_seed(314)

## CutMix

We define a `CutMix` callback for regularization ([paper](https://arxiv.org/pdf/1905.04899.pdf)).

In [None]:
class CutMix(MixUp):
    def __init__(self, alpha=1.): self.distrib = Beta(tensor(alpha), tensor(alpha))
    def before_batch(self):
        lam = self.distrib.sample().squeeze().to(self.x.device)
        shuffle = torch.randperm(self.y.size(0)).to(self.x.device)
        self.yb1 = tuple(L(self.yb).itemgot(shuffle))
        nx_dims = len(self.x.size())
        bs, c, h, w = self.x.shape
        rx, ry = w*self.distrib.sample(), h*self.distrib.sample()
        rw, rh = w*(1-lam).sqrt(), h*(1-lam).sqrt()
        x1 = (rx-rw/2).clamp(min=0).round().to(int)
        x2 = (rx+rw/2).clamp(max=w).round().to(int)
        y1 = (ry-rh/2).clamp(min=0).round().to(int)
        y2 = (ry+rh/2).clamp(max=h).round().to(int)
        self.learn.xb[0][:,:,y1:y2,x1:x2] = self.learn.xb[0][shuffle,:,y1:y2,x1:x2]
        self.lam = 1- float(x2-x1)*(y2-y1)/(h*w)
        
        if not self.stack_y:
            ny_dims = len(self.y.size())
            self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))

This callback is basically a copy of the `MixUp` callback that's already in `fastai` ([here](https://github.com/fastai/fastai/blob/master/nbs/19_callback.mixup.ipynb) is the code).

We load the csv, map the labels and add a `valid` column indicating if a row belongs to the training or to validation dataset

In [None]:
df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
df['image_id'] = df['image_id'].apply(lambda x: f'train_images/{x}')
df.head()
idx2label = json.load(open('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json'))
df.label = df.label.map(str).map(idx2label)
idxs = L.range(len(df)).shuffle()
df['valid'] = False
df.loc[idxs[:4279], 'valid'] = True

With the dataset defined as above, we can use `ImageDataLoaders.from_df` to load our data. We use `size=300` because efficientnet_b3 was trained with images of that resolution ([here](https://github.com/rwightman/pytorch-image-models/blob/9c406532bde4ffe281d356de6e597717d2e53205/timm/models/efficientnet.py#L225)).
We use `mult=2` and `max_zoom=2` to increase augmentation.

In [None]:
dls = ImageDataLoaders.from_df(df, path='../input/cassava-leaf-disease-classification/', bs=64,
                               item_tfms=Resize(320),
                               valid_col='valid',
                               batch_tfms=aug_transforms(size=300,  mult=2, max_zoom=2.))
dls.show_batch()

Let's see `CutMix` in action.

In [None]:
mixup = CutMix()
with Learner(dls, nn.Linear(3,4), loss_func=CrossEntropyLossFlat(), cbs=mixup) as learn:
    learn.epoch,learn.training = 0,True
    learn.dl = dls.train
    b = dls.one_batch()
    learn._split(b)
    learn('before_batch')

_,axs = plt.subplots(3,3, figsize=(9,9))
dls.show_batch(b=(mixup.xb[0],mixup.y), ctxs=axs.flatten())

## EfficientNet

In [None]:
model = timm.create_model('tf_efficientnet_b3_ns', pretrained=False)

In [None]:
model.load_state_dict(torch.load('../input/timm-pretrained-efficientnet/efficientnet/tf_efficientnet_b3_ns-9d44bf68.pth'))

## Train

In [None]:
model.classifier = nn.Linear(model.classifier.in_features, len(dls.vocab))
learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy(), splitter=methodcaller('parameters'), metrics=accuracy, model_dir='/kaggle/working/models')
learn.to_fp16()
learn.freeze()

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(16, base_lr=8.3e-4, cbs=[ShowGraphCallback(), CutMix()])

## Predict

We can calculate the accuracy with tta

In [None]:
dl = dls.valid
a1, target = learn.tta(dl=dl, n=16)
pred_1 = a1.argmax(dim=1)
(pred_1==target).to(float).mean()

Submit the results

In [None]:
sample_df = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
sample_df.head()
sample_copy = sample_df.copy()
sample_copy['image_id'] = sample_copy['image_id'].apply(lambda x: f'test_images/{x}')
test_dl = learn.dls.test_dl(sample_copy)
test_dl.show_batch()

In [None]:
a, _ = learn.tta(dl=test_dl, n=16)
pred = a.argmax(dim=1).numpy()
sample_df['label'] = pred

In [None]:
sample_df.to_csv('submission.csv',index=False)