To run this notebook:
> g4d your_citc_client;<br>
> /google/data/ro/teams/colab/tensorflow --port=8888<br>
> Connect to port 8888 in the top right of this notebook.<br>

In [0]:
import os
import tensorflow.compat.v1 as tf
from google3.pyglib import flags
import numpy as np
import matplotlib.patches as patches
import matplotlib.pyplot as plt

from colabtools import adhoc_import

with adhoc_import.Google3CitcClient("multichannel_training", username="salbro", behavior="preferred", verbose=True):
  from dataloader import input_reader
  from dataloader import mode_keys as ModeKeys
  from hyperparameters import params_dict

# Dataloader

## Set these.

In [0]:
extra_channel_key = "image/relative_altitude"
extra_channel_mean = 103
extra_channel_std = 76

_TRAIN_FILE_NAME = '/cns/ue-d/home/geo-imagery-lerna-models-dev/geo-sky-image-processing/mpallone/training_sets/eos/multichannel/geoalign.test.normalized.tfrecord*'
output_size = [1024, 1024]
num_classes = 2

## Set the other params.

In [0]:
batch_size = 20
anchor_size = 4.0

min_level = 3
max_level = 7
num_scales = 3
aspect_ratios = [1.0, 2.0, 0.5]
max_num_instances = 100
skip_crowd_during_training = True
use_autoaugment=False

SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
params = params_dict.ParamsDict({
    'type': 'shapemask',
    'train': {
        'total_steps': 45000,
        'learning_rate': {
            'learning_rate_steps': [30000, 40000],
        },
        'frozen_variable_prefix': SHAPEMASK_RESNET_FROZEN_VAR_PREFIX,
    },
    'eval': {
        'type': 'shapemask_box_and_mask',
        'mask_eval_class': 'all',  # 'all', 'voc', or 'nonvoc'.
    },
    'architecture': {
        'parser': 'shapemask_parser',
        'backbone': 'resnet',
        'multilevel_features': 'fpn',
        'use_bfloat16': True,
    },
    'anchor': {
      'min_level': min_level,
      'max_level': max_level,
      'num_scales': num_scales,
      'aspect_ratios': aspect_ratios,
      'anchor_size': anchor_size,
    },
    'shapemask_parser': {
        'output_size': [1024, 1024],
        'match_threshold': 0.5,
        'unmatched_threshold': 0.5,
        'aug_rand_hflip': True,
        'aug_scale_min': 0.8,
        'aug_scale_max': 1.2,
        'skip_crowd_during_training': True,
        'max_num_instances': 100,
        'use_bfloat16': True,
        # Shapemask specific parameters
        'mask_train_class': 'all',  # 'all', 'voc', or 'nonvoc'.
        'use_category': True,
        'outer_box_scale': 1.25,
        'num_sampled_masks': 8,
        'mask_crop_size': 32,
        'mask_min_level': 3,
        'mask_max_level': 5,
        'box_jitter_scale': 0.025,
        'upsample_factor': 4,
        'extra_channel_keys': [extra_channel_key]
    },
    'retinanet_head': {
        'min_level': 3,
        'max_level': 7,
        # Note that `num_classes` is the total number of classes including
        # one background classes whose index is 0.
        'num_classes': 2,
        'anchors_per_location': 9,
        'retinanet_head_num_convs': 4,
        'retinanet_head_num_filters': 256,
        'use_separable_conv': False,
        'use_batch_norm': True,
        'batch_norm': {
            'batch_norm_momentum': 0.997,
            'batch_norm_epsilon': 1e-4,
            'batch_norm_trainable': True,
            'use_sync_bn': False,
        },
    },
    'shapemask_head': {
        'num_classes': 2,
        'num_downsample_channels': 128,
        'mask_crop_size': 32,
        'use_category_for_mask': False,
        'num_convs': 4,
        'upsample_factor': 4,
        'shape_prior_path': '',
        'batch_norm': {
            'batch_norm_momentum': 0.997,
            'batch_norm_epsilon': 1e-4,
            'batch_norm_trainable': True,
            'use_sync_bn': False,
        },
    },
    'retinanet_loss': {
        'num_classes': 2,
        'focal_loss_alpha': 0.4,
        'focal_loss_gamma': 1.5,
        'huber_loss_delta': 0.15,
        'box_loss_weight': 50,
    },
    'shapemask_loss': {
        'shape_prior_loss_weight': 0.1,
        'coarse_mask_loss_weight': 1.0,
        'fine_mask_loss_weight': 1.0,
    },
    'postprocess': {
        'min_level': 3,
        'max_level': 7,
    },
})

# Run a session, getting images and labels from the dataset.

In [0]:
sess = tf.Session()
with sess.as_default():
  input_fn = input_reader.InputFn(_TRAIN_FILE_NAME, params, ModeKeys.TRAIN)
  dataset = input_fn({'batch_size': batch_size})
  iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
  images, labels = iterator.get_next()
  images, labels = sess.run([images, labels])

# Show in Matplotlib as a sanity check.

In [0]:
colormap = (np.random.rand(num_classes+2,3) + 0.3) / 1.3
for i, im in enumerate(images):
  plt.figure(figsize=[10,10])

  rgb = im[:,:,:3] * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
  plt.imshow(rgb);

  extra_channels = im[:,:,3:]*extra_channel_std + extra_channel_mean
  extra_channels = np.concatenate((extra_channels, extra_channels, extra_channels), axis=2)
  plt.imshow(extra_channels.astype(int))

  ax = plt.gca()
  for level in range(min_level, max_level + 1):
    locations = np.where(labels['cls_targets'][level][i] >= 0)
    batch, h, w, num_anchors_per_location = labels['cls_targets'].shape
    anchor_boxes = labels['anchor_boxes'][level].reshape(batch, h, w, num_anchors_per_location, 4)
    if len(locations[0]) > 0:
      for y, x, anchor_index in zip(*locations):
        cls = labels['cls_targets'][level][i, y, x, anchor_index]
        # Draws the anchor.
        anchor = anchor_boxes[i, y, x, anchor_index]
        wa = anchor[3] - anchor[1]
        ha = anchor[2] - anchor[0]
        ycenter_a = anchor[0] + 0.5 * ha
        xcenter_a = anchor[1] + 0.5 * wa
        lower_left = (anchor[1], anchor[0])
        rect = patches.Rectangle(lower_left, wa, ha,linewidth=1,edgecolor=colormap[cls,:],facecolor='none')
        ax.add_patch(rect)
        # Draws ground truth box assigned to the anchor.
        ty, tx, th, tw = labels['box_targets'][level][i, y, x, anchor_index*4:(anchor_index+1)*4]
        w = np.exp(tw) * wa
        h = np.exp(th) * ha
        ycenter = ty * ha + ycenter_a
        xcenter = tx * wa + xcenter_a
        ymin = ycenter - h / 2.
        xmin = xcenter - w / 2.
        rect = patches.Rectangle((xmin, ymin), w, h, linewidth=3, linestyle='-.', edgecolor=colormap[cls+1,:],facecolor='none')
        ax.add_patch(rect)
  plt.show()