# Credits:
* @marcosnovaes  https://www.kaggle.com/marcosnovaes/hubmap-looking-at-tfrecords and https://www.kaggle.com/marcosnovaes/hubmap-unet-keras-model-fit-with-tpu
* @mgornergoogle https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu
* qubvel https://github.com/qubvel/segmentation_models  !! 25 available backbones for each of 4 architectures


## Setups and Imports

In [None]:
%%capture
!pip install wandb -q
# using https://github.com/qubvel/segmentation_models
! pip install segmentation_models -q

In [None]:
import tensorflow as tf

import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import glob
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
import matplotlib
%matplotlib inline

from kaggle_datasets import KaggleDatasets
GCS_PATH = KaggleDatasets().get_gcs_path('hubmap-tfrecord-512')

print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE

# import segmentation models
import segmentation_models as sm

# import W&B for ML experiment tracking
import wandb
from wandb.keras import WandbCallback
!wandb login 69f60a7711ce6b8bbae91ac6d15e45d6b1f1430e

#### Setup TPU

In [None]:
try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except ValueError: # no TPU found, detect GPUs
    #strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

print("Number of accelerators: ", strategy.num_replicas_in_sync)

## Hyperparameters

In [None]:
BATCH_SIZE = 8 * strategy.num_replicas_in_sync
EPOCHS = 60
BACKBONE = 'efficientnetb7' 
NFOLDS = 4
SEED = 0
VERBOSE = 1

## Dataset

### GCS_PATHS

Based on: https://www.kaggle.com/marcosnovaes/hubmap-looking-at-tfrecords

In [None]:
%%time
uber_tile_df = pd.read_csv('/kaggle/input/hubmap-looking-at-tfrecords/train_all_tiles.csv')
uber_tile_df['gcs_path'] = uber_tile_df.replace(regex = '/kaggle/input/hubmap-tfrecord-512',value = GCS_PATH)['local_path']
uber_tile_df = uber_tile_df.loc[uber_tile_df['mask_density']  > 0].copy()
uber_tile_df.shape

### Stratified folds

In [None]:
img_ids = uber_tile_df['img_id'].unique()

kf = KFold(n_splits=NFOLDS, random_state=SEED, shuffle=True)
for train_index, val_index in kf.split(img_ids): # one fold only currently
    train_ids = [img_ids[ft] for ft in train_index]
    val_ids = [img_ids[ft] for ft in val_index]
    TRAINING_FILENAMES = list(uber_tile_df.loc[uber_tile_df['img_id'].isin(train_ids),'gcs_path'].values)
    VALIDATION_FILENAMES = list(uber_tile_df.loc[uber_tile_df['img_id'].isin(val_ids),'gcs_path'].values)

NUM_TRAINING_IMAGES = len(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = len(VALIDATION_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
print(NUM_VALIDATION_IMAGES)

### Datasets pipeline

In [None]:
# read back a record to make sure it the decoding works
preprocess_input = sm.get_preprocessing(BACKBONE)

image_feature_description = {
    'img_index': tf.io.FixedLenFeature([], tf.int64),
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'num_channels': tf.io.FixedLenFeature([], tf.int64),
    'img_bytes': tf.io.FixedLenFeature([], tf.string),
    'mask': tf.io.FixedLenFeature([], tf.string),
    'tile_id': tf.io.FixedLenFeature([], tf.int64),
    'tile_col_pos': tf.io.FixedLenFeature([], tf.int64),
    'tile_row_pos': tf.io.FixedLenFeature([], tf.int64),
}

def _parse_image_function(example_proto):
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.reshape( tf.io.decode_raw(single_example['img_bytes'],out_type='uint8'), (512, 512, 3))
    image = tf.image.convert_image_dtype(image, tf.float32)
    
    mask =  tf.reshape(tf.io.decode_raw(single_example['mask'],out_type='bool'),(512, 512,1))

    return image, tf.cast(mask, tf.float32) # cast as float32 required for TPU

def load_dataset(filenames, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO, compression_type="GZIP")
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_image_function, num_parallel_calls=AUTO)
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(128)
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset

# Model

In [None]:
with strategy.scope():
    model = sm.Unet(BACKBONE)

## Callbacks

### Early Stopping

In [None]:
early_stopper = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                                 patience=10, mode='min')

### Model Checkpoint

In [None]:
model_ckpt = tf.keras.callbacks.ModelCheckpoint(filepath='model_weights.h5', 
                                                monitor='val_loss', 
                                                save_weights_only=True,
                                                save_best_only=True, 
                                                mode='min')

### Custom Callback to Visualize Segmentation Masks using W&B

In [None]:
segmentation_classes = ['issue', 'no issue']

# returns a dictionary of labels
def labels():
  l = {}
  for i, label in enumerate(segmentation_classes):
    l[i] = label
  return l

# util function for generating interactive image mask from components
def wandb_mask(bg_img, pred_mask, true_mask):
  return wandb.Image(bg_img, masks={
      "prediction" : {
          "mask_data" : pred_mask, 
          "class_labels" : labels()
      },
      "ground truth" : {
          "mask_data" : true_mask, 
          "class_labels" : labels()
      }
    }
  )

In [None]:
class SemanticLogger(tf.keras.callbacks.Callback):
    def __init__(self, dataloader):
        super(SemanticLogger, self).__init__()
        self.val_images, self.val_masks = next(iter(dataloader))

    def on_epoch_end(self, logs, epoch):
        pred_masks = self.model.predict(self.val_images)
        pred_masks = np.argmax(pred_masks, axis=-1)

        val_images = tf.image.convert_image_dtype(self.val_images, tf.uint8)
        val_masks = tf.image.convert_image_dtype(self.val_masks, tf.uint8)
        val_masks = tf.squeeze(val_masks, axis=-1)
        
        pred_masks = tf.image.convert_image_dtype(pred_masks, tf.uint8)

        mask_list = []
        for i in range(len(self.val_images)):
          mask_list.append(wandb_mask(val_images[i].numpy(), 
                                      pred_masks[i].numpy(), 
                                      val_masks[i].numpy()))

        wandb.log({"predictions" : mask_list})

# Train Model Using W&B

In [None]:
# compile model
optimizer = 'adam'
model.compile(optimizer=optimizer,
              loss=tf.keras.losses.BinaryCrossentropy(),    
              metrics=[sm.metrics.iou_score,'accuracy'])

# initialize wandb run
wandb.init(project='HuBMAP')

_ = model.fit(get_training_dataset(), 
              epochs=EPOCHS,
              steps_per_epoch=STEPS_PER_EPOCH,
              verbose = VERBOSE,
              validation_data=get_validation_dataset(),
              callbacks=[early_stopper,
                         model_ckpt,
                         WandbCallback(),
                         SemanticLogger(get_validation_dataset())
                        ])

wandb.finish()

In [None]:
# save whole model for submission without internet
model.load_weights('model_weights.h5')
model.save('model.h5')