This notebook tries to run the aux loss model explained in the discussion [Easy Trick to Add Aux Loss](https://www.kaggle.com/c/siim-covid19-detection/discussion/263676).

In [None]:
%env SM_FRAMEWORK=tf.keras
!pip install -U segmentation-models

In [None]:
DEBUG = False

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as L
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
import math
import pandas as pd
import tensorflow_addons as tfa
import segmentation_models as sm

print(tf.__version__)

In [None]:
SEG_MODEL = sm.FPN
BACKBONE = 'efficientnetb4'
IMAGE_SIZE = 512
BATCH_SIZE = 128
INIT_LR = 1e-4
WARMUP_EPO = 2
COSINE_EPO = 28 if not DEBUG else 2
N_EPOCHS = WARMUP_EPO + COSINE_EPO
N_FOLDS = 5

VID = 'V03'
# FOLD_I_LIST = [0, 1, 2, 3, 4]
FOLD_I_LIST = [0]
FOLD_I_LIST = FOLD_I_LIST[ :1 ] if DEBUG else FOLD_I_LIST

In [None]:
train_data_name = 'siim-covid19-tfrecord-for-training'
MAX_BBOXES = 8
N_STUDY_LABELS = 4

## TPU

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() 
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # otherwise detect GPUs
    strategy = tf.distribute.MirroredStrategy() # single-GPU or multi-GPU
    
print(f"Running on {strategy.num_replicas_in_sync} replicas")

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path(train_data_name)

GCS_DS_PATH

## Dataset

In [None]:
def decode_image(image_bytes):
    image = tf.image.decode_jpeg(image_bytes, channels=3)
    image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    TFREC_FORMAT = {
        'image_id': tf.io.FixedLenFeature([], tf.string),
        'study_id': tf.io.FixedLenFeature([], tf.string),
        'fold': tf.io.FixedLenFeature([], tf.int64),
        'image': tf.io.FixedLenFeature([], tf.string),
        'xray_height': tf.io.FixedLenFeature([], tf.int64),
        'xray_width': tf.io.FixedLenFeature([], tf.int64),
        'bbox_label': tf.io.FixedLenFeature([MAX_BBOXES], tf.string),
        'confidence': tf.io.FixedLenFeature([MAX_BBOXES], tf.float32),
        'left': tf.io.FixedLenFeature([MAX_BBOXES], tf.float32),
        'top': tf.io.FixedLenFeature([MAX_BBOXES], tf.float32),
        'right': tf.io.FixedLenFeature([MAX_BBOXES], tf.float32),
        'bottom': tf.io.FixedLenFeature([MAX_BBOXES], tf.float32),
        'Negative for Pneumonia': tf.io.FixedLenFeature([], tf.int64),
        'Typical Appearance': tf.io.FixedLenFeature([], tf.int64),
        'Indeterminate Appearance': tf.io.FixedLenFeature([], tf.int64),
        'Atypical Appearance': tf.io.FixedLenFeature([], tf.int64),
    }
    
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    ids = [example['image_id'], example['study_id']]
    fold = example['fold']
    image = decode_image(example['image'])
    bbox_label = example['bbox_label']

    xray_height = tf.cast(example['xray_height'], dtype=tf.float32)
    xray_width = tf.cast(example['xray_width'], dtype=tf.float32)
    def _calc_bbox_pos(example_pos, xray_size):
        pos = example_pos * IMAGE_SIZE / xray_size
        pos = tf.cast(tf.math.round(pos), dtype=tf.int64)
        pos = tf.clip_by_value(pos, 0, IMAGE_SIZE - 1)
        return pos
    left = _calc_bbox_pos(example['left'], xray_width)
    top = _calc_bbox_pos(example['top'], xray_height)
    right = _calc_bbox_pos(example['right'], xray_width)
    bottom = _calc_bbox_pos(example['bottom'], xray_height)
    bbox_pos = [left, top, right, bottom]

    study_label = [
        example['Negative for Pneumonia'], example['Typical Appearance'],
        example['Indeterminate Appearance'], example['Atypical Appearance']]
    return ids, fold, image, bbox_label, bbox_pos, study_label

In [None]:
opacity_label = tf.constant("opacity".encode('utf-8'))

def process_bbox(ids, fold, image, bbox_label, bbox_pos, study_label):
    # opacity => 1.0, none => 0.0
    bbox_label = tf.cast(
        bbox_label == opacity_label, dtype=tf.float32)
    return ids, fold, image, bbox_label, bbox_pos, study_label

In [None]:
def make_mask(ids, fold, image, bbox_label, bbox_pos, study_label):
    lefts   = bbox_pos[0]
    tops    = bbox_pos[1]
    rights  = bbox_pos[2]
    bottoms = bbox_pos[3]

    def _make_one_mask(i):
        mask_height = bottoms[i] - tops[i] + 1
        mask_width = rights[i] - lefts[i] + 1
        mask_shape = [mask_height, mask_width]
        mask = bbox_label[i] * tf.ones(mask_shape, dtype=tf.float32)

        paddings = [
            [tops[i], IMAGE_SIZE - bottoms[i] - 1],
            [lefts[i], IMAGE_SIZE - rights[i] - 1]]
        mask = tf.pad(mask, paddings, mode='CONSTANT')
        return mask

    num_masks_rng = tf.range(MAX_BBOXES, dtype=tf.int64)
    masks = tf.map_fn(
        _make_one_mask, num_masks_rng,
        fn_output_signature=tf.float32)
    mask = tf.math.reduce_sum(masks, axis=0)
    mask = tf.math.minimum(1.0, mask)
    mask = tf.reshape(mask, [IMAGE_SIZE, IMAGE_SIZE, 1])
    return ids, fold, image, mask, study_label

In [None]:
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=None)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=None)
    dataset = dataset.map(process_bbox, num_parallel_calls=None)
    dataset = dataset.map(make_mask, num_parallel_calls=None)
    return dataset

