# CutMixSemantic

In [1]:
#| default_exp CutMixSemantic

In [2]:
#| export
from semantic_segmentation_augmentations.HoleMakerTechnique import *
from semantic_segmentation_augmentations.HoleMakerRandom import *
from semantic_segmentation_augmentations.HolesFilling import *
import numpy as np

In [3]:
#| hide
from nbdev.showdoc import *
from fastcore.utils import *

In [5]:
#| export
class CutMixSemantic(HolesFilling):
    "Defines the amount of holes, the technique used to make them and the probability of apply the technique."
    def __init__(self,
                 holes_num = 1, # The amount of holes to make.
                 occlusion_class = -1, # The class to remove. If -1, selects it randomly in each use.
                 hole_maker: "HoleMakerTechnique" = None, # The strategy used to make the holes.
                 p = 1.0): # The probability of applying this technique.
        super().__init__(hole_maker)
        self.holes_num = holes_num
        self.occlusion_class = occlusion_class
        self.p = p

    def before_batch(self):
        "Applies the CutMix technique with semantic information (only applies the CutMix to a selected class)."
        if random() < self.p:
            for image, mask in zip(self.x, self.y):
                for _ in range(self.holes_num):
                    rand = randint(0, image.shape[0])
                    other_image, other_mask = self.x[rand], self.y[rand]
                    xhole, yhole = self.make_hole(mask)
                    occlusion_value = self.occlusion_class if self.occlusion_class != -1 else randint(1, len(mask.unique()))
                    sub_image, sub_mask = TensorBase(other_image[:, yhole, xhole]), TensorBase(other_mask[yhole, xhole])
                    replacement_mask = sub_mask == occlusion_value
                    sub_image[:, replacement_mask] = torch.min(image)
                    sub_mask[replacement_mask] = 0
                    self.fill_hole(image, mask, xhole, yhole, [sub_image, sub_mask])

The default technique used to make those holes is the `HoleMakerRandom` technique.

In [6]:
show_doc(CutMixSemantic.before_batch)

---

### CutMixRandom.before_batch

>      CutMixRandom.before_batch ()

Applies the CutMix technique.

In [6]:
#| hide
import nbdev
nbdev.nbdev_export()