In [None]:
# default_exp data.mixed_augmentation

# Label-mixing transforms

> Callbacks that perform data augmentation by mixing samples in different ways.

In [None]:
#export
from tsai.imports import *
from tsai.utils import *
from tsai.data.external import *
from tsai.data.core import *
from tsai.data.preprocessing import *

In [None]:
#export
from torch.distributions.beta import Beta

In [None]:
#export
def _reduce_loss(loss, reduction='mean'):
    "Reduce the loss based on `reduction`"
    return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss

In [None]:
#export
class MixHandler1d(Callback):
    "A handler class for implementing mixed sample data augmentation"
    run_valid = False

    def __init__(self, alpha=0.5):
        self.distrib = Beta(alpha, alpha)

    def before_train(self):
        self.labeled = True if self.dls.d != 0 else False
        if self.labeled:
            self.stack_y = getattr(self.learn.loss_func, 'y_int', False)
            if self.stack_y: self.old_lf, self.learn.loss_func = self.learn.loss_func, self.lf

    def after_train(self):
        if self.labeled and self.stack_y: self.learn.loss_func = self.old_lf

    def lf(self, pred, *yb):
        if not self.training: return self.old_lf(pred, *yb)
        with NoneReduce(self.old_lf) as lf: loss = torch.lerp(lf(pred, *self.yb1), lf(pred, *yb), self.lam)
        return _reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))

In [None]:
#export
class MixUp1d(MixHandler1d):
    "Implementation of https://arxiv.org/abs/1710.09412"

    def __init__(self, alpha=.4):
        super().__init__(alpha)

    def before_batch(self):
        lam = self.distrib.sample((self.x.size(0), ))
        self.lam = torch.max(lam, 1 - lam).to(self.x.device)
        shuffle = torch.randperm(self.x.size(0))
        xb1 = self.x[shuffle]
        self.learn.xb = L(xb1, self.xb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.x.ndim - 1))
        if self.labeled:
            self.yb1 = tuple((self.y[shuffle], ))
            if not self.stack_y: self.learn.yb = L(self.yb1, self.yb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.y.ndim - 1))    
                
MixUp1D = MixUp1d

In [None]:
from tsai.models.utils import *
from tsai.models.ResNet import *
from tsai.learner import *
dsid = 'NATOPS'
X, y, splits = get_UCR_data(dsid, return_split=False)
tfms = [None, Categorize()]
batch_tfms = TSStandardize()
dls = get_ts_dls(X, y, tfms=tfms, splits=splits, batch_tfms=batch_tfms)
model = build_model(ResNet, dls=dls)

In [None]:
learn = Learner(dls, model, cbs=MixUp1d(0.4))
learn.fit_one_cycle(1)

epoch,train_loss,valid_loss,time
0,1.874871,1.762203,00:03


In [None]:
#export
class CutMix1d(MixHandler1d):
    "Implementation of `https://arxiv.org/abs/1905.04899`"

    def __init__(self, alpha=1.):
        super().__init__(alpha)

    def before_batch(self):
        bs, *_, seq_len = self.x.size()
        self.lam = self.distrib.sample((1, ))
        shuffle = torch.randperm(bs)
        xb1 = self.x[shuffle]
        x1, x2 = self.rand_bbox(seq_len, self.lam)
        self.learn.xb[0][..., x1:x2] = xb1[..., x1:x2]
        self.lam = (1 - (x2 - x1) / float(seq_len)).item()
        if self.labeled:
            self.yb1 = tuple((self.y[shuffle], ))
            if not self.stack_y:
                self.learn.yb = tuple(L(self.yb1, self.yb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.y.ndim - 1)))

    def rand_bbox(self, seq_len, lam):
        cut_rat = torch.sqrt(1. - lam)
        cut_seq_len = torch.round(seq_len * cut_rat).type(torch.long)

        # uniform
        cx = torch.randint(0, seq_len, (1, ))
        x1 = torch.clamp(cx - cut_seq_len // 2, 0, seq_len)
        x2 = torch.clamp(cx + cut_seq_len // 2, 0, seq_len)
        return x1, x2
    
CutMix1D = CutMix1d

In [None]:
dsid = 'NATOPS'
X, y, splits = get_UCR_data(dsid, return_split=False)
tfms = [None, Categorize()]
batch_tfms = TSStandardize()
dls = get_ts_dls(X, y, tfms=tfms, splits=splits, batch_tfms=batch_tfms)
model = build_model(ResNet, dls=dls)
learn = Learner(dls, model, cbs=CutMix1d(1.))
learn.fit_one_cycle(1)

epoch,train_loss,valid_loss,time
0,1.794683,1.767164,00:03


To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


In [None]:
#hide
out = create_scripts(); beep(out)

<IPython.core.display.Javascript object>

Converted 000_export.ipynb.
Converted 001_utils.ipynb.
Converted 010_data.validation.ipynb.
Converted 011_data.preparation.ipynb.
Converted 012_data.external.ipynb.
Converted 012b_data.tabular.ipynb.
Converted 012c_data.mixed.ipynb.
Converted 013_data.core.ipynb.
Converted 014_data.unwindowed.ipynb.
Converted 015_data.metadatasets.ipynb.
Converted 016_data.preprocessing.ipynb.
Converted 017_data.transforms.ipynb.
Converted 018_data.mixed_augmentation.ipynb.
Converted 019_data.image.ipynb.
Converted 020_data.features.ipynb.
Converted 050_losses.ipynb.
Converted 051_metrics.ipynb.
Converted 052_learner.ipynb.
Converted 052b_tslearner.ipynb.
Converted 053_optimizer.ipynb.
Converted 060_callback.core.ipynb.
Converted 061_callback.noisy_student.ipynb.
Converted 063_callback.MVP.ipynb.
Converted 064_callback.PredictionDynamics.ipynb.
Converted 100_models.layers.ipynb.
Converted 100b_models.utils.ipynb.
Converted 100c_models.explainability.ipynb.
Converted 101_models.ResNet.ipynb.
Converted 1