In [None]:
tfrec_file_names = sorted(tf.io.gfile.glob(GCS_DS_PATH + '/*.tfrec'))
tfrec_file_names = \
    [ tfrec_file_names[ :4 ] ] if DEBUG else tfrec_file_names
raw_ds = load_dataset(tfrec_file_names)

print(raw_ds)

## Fold Information

In [None]:
study_id_list = []
fold_batch_list = []
for ids_batch, fold_batch, _, _, _ in raw_ds.batch(256):
    print('.', end='', flush=True)
    study_id_bin_array = ids_batch[ :, 1].numpy()
    for study_id_bin in study_id_bin_array:
        study_id_str = study_id_bin.decode('utf-8')
        study_id_list.append(study_id_str)
    fold_batch_list.append(fold_batch.numpy())

print()
fold_array = np.concatenate(fold_batch_list, axis=0)
fold_info_df = pd.DataFrame({
    'study_id': study_id_list,
    'fold': fold_array })

fold_info_df

In [None]:
fold_info_df['fold'].value_counts().sort_index()

In [None]:
def get_train_count(fold_i):
    return sum(fold_info_df['fold'] != fold_i)

def get_val_count(fold_i):
    return sum(fold_info_df['fold'] == fold_i)

def get_val_study_ids(fold_i):
    fold_mask = (fold_info_df['fold'] == fold_i)
    val_study_ids = fold_info_df.loc[fold_mask, 'study_id']
    return val_study_ids.values

## Data Augmentation

In [None]:
def image_to_float_0_1(image):
    image = tf.cast(image, dtype=tf.float32) / 255.0
    return image

In [None]:
def check_aug(aug_fun, with_mask):
    rows = 2
    cols = 5
    n_imgs = rows * cols

    _, _, images, masks, _ = next(iter(
        raw_ds.take(1).repeat(n_imgs).batch(n_imgs)))
    images = image_to_float_0_1(images)
    
    plt.figure(figsize=(12, 4))
    aug_images, aug_masks = aug_fun(images, masks)
    for i, aug_image in enumerate(aug_images):
        plt.subplot(rows, cols, i+1)
        plt.imshow(aug_image)
        plt.axis("off")
    plt.tight_layout()
    plt.show()        
    
    if with_mask:
        plt.figure(figsize=(12, 4))
        for i, aug_mask in enumerate(aug_masks):
            plt.subplot(rows, cols, i+1)
            plt.imshow(aug_mask, cmap='gray')
            plt.axis("off")
        plt.tight_layout()
        plt.show()

In [None]:
def random_int(shape=[], minval=0, maxval=1):
    return tf.random.uniform(
        shape=shape, minval=minval, maxval=maxval, dtype=tf.int32)

def random_float(shape=[], minval=0.0, maxval=1.0):
    rnd = tf.random.uniform(
        shape=shape, minval=minval, maxval=maxval, dtype=tf.float32)
    return rnd

In [None]:
class BaseAug():
    def __init__(self, p):
        self.p = p
        
    def __call__(self, images, masks):
        def _aug_one_data(i):
            image = images[i]
            mask = masks[i]
            rnd = random_float()
            return tf.cond(
                rnd <= self.p, 
                lambda: self.aug_data(image, mask),
                lambda: self._no_aug(image, mask))
        
        batch = tf.shape(images)[0]
        batch_rng = tf.range(batch, dtype=tf.int64)
        aug_images, aug_masks = tf.map_fn(
            _aug_one_data, batch_rng,
            fn_output_signature=(tf.float32, tf.float32))
        return aug_images, aug_masks

    def aug_data(self, image, mask):
        raise NotImplemented("aug_data() needs to implement")
        
    def _no_aug(self, image, mask):
        return image, mask

In [None]:
def mirror_boundary(v, max_v):
    # v % (max_v*2.0-2.0) ==> v % (512*2-2) ==> [0..1022]
    # [0..1022] - (max_v-1.0) ==> [0..1022] - 511 ==> [-511..511]
    # -1.0 * abs([-511..511]) ==> [-511..0]
    # [-511..0] + max_v - 1.0 ==> [-511..0] + 511 ==> [0..511]
    mirror_v = -1.0 * tf.math.abs(
        v % (max_v*2.0-2.0) - (max_v-1.0)) + max_v-1.0
    return mirror_v

def clip_boundary(v, max_v):
    clip_v = tf.clip_by_value(v, 0.0, max_v-1.0)
    return clip_v

