This notebook shows an implementation of
the [Fmix](https://arxiv.org/pdf/2002.12047.pdf) data augmentation by Tensorflow.
Fmix is a kind of Mixed Sample Data Augmentation.
It uses binary masks obtained by applying a threshold to low frequency images sampled from Fourier space.

I referred the followings:
* [Understanding and Enhancing Mixed Sample Data Augmentation](https://arxiv.org/pdf/2002.12047.pdf) -- Original paper.
* https://github.com/ecs-vlc/FMix -- Original code.
* [Getting started with 100+ flowers on TPU](https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu) -- based on this notebook.
* https://github.com/numpy/numpy/blob/master/numpy/fft/helper.py -- fftfreq() implementation.

In [None]:
import math, re, os
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from kaggle_datasets import KaggleDatasets
print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE

# TPU or GPU detection

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

# Competition data access

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path() # you can list the bucket with "!gsutil ls $GCS_DS_PATH"

# Configuration

In [None]:
IMAGE_SIZE = [512, 512] # At this size, a GPU will run out of memory. Use the TPU.
                        # For GPU training, please select 224 x 224 px image size.
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

GCS_PATH_SELECT = { # available image sizes
    192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',
    224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
    331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',
    512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'
}
GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]

TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')

# Datasets

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

## Fmix

The steps of Fmix are as follows:
1. We  first  sample  a  random  complex  tensor  for which both the real and imaginary part are independent and Gaussian.
2. We then scale each component according to its frequency via the parameter δ such that higher values of δ correspond to increased decay of high frequency information.
3. Next, we perform an inverse Fourier transform on the complex tensor and take the real part to obtain a grey-scale image.
4. Finally, we set the top proportion of the image to have value ‘1’ and the rest to have value ‘0’ to obtain our binary mask, then mix two images by using this mask.

In [None]:
# I could not find an equivalent in Tensorflow, so I made this.
# Original: https://github.com/numpy/numpy/blob/master/numpy/fft/helper.py

def fftfreq(n, d=1.0):
    val = 1.0 / (n * d)
    N = (n - 1) // 2 + 1
    p1 = tf.range(0, N, dtype=tf.float32)
    p2 = tf.range(-(n // 2), 0, dtype=tf.float32)
    results = tf.concat([p1, p2], 0)
    return results * val

In [None]:
# https://github.com/ecs-vlc/FMix/blob/master/fmix.py

# Parameters are 'h' and 'w' for simplicity.
def fftfreqnd(h, w):
    """ Get bin values for discrete fourier transform of size (h, w)
    """
    # In the original implementation, '[: w // 2 + 2]' or '[: w // 2 + 1]' is
    # applied to fx.  However, I don't do this here, because tf.signal.ifft2d
    # returns the same shape as the input.  tf.signal.ifft2d does not accept
    # shape, as in np.fft.irfftn used in the original code.  I think that
    # a tensor of width by height is necessary here for tf.signal.ifft2d.
    fx = fftfreq(w)  # [: w // 2 + 2] or [ : w // 2 + 1]
    fy = fftfreq(h)
    
    fx_square = fx * fx
    fy_square = fy * fy
    return tf.math.sqrt(fx_square[ tf.newaxis, : ] + fy_square[ : , tf.newaxis ])

Here is the result of 'fftfreqnd(512, 512)'.
Four corners correspond to low frequency,
and center is high frequency.
Now, the values at corners are small
and those at center are big.
The result here will be used as a divisor,
so finally values for low frequency become big and
those for high frequency are small.

In [None]:
ffreqnd = fftfreqnd(512, 512)

plt.pcolor(ffreqnd)
plt.colorbar()
plt.title('fftfreqnd(512, 512) result')
plt.show()

One difficult point in Tensorflow program is
to find the shape of a tensor.
To clarify this point, I use suffix.
The characters used in suffix are:
* b -- batch
* h, w, c -- image height, width, and color.
* p -- pixels in an image (height x width)
* t -- total pixels in a batch (batch x pixels in a image)
* 1, 2 -- constant

For example:
* ...._bhwc means rank 4 tensor and each dimensions are for batch, height, width, and color.
* tf.expand_dims(..._bhw, -1) expands one dimention at last, so the result is ..._bhw1.

In [None]:
IMAGE_MAX_W_H = max(IMAGE_SIZE[0], IMAGE_SIZE[1])

def get_spectrum(data_count, freqs, decay_power):
    # Make a tensor to scale frequencies, low frequencies are bigger
    # and high frequencies are smaller.
    # Make freqs greater than 0 to avoid division by 0.
    lowest_freq = tf.constant(1. / IMAGE_MAX_W_H)
    freqs_gt_zero = tf.math.maximum(freqs, lowest_freq)
    scale_hw = 1.0 / tf.math.pow(freqs_gt_zero, decay_power)

    # Generate random Gaussian distribution numbers of data_count x height x width x 2.
    # 2 in the last dimension is for real and imaginary part of a complex number.
    # In the original program, the first dimention is used for channels.
    # In this program, it is used for data in a batch.
    param_size = [data_count] + list(freqs.shape) + [2]
    param_bhw2 = tf.random.normal(param_size)

    # Make a spectrum by multiplying scale and param.  For scale,
    # expand first and last dimension for batch and real/imaginary part.
    scale_1hw1 = tf.expand_dims(scale_hw, -1)[ tf.newaxis, : ]
    spectrum_bhw2 = scale_1hw1 * param_bhw2
    return spectrum_bhw2

In [None]:
def make_low_freq_images(data_count, decay):
    # Make a mask image by inverse Fourier transform of a spectrum,
    # which is generated by get_spectrum().
    freqs = fftfreqnd(*IMAGE_SIZE)
    spectrum_bhw2 = get_spectrum(data_count, freqs, decay)
    spectrum_re_bhw = spectrum_bhw2[:, :, :, 0]
    spectrum_im_bhw = spectrum_bhw2[:, :, :, 1]
    spectrum_comp_bhw = tf.complex(spectrum_re_bhw, spectrum_im_bhw)
    mask_bhw = tf.math.real(tf.signal.ifft2d(spectrum_comp_bhw))

    # Scale the mask values from 0 to 1.
    mask_min_b = tf.reduce_min(mask_bhw, axis=(-2, -1))
    mask_min_b11 = mask_min_b[ :, tf.newaxis, tf.newaxis]
    mask_shift_to_0_bhw = mask_bhw - mask_min_b11
    mask_max_b = tf.reduce_max(mask_shift_to_0_bhw, axis=(-2, -1))
    mask_max_b11 = mask_max_b[ :, tf.newaxis, tf.newaxis]
    mask_scaled_bhw = mask_shift_to_0_bhw / mask_max_b11
    return mask_scaled_bhw

In [None]:
# Helper function to draw a tensor as images.
# Expected tensor shape is [batch, height, width, and optinal color]
# Expected values are from 0 to 1.
def plot_as_images(tf_0_1, title, figsize=(12, 2), rows=1):
    plt.figure(figsize=figsize)
    # If the shape is 'bhw1', then make it 'bhw'.
    if tf_0_1.ndim == 4 and tf_0_1.shape[-1] == 1:
        tf_0_1 = tf.reshape(tf_0_1, tf_0_1.shape[:3])
    np_0_255 = (tf_0_1.numpy() * 255).astype(np.uint8)
    image_count = np_0_255.shape[0]
    cols = (image_count + rows - 1) // rows
    for i in range(image_count):
        plt.subplot(rows, cols, i + 1)
        im_pil = Image.fromarray(np_0_255[i])
        plt.imshow(im_pil, cmap='gray')
        plt.title(title + str(i + 1))
        plt.axis("off")
    plt.show()

Here are some samples of low frequency images.

In [None]:
lfimages = make_low_freq_images(5, 3.0)
plot_as_images(lfimages, "Low Freq Image ")

In [None]:
IMAGE_PIXEL_COUNT = IMAGE_SIZE[0] * IMAGE_SIZE[1]

def make_binary_masks(data_count, low_freq_images_bhw, mix_ratios_b):
    # The goal is "top proportion of the image to have value ‘1’ and the rest to have value ‘0’".
    # To make this I use tf.scatter_nd().  For tf.scatter_nd(), indices and values
    # are necessary.
    
    # For each image, get indices of an image whose order is sorted from top to bottom.
    # These are used for row indices.  To combine with column indices, they are reshaped to 1D.
    low_freq_images_bp = tf.reshape(low_freq_images_bhw, [data_count, -1])
    row_indices_bp = tf.argsort(low_freq_images_bp, axis=-1, direction='DESCENDING', stable=True)
    row_indices_t = tf.reshape(row_indices_bp, [-1])
    
    # Make column indices, col_indices_t looks like
    # '[ 0 ... 0 1 ... 1 ..... data_count-1 ... data_count-1]'
    col_indices_b = tf.range(data_count, dtype=tf.int32)
    col_indices_t = tf.repeat(col_indices_b, IMAGE_PIXEL_COUNT, axis=-1)
    
    # Combine column and row indices for scatter_nd.
    scatter_indices_2t = tf.stack([col_indices_t, row_indices_t])
    scatter_indices_t2 = tf.transpose(scatter_indices_2t)

    # Make a tensor which looks like:
    # [[ 0.0 ... 1.0 ]   \  <-- tf.linspace(0.0, 1.0, IMAGE_PIXEL_COUNT)
    #   ...               | data_count
    #  [ 0.0 ... 1.0 ]]  /
    linspace_0_1_p = tf.linspace(0.0, 1.0, IMAGE_PIXEL_COUNT)
    linspace_0_1_1p = linspace_0_1_p[ tf.newaxis, : ]
    linspace_0_1_bp = tf.repeat(linspace_0_1_1p, data_count, axis=0)
    
    # Make mix_ratio of the top elements in each data '1' and the rest '0'
    # This looks like:
    # [[ 1.0 1.0 ... 0.0 ]   \    <-- top mix_ratios_b[0] elements are 1.0
    #   ...                   | data_count
    #  [ 1.0 1.0 ... 0.0 ]]  /    <-- top mix_ratios_b[data_count - 1] elements are 1.0
    mix_ratios_b1 = mix_ratios_b[ :, tf.newaxis]
    scatter_updates_bp = tf.where(linspace_0_1_bp <= mix_ratios_b1, 1.0, 0.0)
    scatter_updates_t = tf.reshape(scatter_updates_bp, [-1])
    
    # Make binary masks by using tf.scatter_nd(), then reshape.
    bin_masks_bp = tf.scatter_nd(scatter_indices_t2, scatter_updates_t, [data_count, IMAGE_PIXEL_COUNT])
    bin_masks_bhw1 = tf.reshape(bin_masks_bp, [data_count, *IMAGE_SIZE, 1])
    return bin_masks_bhw1

Here are some binary mask samples.

In [None]:
lfimages = make_low_freq_images(5, 3.0)
bin_masks = make_binary_masks(5, lfimages, tf.constant([ 0.1, 0.3, 0.5, 0.7, 0.9 ]))
plot_as_images(bin_masks, "Binary mask ")

In [None]:
def do_mix(mix_ratios, orig_batch, mixing_batch):
    # Mix original and mixing batch data with the specified ratios.
    mixed_batch = (mix_ratios * orig_batch) + ((1.0 - mix_ratios) * mixing_batch)
    return mixed_batch

The followings are samples for low frequency gray-scale images,
binary masks, images to be mixed, and Fmix images.

In [None]:
lfimages = make_low_freq_images(5, 3.0)
bin_masks = make_binary_masks(5, lfimages, tf.constant([ 0.5 ] * 5))
dataset = iter(load_dataset(TRAINING_FILENAMES, labeled=True).batch(5))
images_1, labels_1 = next(dataset)
images_2, labels_2 = next(dataset)
mixed_images = do_mix(bin_masks, images_1, images_2)

plot_as_images(lfimages, "Gray-scale ")
plot_as_images(bin_masks, "Binary mask ")
plot_as_images(images_1, "Image 1-")
plot_as_images(images_2, "Image 2-")
plot_as_images(mixed_images, "Fmix ")

In [None]:
import tensorflow_probability as tfp
tfd = tfp.distributions

# Hyper parameters and a beta distribution to generate mix ratios.
FMIX_ALPHA = 1.0
FMIX_DECAY = 3.0
fmix_beta_dist = tfd.Beta(FMIX_ALPHA, FMIX_ALPHA)

def fmix(orig_images):
    # orig_images.shape[0] is None when making computational graph.
    # So, this should be a tensor.
    data_count = tf.shape(orig_images)[0]
    
    # Randomly select mixing images.
    data_indices = tf.range(data_count, dtype=tf.int32)
    mixing_indices = tf.random.shuffle(data_indices)
    mixing_images = tf.gather(orig_images, mixing_indices)

    # Generate mix ratios by beta distribution.
    mix_ratios = fmix_beta_dist.sample([data_count])

    # Generate binary masks, then mix images.
    low_freq_images = make_low_freq_images(data_count, FMIX_DECAY)
    bin_masks = make_binary_masks(data_count, low_freq_images, mix_ratios)
    mixed_images = do_mix(bin_masks, orig_images, mixing_images)
    return mixed_images

Here are Fmix samples.

In [None]:
dataset = iter(load_dataset(TRAINING_FILENAMES, labeled=True).batch(15))
orig_images, orig_labels = next(dataset)
mixed_images = fmix(orig_images)

plot_as_images(mixed_images, "Fmix ", figsize=(12, 7), rows=3)

Thanks for reading!
If you have any questions or find mistakes, please let me know.

# History
* 2020/03/29 -- First version
* 2020/04/25 -- Fixed an issue for "data_count" in "fmix()". It should be a tensor. Otherwise it might generate an error if data count was not the predefined constant "BATCH_SIZE". 