# XYMasking aug in Albumentations

In [Albumentations](https://albumentations.ai/) library we have many image related augmentations, but, apparantely there was none that could do the same as [TimeMasking](https://pytorch.org/audio/main/generated/torchaudio.transforms.TimeMasking.html) [FrequencyMasking](https://pytorch.org/audio/main/generated/torchaudio.transforms.FrequencyMasking.html) from torchaudio.

Technically it is possible to create such transforms for this competition using [LambdaTransform](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Lambda) where you define python function that should be applied to images, so, for all practical purposes you do not need special transform.

But:
1. Native tranforms could be serialized /deserialized to json, yaml and python dictionaries, which is not the case for Lambda.
2. Audio and eeg data becomes more and more popular => it makes sense to create something that works out of the box.
3. The architecture of native masking transforms ([CoarseDropout](https://albumentations.ai/docs/api_reference/augmentations/dropout/coarse_dropout/), [MaskDropout](https://albumentations.ai/docs/api_reference/augmentations/dropout/mask_dropout/), [GridDropout](https://albumentations.ai/docs/api_reference/augmentations/dropout/grid_dropout/)) allows the transform to be applied in a more advanced settings that I will describe in this kernel.

## XYMasking

We added [XYMasking](https://albumentations.ai/docs/api_reference/augmentations/dropout/xy_masking/) in version 1.4.0 (released on 17 Feb 2024), it may take some time for Kaggle docker to update.


In [None]:
%matplotlib inline

In [None]:
from pylab import *
import torchaudio
import torch
from pathlib import Path
import numpy as np
import pandas as pd

In [None]:
!pip install -U albumentations

In [None]:
import albumentations as A

In [None]:
def spectrogram_from_eeg4(parquet_path: Path) -> np.ndarray:
    NAMES = ["LL", "LP", "RP", "RR"]

    FEATS = [
        ["Fp1", "F7", "T3", "T5", "O1"],
        ["Fp1", "F3", "C3", "P3", "O1"],
        ["Fp2", "F8", "T4", "T6", "O2"],
        ["Fp2", "F4", "C4", "P4", "O2"],
    ]
    
    FOUR = 4
    
    # Load middle 50 seconds of EEG series
    eeg = pd.read_parquet(parquet_path)
    middle = (len(eeg) - 10_000) // 2
    eeg = eeg.iloc[middle : middle + 10_000]

    # Variable to hold spectrogram
    img = np.zeros((128, 256, 4), dtype="float32")

    for k in range(4):
        cols = FEATS[k]

        for kk in range(4):
            # Compute pair differences
            x = eeg[cols[kk]].to_numpy() - eeg[cols[kk + 1]].to_numpy()

            # Fill NaNs
            m = np.nanmean(x)
            x = np.where(np.isnan(x), m, x)  # Vectorized operation for replacing NaNs

            # Convert to tensor and add a batch dimension
            x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0)

            # Create MelSpectrogram object
            mel_spectrogram = torchaudio.transforms.MelSpectrogram(
                sample_rate=200,
                n_fft=1024,
                win_length=128,
                hop_length=len(x) // 256,
                n_mels=128,
                f_min=0,
                f_max=20,
                power=2.0,
            )

            # Compute spectrogram
            mel_spec_tensor = mel_spectrogram(x_tensor)

            # Convert power spectrogram to dB scale
            mel_spec_db_tensor = torchaudio.transforms.AmplitudeToDB(stype="power")(mel_spec_tensor)

            # Ensure the spectrogram is the expected shape
            width = min(mel_spec_db_tensor.shape[2], 256)
            mel_spec_db_tensor = mel_spec_db_tensor[:, :, :width].squeeze(0)  # Remove batch dimension

            # Standardize to -1 to 1
            mel_spec_db_np = (mel_spec_db_tensor.numpy() + 40) / 40
            img[:, :width, k] += mel_spec_db_np

        # Average the 4 montage differences
        img[:, :width, k] /= 4.0

    return img[::-1]

In [None]:
eeg_file_path = "/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/2208063991.parquet"

In [None]:
img = spectrogram_from_eeg4(eeg_file_path)

In [None]:
print("Spectrogram shape = ", img.shape)
print(img.min(), img.max())

In [None]:
img = np.ascontiguousarray(img)

In [None]:
imshow(img[:, :, 0])

## One vertical stripe (time masking) with fixed width, filled with 0 

In [None]:
params1 = {
    "num_masks_x": 1,    
    "mask_x_length": 20,
    "fill_value": 0,    

}
transform1 = A.Compose([A.XYMasking(**params1, p=1)])
imshow(transform1(image=img[:, :, 0])["image"])

## One vertical stripe (time masking) with randomly sampled width, filled with 0

In [None]:
params2 = {
    "num_masks_x": 1,    
    "mask_x_length": (0, 20), # This line changed from fixed  to a range
    "fill_value": 0,
}
transform2 = A.Compose([A.XYMasking(**params2, p=1)])
imshow(transform2(image=img[:, :, 0])["image"])

### Analogous transform in torchaudio

In [None]:
spectrogram = torchaudio.transforms.Spectrogram()
masking = torchaudio.transforms.TimeMasking(time_mask_param=20)
masked = masking(torch.from_numpy(img[:, :, 0]))
imshow(masked.numpy())

## One horizontal stripe (frequency masking) with randomly sampled width, filled with 0

In [None]:
params3 = {    
    "num_masks_y": 1,    
    "mask_y_length": (0, 20),
    "fill_value": 0,    

}
transform3 = A.Compose([A.XYMasking(**params3, p=1)])
imshow(transform3(image=img[:, :, 0])["image"])

### Analogous transform in torchaudio

In [None]:
spectrogram = torchaudio.transforms.Spectrogram()
masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)
masked = masking(torch.from_numpy(img[:, :, 0]))
imshow(masked)

## Several vertical and horizontal stripes

In [None]:
params4 = {    
    "num_masks_x": (2, 4),
    "num_masks_y": 5,    
    "mask_y_length": 8,
    "mask_x_length": (10, 20),
    "fill_value": 0,  

}
transform4 = A.Compose([A.XYMasking(**params4, p=1)])
imshow(transform4(image=img[:, :, 0])["image"])

## Application to the image with the number of channels larger than 3, and different fill values for different channels 

Transform can work with any number of channels supporing image shapes of 

* Grayscale: (height, width)
* RGB: (height, width, 3)
* Multichannel: (heigh, width, num_channels)

For value that is used to fill masking regions you can use:
* scalar that will be applied to every channel
* list of numbers equal to the number of channels, so that every channel will have it's own filling value

In [None]:
params5 = {    
    "num_masks_x": (2, 4),
    "num_masks_y": 5,    
    "mask_y_length": 8,
    "mask_x_length": (20, 30),
    "fill_value": (0, 1, 2, 3),  

}
transform5 = A.Compose([A.XYMasking(**params5, p=1)])
transformed = transform5(image=img)["image"]

In [None]:
# Create a figure and 4 subplots (1 row, 4 columns)
fig, axs = plt.subplots(2, 2) # Adjust figsize to fit your needs

vmin=0
vmax=3

axs[0, 0].imshow(transformed[:, :, 0], vmin=vmin, vmax=vmax)
axs[0, 0].set_title('Channel 0')
axs[0, 0].axis('off')  # Hide axes for cleaner visualization

axs[0, 1].imshow(transformed[:, :, 1], vmin=vmin, vmax=vmax)
axs[0, 1].set_title('Channel 1')
axs[0, 1].axis('off')

axs[1, 0].imshow(transformed[:, :, 2], vmin=vmin, vmax=vmax)
axs[1, 0].set_title('Channel 2')
axs[1, 0].axis('off')

axs[1, 1].imshow(transformed[:, :, 3], vmin=vmin, vmax=vmax)
axs[1, 1].set_title('Channel 3')
axs[1, 1].axis('off')

plt.tight_layout()

### Conclusion

If you have any questions, proposals, feature requests 

1. Feel free to write them to [issues in repository](https://github.com/albumentations-team/albumentations/issues) 
or
2. Join [Discord server](https://discord.gg/AmMnDBdzYs)
or
3. Write me directly on [LinkedIn](https://www.linkedin.com/in/iglovikov/)

**And the most important - I hope that this transform will be helpful in this competition!**