def interpolate_bilinear(image, map_x, map_y):
    def _gather(image, map_x, map_y):
        map_stack = tf.stack([map_x, map_y]) # [ 2, height, width ]
        map_indices = tf.transpose(
            map_stack, perm=[1, 2, 0])       # [ height, width, 2 ]
        map_indices = tf.cast(map_indices, dtype=tf.int32)
        gather_image = tf.gather_nd(image, map_indices)
        return gather_image
    
    ll = _gather(image, tf.math.floor(map_x), tf.math.floor(map_y))
    lr = _gather(image, tf.math.ceil(map_x), tf.math.floor(map_y))
    ul = _gather(image, tf.math.floor(map_x), tf.math.ceil(map_y))
    ur = _gather(image, tf.math.ceil(map_x), tf.math.ceil(map_y))
    
    fraction_x = tf.expand_dims(map_x % 1.0, axis=-1) # [h, w, 1]
    int_l = (lr - ll) * fraction_x + ll
    int_u = (ur - ul) * fraction_x + ul
    
    fraction_y = tf.expand_dims(map_y % 1.0, axis=-1) # [h, w, 1]
    interpolate_image = (int_u - int_l) * fraction_y + int_l
    return interpolate_image

def remap(image, height, width, map_x, map_y, mode):
    assert \
        mode in ('mirror', 'constant'), \
        "mode is neither 'mirror' nor 'constant'"

    height_f = tf.cast(height, dtype=tf.float32)
    width_f = tf.cast(width, dtype=tf.float32)
    map_x = tf.reshape(map_x, shape=[height, width])
    map_y = tf.reshape(map_y, shape=[height, width])
    if mode == 'mirror':
        b_map_x = mirror_boundary(map_x, width_f)
        b_map_y = mirror_boundary(map_y, height_f)
    else:
        b_map_x = clip_boundary(map_x, width_f)
        b_map_y = clip_boundary(map_y, height_f)
        
    image_remap = interpolate_bilinear(image, b_map_x, b_map_y)
    
    if mode == 'constant':
        map_stack = tf.stack([map_x, map_y])
        map_indices = tf.transpose(map_stack, perm=[1, 2, 0])
        x_ge_0 = (0.0 <= map_indices[ : , : , 0])    # [h, w]
        x_lt_w = (map_indices[ : , : , 0] < width_f)
        y_ge_0 = (0.0 <= map_indices[ : , : , 1])
        y_lt_h = (map_indices[ : , : , 1] < height_f)
        inside_boundary = tf.math.reduce_all(
            tf.stack([x_ge_0, x_lt_w, y_ge_0, y_lt_h]), axis=0) # [h, w]
        inside_boundary = inside_boundary[ : , : , tf.newaxis]  # [h, w, 1]
        image_remap = tf.where(inside_boundary, image_remap, 0.0)

    return image_remap

### Transpose

In [None]:
class Transpose(BaseAug):
    def __init__(self, p):
        super(Transpose, self).__init__(p)
        
    def aug_data(self, image, mask):
        aug_image = tf.transpose(image, perm=[1, 0, 2])
        aug_mask = tf.transpose(mask, perm=[1, 0, 2])
        return aug_image, aug_mask

In [None]:
transpose = Transpose(p=0.5)
check_aug(transpose, with_mask=True)

### VerticalFlip

In [None]:
class VerticalFlip(BaseAug):
    def __init__(self, p):
        super(VerticalFlip, self).__init__(p)
        
    def aug_data(self, image, mask):
        aug_image = tf.image.flip_up_down(image)
        aug_mask = tf.image.flip_up_down(mask)
        return aug_image, aug_mask

In [None]:
vertical_flip = VerticalFlip(p=0.5)
check_aug(vertical_flip, with_mask=True)

### HorizontalFlip

In [None]:
class HorizontalFlip(BaseAug):
    def __init__(self, p):
        super(HorizontalFlip, self).__init__(p)
        
    def aug_data(self, image, mask):
        aug_image = tf.image.flip_left_right(image)
        aug_mask = tf.image.flip_left_right(mask)
        return aug_image, aug_mask

In [None]:
horizontal_flip = HorizontalFlip(p=0.5)
check_aug(horizontal_flip, with_mask=True)

### RandomBrightness

In [None]:
class RandomBrightness(BaseAug):
    def __init__(self, max_delta, p):
        super(RandomBrightness, self).__init__(p)
        self.max_delta = max_delta
        
    def aug_data(self, image, mask):
        aug_image = tf.image.random_brightness(image, self.max_delta)
        return aug_image, mask

In [None]:
random_brightness = RandomBrightness(max_delta=0.2, p=0.75)
check_aug(random_brightness, with_mask=False)

### RandomContrast

In [None]:
class RandomContrast(BaseAug):
    def __init__(self, lower, upper, p):
        super(RandomContrast, self).__init__(p)
        self.lower = lower
        self.upper = upper

    def aug_data(self, image, mask):
        aug_image = tf.image.random_contrast(
            image, self.lower, self.upper)
        return aug_image, mask

In [None]:
random_contrast = RandomContrast(lower=0.2, upper=0.8, p=0.75)
check_aug(random_contrast, with_mask=False)

### Blur

In [None]:
class Blur(BaseAug):
    def __init__(self, blur_limit, p):
        super(Blur, self).__init__(p)
        self.blur_limit = blur_limit
        
    def aug_data(self, image, mask):
        filter_size = random_int([], 3, self.blur_limit + 1)
        filter_shape = (filter_size, filter_size)
        aug_image = tfa.image.gaussian_filter2d(
            image, filter_shape=filter_shape)
        aug_image = tf.reshape(aug_image, [IMAGE_SIZE, IMAGE_SIZE, 3])
        return aug_image, mask

