## Key Ideas

- smoother decision boundaries at multiple levels of representation
- less confident predictions
- more robust with adversarial examples (need a big/perceptible change in decision boundaries)
- Manifold has 3 effects on training:
    - smooths decision boundaries
    - improves arrangement of hidden representation (encorages regions of low confidence)
    - flattens representations (minimal amount of directions of variation)
- decision boundaries are usually sharp and close to the data
- Central message of the paper

```
Manifold Mixup improves the hidden representations and decision boundaries of neural networks at
multiple layers.
```

- Important factors for generalization (that Manifold Mixup acts on):
    - smoothness and margin
    - flatten representation/compression

In [20]:
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _WeightedLoss

In [21]:
def mixup_process(out, target_reweighted, lam):
    indices = np.random.permutation(out.size(0))
    out = out*lam + out[indices]*(1-lam)
    target_shuffled_onehot = target_reweighted[indices]
    target_reweighted = target_reweighted * lam + target_shuffled_onehot * (1 - lam)
    return out, target_reweighted


def mixup_data(x, y, alpha):

    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()
    mixed_x = lam * x + (1 - lam) * x[index,:]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

### Basic Manifold Mixup block

In [22]:
def forward(self, x, target=None, mixup_hidden = False,  mixup_alpha = 0.1, layer_mix=None):
    if self.per_img_std:
        x = per_image_standardization(x)
    if mixup_hidden == True:
        if layer_mix == None:
            layer_mix = random.randint(0,2)
        out = x
        if layer_mix == 0:
            out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        if layer_mix == 1:
            out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)
        out = self.layer2(out)
        if layer_mix == 2:
            out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)
        out = self.layer3(out)
        if layer_mix == 3:
            out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)
        out = self.layer4(out)
        if layer_mix == 4:
            out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)

        if layer_mix == 5:
            out, y_a, y_b, lam = mixup_data(out, target, mixup_alpha)

        lam = torch.tensor(lam).cuda()
        lam = lam.repeat(y_a.size())
        return out, y_a, y_b, lam


    else:
        out = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

### During Training

```python
if args.train == 'mixup':
    output, reweighted_target = model(input_var,target_var, mixup= True, mixup_alpha = args.mixup_alpha)
    loss = bce_loss(softmax(output), reweighted_target)#mixup_criterion(target_a, target_b, lam)
    """
    mixed_input, target_a, target_b, lam = mixup_data(input, target, args.mixup_alpha)
    input_var, mixed_input_var, target_var, target_a_var, target_b_var = Variable(input),Variable(mixed_input), Variable(target), Variable(target_a), Variable(target_b)

    mixed_output = model(mixed_input_var)
    output = model(input_var)

    loss_func = mixup_criterion(target_a_var, target_b_var, lam)
    loss = loss_func(criterion, mixed_output)
    """

elif args.train== 'mixup_hidden':
    output, reweighted_target = model(input_var, target_var, mixup_hidden= True, mixup_alpha = args.mixup_alpha)
    loss = bce_loss(softmax(output), reweighted_target)#mixup_criterion(target_a, target_b, lam)
    """
    input_var, target_var = Variable(input), Variable(target)
    mixed_output, target_a, target_b, lam = model(input_var, target_var, mixup_hidden = True,  mixup_alpha = args.mixup_alpha)
    output = model(input_var)

    lam = lam[0]
    target_a_one_hot = to_one_hot(target_a, args.num_classes)
    target_b_one_hot = to_one_hot(target_b, args.num_classes)
    mixed_target = target_a_one_hot * lam + target_b_one_hot * (1 - lam)
    loss = bce_loss(softmax(output), mixed_target)
    """
elif args.train == 'vanilla':
    output, reweighted_target = model(input_var, target_var)
    loss = bce_loss(softmax(output), reweighted_target)
```

## References
1. [Original repo](https://github.com/vikasverma1077/manifold_mixup)