In [None]:
blur = Blur(blur_limit=5, p=1.0)
check_aug(blur, with_mask=False)

### MedianBlur

In [None]:
class MedianBlur(BaseAug):
    def __init__(self, blur_limit, p):
        super(MedianBlur, self).__init__(p)
        self.blur_limit = blur_limit
        
    def aug_data(self, image, mask):
#         filter_size = random_int([], 3, self.blur_limit + 1)
#         filter_shape = (filter_size, filter_size)
        filter_shape = (3, 3)
        aug_image = tfa.image.median_filter2d(
            image, filter_shape=filter_shape)
        aug_image = tf.reshape(aug_image, [IMAGE_SIZE, IMAGE_SIZE, 3])
        return aug_image, mask

In [None]:
median_blur = MedianBlur(blur_limit=3, p=1.0)
check_aug(median_blur, with_mask=False)

### OneOf

In [None]:
class OneOf(BaseAug):
    def __init__(self, trans1, trans2, p):
        super(OneOf, self).__init__(p)
        self.trans1 = trans1
        self.trans2 = trans2
        
    def aug_data(self, image, mask):
        rnd = random_float()
        aug_image, aug_mask = tf.cond(
            rnd <= 0.5,
            lambda: self.trans1.aug_data(image, mask),
            lambda: self.trans2.aug_data(image, mask))
        return aug_image, aug_mask

In [None]:
one_of_blur_median_blur = OneOf(
    blur, median_blur, p=0.7)
check_aug(one_of_blur_median_blur, with_mask=False)

### JpegCompression

In [None]:
class JpegCompression(BaseAug):
    def __init__(self, quality_lower, quality_upper, p):
        super(JpegCompression, self).__init__(p)
        self.quality_lower = quality_lower
        self.quality_upper = quality_upper
        
    def aug_data(self, image, mask):
        jpeg_quality = random_int(
            [], self.quality_lower, self.quality_upper + 1)
        aug_image = tf.image.adjust_jpeg_quality(image, jpeg_quality)
        aug_image = tf.reshape(aug_image, [IMAGE_SIZE, IMAGE_SIZE, 3])
        return aug_image, mask

In [None]:
jpeg_compression = JpegCompression(
    quality_lower=85, quality_upper=95, p=0.5)
check_aug(jpeg_compression, with_mask=False)

### OpticalDistortion

In [None]:
def initUndistortRectifyMap(height, width, k, dx, dy):
    height = tf.cast(height, dtype=tf.float32)
    width = tf.cast(width, dtype=tf.float32)
    
    f_x = width
    f_y = height
    c_x = width * 0.5 + dx
    c_y = height * 0.5 + dy
    
    f_dash_x = f_x
    c_dash_x = (width - 1.0) * 0.5
    f_dash_y = f_y
    c_dash_y = (height - 1.0) * 0.5

    h_rng = tf.range(height, dtype=tf.float32)
    w_rng = tf.range(width, dtype=tf.float32)
    v, u = tf.meshgrid(h_rng, w_rng)
    
    x = (u - c_dash_x) / f_dash_x
    y = (v - c_dash_y) / f_dash_y
    x_dash = x
    y_dash = y
    
    r_2 = x_dash * x_dash + y_dash * y_dash
    r_4 = r_2 * r_2
    x_dash_dash = x_dash * (1 + k*r_2 + k*r_4)
    y_dash_dash = y_dash * (1 + k*r_2 + k*r_4)

    map_x = x_dash_dash * f_x + c_x
    map_y = y_dash_dash * f_y + c_y
    return map_x, map_y

In [None]:
class OpticalDistortion(BaseAug):
    def __init__(self, distort_limit, shift_limit, p=1.0):
        super(OpticalDistortion, self).__init__(p)
        self.distort_limit = distort_limit
        self.shift_limit= shift_limit
        
    def aug_data(self, image, mask):
        k = random_float([], -self.distort_limit, self.distort_limit)
        dx = random_float([], -self.shift_limit, self.shift_limit)
        dy = random_float([], -self.shift_limit, self.shift_limit)
        image_shape = tf.shape(image)
        height = image_shape[0]
        width = image_shape[1]
        map_x, map_y = initUndistortRectifyMap(
            height, width, k, dx, dy)
        aug_image = remap(
            image, height, width, map_x, map_y, mode='mirror')
        aug_mask = remap(
            mask, height, width, map_x, map_y, mode='mirror')
        return aug_image, aug_mask

In [None]:
optical_distortion = OpticalDistortion(
    distort_limit=1.0, shift_limit=0.05, p=0.75)
check_aug(optical_distortion, with_mask=True)

### GridDistortion

In [None]:
def make_grid_distorted_maps(height, width, num_steps, xsteps, ysteps):
    def _make_maps_before_last(size, step, steps): # size=512, step=102,
                                                   # steps.shape=[num_steps]
        step_rep = tf.repeat(step, num_steps)  # [102, 102, 102, 102, 102]
        step_rep_f = tf.cast(step_rep, dtype=tf.float32)
        step_inc = step_rep_f * steps          # [102*s_0, ..., 102*s_4]
        cur = tf.math.cumsum(step_inc)         # [si_0, si_0 + si_1, ... ]
        zero = tf.zeros([1], dtype=tf.float32)
        prev = tf.concat([ zero, cur[ :-1] ], axis=0) # [0, c_0, ..., c_3]
        prev_cur = tf.stack([prev, cur])       # [[p_0, p_1, ...], [c_0, c_1, ...]]
        ranges = tf.transpose(prev_cur)        # [[p_0, c_0], [p_1, c_1], ... ]

        def _linspace_range(rng):
            return tf.linspace(rng[0], rng[1], step)
 
        maps_stack = tf.map_fn(_linspace_range, ranges)
        maps = tf.reshape(maps_stack, [-1])    # [-1] flatten into 1-D
        return maps
    
    def _make_last_map(size, step, last_start):
        last_step = size - step * num_steps  # 512 - 102*5 = 2 
        size_f = tf.cast(size, dtype=tf.float32)
        last_map = tf.linspace(last_start, size_f-1.0, last_step)
        return last_map
    
    def _make_distorted_map(size, steps):
        step = size // num_steps               # step=102 
        maps_before_last = _make_maps_before_last(size, step, steps[ :-1 ])
        last_map = _make_last_map(size, step, maps_before_last[-1])
        distorted_map = tf.concat([maps_before_last, last_map], axis=0)
        return distorted_map

    xx = _make_distorted_map(width, xsteps)
    yy = _make_distorted_map(height, ysteps)
    map_y, map_x = tf.meshgrid(xx, yy)
    return map_x, map_y

class GridDistortion(BaseAug):
    def __init__(self, num_steps, distort_limit, p=1.0):
        super(GridDistortion, self).__init__(p)
        self.num_steps = num_steps
        self.distort_limit = distort_limit
        
    def aug_data(self, image, mask):
        xsteps = tf.random.uniform(
            [self.num_steps + 1],
            minval=1.0 - self.distort_limit,
            maxval=1.0 + self.distort_limit)
        ysteps = tf.random.uniform(
            [self.num_steps + 1],
            minval=1.0 - self.distort_limit,
            maxval=1.0 + self.distort_limit)

        image_shape = tf.shape(image)
        height = image_shape[0]
        width = image_shape[1]
        map_x, map_y = make_grid_distorted_maps(
            height, width, self.num_steps, xsteps, ysteps)
        aug_image = remap(
            image, height, width, map_x, map_y, mode='mirror')
        aug_mask = remap(
            mask, height, width, map_x, map_y, mode='mirror')
        return aug_image, aug_mask

In [None]:
grid_distortion = GridDistortion(
    num_steps=5, distort_limit=1.0, p=0.75)
check_aug(grid_distortion, with_mask=True)

### OneOf

In [None]:
one_of_opt_grid_distortion = OneOf(
    optical_distortion, grid_distortion, p=0.75)
check_aug(one_of_opt_grid_distortion, with_mask=True)

### HeuSaturationValue

In [None]:
class HueSaturationValue(BaseAug):
    def __init__(
            self, hue_shift_limit, sat_shift_limit,
            val_shift_limit, p):
        super(HueSaturationValue, self).__init__(p)
        self.hue_shift_limit = hue_shift_limit
        self.sat_shift_limit = sat_shift_limit
        self.val_shift_limit = val_shift_limit
        
    def aug_data(self, image, mask):
        hue_shift = random_float(
            [], -self.hue_shift_limit, self.hue_shift_limit)
        sat_shift = random_float(
            [], -self.sat_shift_limit, self.sat_shift_limit)
        val_shift = random_float(
            [], -self.val_shift_limit, self.val_shift_limit)

        hsv_image = tf.image.rgb_to_hsv(image)
        hue_value = (hsv_image[ ... , :1 ] + hue_shift) % 1.0
        sat_value = tf.clip_by_value(
            hsv_image[ ... , 1:2 ] + sat_shift, 0.0, 1.0)
        val_value = tf.clip_by_value(
            hsv_image[ ... , 2: ] + val_shift, 0.0, 1.0)
        hsv_image = tf.concat(
            [hue_value, sat_value, val_value], axis=-1)
        aug_image = tf.image.hsv_to_rgb(hsv_image)
        return aug_image, mask

In [None]:
hue_saturation_value = HueSaturationValue(
    hue_shift_limit=0.2, sat_shift_limit=0.3,
    val_shift_limit=0.2, p=0.75)
check_aug(hue_saturation_value, with_mask=False)

### ShiftScaleRotate

In [None]:
def affine_transform(height, width, tx, ty, z, theta):
    cx = (width - 1.0) * 0.5
    cy = (height - 1.0) * 0.5
    
    center_shift_mat = tf.convert_to_tensor([
        [1.0, 0.0, -cx],
        [0.0, 1.0, -cy],
        [0.0, 0.0, 1.0]], dtype=tf.float32)
    trans_mat = center_shift_mat
    
    rot_rad = -2.0 * math.pi * theta / 360.0
    roration_mat = tf.convert_to_tensor([
        [tf.math.cos(rot_rad), tf.math.sin(rot_rad), 0.0],
        [-tf.math.sin(rot_rad), tf.math.cos(rot_rad), 0.0],
        [0.0, 0.0, 1.0]], dtype=tf.float32)
    trans_mat = tf.linalg.matmul(roration_mat, trans_mat)

    zoom_mat = tf.convert_to_tensor([
        [1.0 / z, 0.0, 0.0],
        [0.0, 1.0 / z, 0.0],
        [0.0, 0.0, 1.0]], dtype=tf.float32)
    trans_mat = tf.linalg.matmul(zoom_mat, trans_mat)
    
    shift_mat = tf.convert_to_tensor([
        [1.0, 0.0, cx - tx],
        [0.0, 1.0, cy - ty],
        [0.0, 0.0, 1.0]], dtype=tf.float32)
    trans_mat = tf.linalg.matmul(shift_mat, trans_mat)
    
    h_rng = tf.range(height, dtype=tf.float32)
    w_rng = tf.range(width, dtype=tf.float32)
    y, x = tf.meshgrid(h_rng, w_rng)
    x = tf.reshape(x, [-1])
    y = tf.reshape(y, [-1])
    ones = tf.ones_like(x)
    coord_mat = tf.stack([x, y, ones])
    
    res_mat = tf.linalg.matmul(trans_mat, coord_mat)
    map_x = res_mat[0]
    map_y = res_mat[1]
    return map_x, map_y

In [None]:
class ShiftScaleRotate(BaseAug):
    def __init__(
            self, shift_limit, scale_limit, rotate_limit, p):
        super(ShiftScaleRotate, self).__init__(p)
        self.shift_limit = shift_limit
        self.scale_limit = scale_limit
        self.rotate_limit = rotate_limit

    def aug_data(self, image, mask):
        image_shape = tf.shape(image)
        height_i = image_shape[0]
        width_i = image_shape[1]
        height_f = tf.cast(height_i, dtype=tf.float32)
        width_f = tf.cast(width_i, dtype=tf.float32)
        rnd_shift = random_float(
            [2], -self.shift_limit, self.shift_limit)
        tx = width_f * rnd_shift[0]
        ty = height_f * rnd_shift[1]
        z = random_float(
            [], 1.0 - self.scale_limit, 1.0 + self.scale_limit)
        theta = random_float(
            [], -self.rotate_limit, self.rotate_limit)

        map_x, map_y = affine_transform(
            height_f, width_f, tx, ty, z, theta)
        aug_image = remap(
            image, height_i, width_i, map_x, map_y, mode='constant')
        aug_mask = remap(
            mask, height_i, width_i, map_x, map_y, mode='constant')
        return aug_image, aug_mask

In [None]:
shift_scale_rotate = ShiftScaleRotate(
    shift_limit=0.2, scale_limit=0.3, rotate_limit=30, p=0.75)
check_aug(shift_scale_rotate, with_mask=True)

### Cutout

In [None]:
class Cutout(BaseAug):
    def __init__(self, num_cuts, mask_factor, p):
        super(Cutout, self).__init__(p)
        self.num_cuts = num_cuts
        self.mask_factor = mask_factor

    def aug_data(self, image, mask):
        image_shape = tf.shape(image)
        height_i = image_shape[0]
        width_i = image_shape[1]
        height_f = tf.cast(height_i, dtype=tf.float32)
        width_f = tf.cast(width_i, dtype=tf.float32)
        cut_h = tf.cast(height_f * self.mask_factor, dtype=tf.int32)
        cut_w = tf.cast(width_f * self.mask_factor, dtype=tf.int32)

        y_centers = random_int([self.num_cuts], 0, height_i)
        x_centers = random_int([self.num_cuts], 0, width_i)
        tops = tf.math.maximum(y_centers - cut_h//2, 0)
        lefts = tf.math.maximum(x_centers - cut_w//2, 0)
        bottoms = tf.math.minimum(tops + cut_h, height_i - 1)
        rights = tf.math.minimum(lefts + cut_w, width_i - 1)

        def _make_one_mask(i):
            mask_height = bottoms[i] - tops[i] + 1
            mask_width = rights[i] - lefts[i] + 1
            mask_shape = [mask_height, mask_width]
            mask = tf.ones(mask_shape, dtype=tf.bool)

            paddings = [
                [tops[i], height_i - bottoms[i] - 1],
                [lefts[i], width_i - rights[i] - 1]]
            mask = tf.pad(mask, paddings, mode='CONSTANT')
            return mask

        num_cuts_rng = tf.range(self.num_cuts, dtype=tf.int64)
        cut_masks = tf.map_fn(
            _make_one_mask, num_cuts_rng,
            fn_output_signature=tf.bool)
        cut_mask = tf.reduce_any(cut_masks, axis=0)
        cut_mask = cut_mask[ ..., tf.newaxis ]

        mask_value = tf.constant(0.0, dtype=tf.float32)
        aug_image = tf.where(cut_mask, mask_value, image)
        return aug_image, mask

In [None]:
cut_out = Cutout(num_cuts=1, mask_factor=0.4, p=0.75)
check_aug(cut_out, with_mask=False)

## Do Augment

In [None]:
def do_augment(image, mask, labels):
    image, mask = transpose(image, mask)
    image, mask = vertical_flip(image, mask)
    image, mask = horizontal_flip(image, mask)
    image, mask = random_brightness(image, mask)
    image, mask = random_contrast(image, mask)
    image, mask = one_of_blur_median_blur(image, mask)
    image, mask = jpeg_compression(image, mask)
    image, mask = one_of_opt_grid_distortion(image, mask)
    image, mask = hue_saturation_value(image, mask)
    image, mask = shift_scale_rotate(image, mask)
    image, mask = cut_out(image, mask)
    return image, mask, labels

## Dataset 2

In [None]:
def select_train(ds, fold_i):
    ds = ds.filter(
        lambda ids, fold, image, mask, study_label: fold != fold_i)
    return ds
    
def select_val(ds, fold_i):
    ds = ds.filter(
        lambda ids, fold, image, mask, study_label: fold == fold_i)
    return ds

In [None]:
def get_data(ids, fold, image, mask, study_label):
    return image, mask, study_label

def rescale_image(image, mask, study_label):
    image = image_to_float_0_1(image)
    return image, mask, study_label

In [None]:
def reform_for_model(image, mask, labels):
    return image, {'sigmoid': mask, 'study_label': labels}

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

def build_dataset(
        dset, augment=True, repeat=True, shuffle=1024):
    dset = dset.map(get_data, num_parallel_calls=AUTOTUNE)
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(BATCH_SIZE)
    dset = dset.map(
        rescale_image, num_parallel_calls=AUTOTUNE)
    dset = dset.map(
        do_augment, num_parallel_calls=AUTOTUNE) if augment else dset
    dset = dset.map(
        reform_for_model, num_parallel_calls=AUTOTUNE)
    dset = dset.prefetch(AUTOTUNE)
    return dset

In [None]:
def make_datasets(fold_i):
    train_ds = select_train(raw_ds, fold_i)
    train_ds = build_dataset(
        train_ds, augment=True, repeat=True, shuffle=1024)

    val_ds = select_val(raw_ds, fold_i)
    val_ds = build_dataset(
        val_ds, augment=False, repeat=False, shuffle=None)

    train_steps = get_train_count(fold_i) // BATCH_SIZE
    val_steps = get_val_count(fold_i) // BATCH_SIZE

    return train_ds, val_ds, train_steps, val_steps

## Visualization

In [None]:
train_ds, val_ds, train_steps, val_steps = make_datasets(0)

print(train_ds)
print(val_ds)
print(train_steps)
print(val_steps)

In [None]:
def show_images(ds):
    rows = 4
    cols = 5
    n_imgs = (rows//2) * cols

    images, label_dict = next(iter(
        ds.unbatch().take(n_imgs).batch(n_imgs)))
    masks = label_dict['sigmoid']
    plt.figure(figsize=(12, 8))
    for i, image in enumerate(images):
        plt.subplot(rows, cols, i+1)
        plt.imshow(image)
        plt.axis("off")
    for i, mask in enumerate(masks):
        plt.subplot(rows, cols, n_imgs+i+1)
        plt.imshow(mask, cmap='gray')
        plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
show_images(train_ds)

In [None]:
show_images(val_ds)

## Model

In [None]:
def make_model():
    base_model = SEG_MODEL(
        BACKBONE, encoder_weights='imagenet', 
        input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
        classes=1, activation='sigmoid')
    
    x = base_model.get_layer(name='top_activation').output 
    x = L.GlobalAveragePooling2D(name='avgpool')(x)
    study_label = L.Dense(
        N_STUDY_LABELS, activation='sigmoid', name='study_label')(x)

    model = tf.keras.Model(
        inputs=base_model.input, 
        outputs=[base_model.output, study_label]) 
    
    pr_auc = tf.keras.metrics.AUC(
        curve="PR", multi_label=True, name="pr_auc")
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss={
            'sigmoid': sm.losses.bce_jaccard_loss,
            'study_label': "binary_crossentropy" },
        loss_weights = [1.0, 1.0],
        metrics={
            'sigmoid': sm.metrics.iou_score,
            'study_label': pr_auc },
        # overheads and allows the XLA compiler to unroll the loop on TPU
        # and optimize hardware utilization.
        # needs to be commented out for Tensorflow 2.3
        steps_per_execution=8)
    return model

In [None]:
with strategy.scope(): 
    model = make_model()
    
initial_weights = model.get_weights()
# model.summary()

## Training

In [None]:
LR_START = INIT_LR
LR_MAX = 1e-3
LR_MIN = 1e-5
LR_RAMPUP_EPOCHS = WARMUP_EPO
LR_SUSTAIN_EPOCHS = 0
EPOCHS = N_EPOCHS

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        decay_total_epochs = EPOCHS - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS - 1
        decay_epoch_index = epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS
        phase = math.pi * decay_epoch_index / decay_total_epochs
        cosine_decay = 0.5 * (1 + math.cos(phase))
        lr = (LR_MAX - LR_MIN) * cosine_decay + LR_MIN
    return lr

rng = [i for i in range(EPOCHS)]
lr_y = [lrfn(x) for x in rng]
plt.figure(figsize=(10, 4))
plt.plot(rng, lr_y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}". \
      format(lr_y[0], max(lr_y), lr_y[-1]))

In [None]:
cb_monitor = 'val_study_label_pr_auc'

class RestoreBestWeights(tf.keras.callbacks.Callback):
    def __init__(self):
        super(RestoreBestWeights, self).__init__()
        self.best_monitor = -np.Inf
        self.best_weights = None
        self.best_epoch = None
        
    def on_epoch_end(self, epoch, logs=None):
        current_monitor = logs.get(cb_monitor)
        if current_monitor > self.best_monitor:
            self.best_monitor = current_monitor
            self.best_weights = self.model.get_weights()
            self.best_epoch = epoch
            
    def on_train_end(self, logs=None):
        print("Restoring best weights on epoch {0}, {1} was {2:.5f}".format(
            self.best_epoch + 1, cb_monitor, self.best_monitor))
        self.model.set_weights(self.best_weights)

In [None]:
def make_callbacks(fold_i):
    best_model_file_name = \
        "study_aux_loss_model_{0}_{1}.hdf5".format(VID, fold_i)
    cb_mode = 'max'
    cb_min_delta = 1e-4
    cb_verbose = 1

    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        best_model_file_name, save_best_only=True,
        save_weights_only=False, monitor=cb_monitor, mode=cb_mode,
        verbose=cb_verbose)
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose = False)
    restore_best_weights = RestoreBestWeights()
    
    return checkpoint, lr_callback, restore_best_weights

In [None]:
def fit_one_fold(fold_i):
    train_dataset, val_dataset, train_steps, val_steps = \
        make_datasets(fold_i)
    checkpoint, lr_callback, restore_best_weights = \
        make_callbacks(fold_i)
    history = model.fit(
        train_dataset, 
        epochs=EPOCHS,
        verbose=1,
        callbacks=[checkpoint, lr_callback, restore_best_weights],
        steps_per_epoch=train_steps,
        validation_data=val_dataset,
        validation_steps=val_steps)
    return history, val_dataset, val_steps

In [None]:
def plot_history(history, title, labels, subplot):
    plt.subplot(*subplot)
    plt.title(title)
    for label in labels:
        plt.plot(history.history[label], label=label)
    plt.legend()

In [None]:
def plot_fit_result(history):
    plt.figure(figsize=(12, 8))
    plot_history(
        history, "Loss",
        ['loss', 'val_loss'],
        (2, 2, 1))
    plot_history(
        history, "PR_AUC", 
        ['study_label_pr_auc', 'val_study_label_pr_auc'],
        (2, 2, 2))
    plot_history(
        history, "IOU", 
        ['sigmoid_iou_score', 'val_sigmoid_iou_score'],
        (2, 2, 3))
    plt.show()

In [None]:
def predict_one_fold(model, val_dataset, val_steps):
    val_true_list = [] 
    for _, label_dict in val_dataset:
        study_labels = label_dict['study_label']
        val_true_list.append(study_labels)
    val_true = np.concatenate(val_true_list, axis=0)
   
    val_pred_list = []
    for images, _ in val_dataset:
        pred_list = model(images, training=False)
        val_preds = pred_list[1]
        val_pred_list.append(val_preds)
    val_pred = np.concatenate(val_pred_list, axis=0)
    
    val_true = val_true[ : len(val_pred) ]
    return val_true, val_pred

In [None]:
study_labels = [
    'Negative for Pneumonia',
    'Typical Appearance',
    'Indeterminate Appearance',
    'Atypical Appearance',
]

In [None]:
from sklearn.metrics import average_precision_score

def show_average_precision_score(val_true, val_pred):
    average_precision_list = []
    for i in range(val_true.shape[-1]):
        average_precision = average_precision_score(
            val_true[ : , i], val_pred[ : , i])
        print("{0:30s}: {1:.4f}".format(
            study_labels[i], average_precision))
        average_precision_list.append(average_precision)

    mean_average_precision = np.mean(average_precision_list)
    print("{0:30s}: {1:.4f}".format(
        "Mean", mean_average_precision))
    
    plt.figure(figsize=(8, 4))
    plt.plot(average_precision_list)
    ticks = np.arange(len(study_labels))
    plt.xticks(ticks=ticks, labels=study_labels, rotation=45)
    plt.show()

In [None]:
def make_pred_str(pred):
    labels = ['negative', 'typical', 'indeterminate', 'atypical']
    pred_list = []
    for i, label in enumerate(labels):
        s = "{0} {1:.6f} 0 0 1 1".format(label, pred[i])
        pred_list.append(s)
    return ' '.join(pred_list)

def make_submission(fold_i, val_pred):
    val_study_ids = get_val_study_ids(fold_i)
    fold_sub_list = []
    for study_id, pred in zip(val_study_ids, val_pred):
        pred_str = make_pred_str(pred)
        fold_sub_list.append([study_id + "_study", pred_str])
    return fold_sub_list

In [None]:
study_sub_list = []
for fold_i in FOLD_I_LIST:
    print("####################")
    print("# Fold {0}".format(fold_i))
    model.set_weights(initial_weights)
    history, val_dataset, val_steps = fit_one_fold(fold_i)
    plot_fit_result(history)
    val_true, val_pred = predict_one_fold(model, val_dataset, val_steps)
    show_average_precision_score(val_true, val_pred)
    fold_sub_list = make_submission(fold_i, val_pred)
    study_sub_list.extend(fold_sub_list)

In [None]:
study_sub_df = pd.DataFrame(
    study_sub_list, columns=['id', 'PredictionString'])

study_sub_df

In [None]:
study_sub_file_name = "study_sub_{0}.csv".format(VID)
study_sub_df.to_csv(study_sub_file_name, index=False)

! head study_sub_*.csv