##### Copyright 2020 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# RNGDet evaluation

## Setup

Install and import the necessary modules.

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

#import sys
#sys.path.append("/mnt/hdd-nfs-intern/ghpark/03_temp/models")

In [None]:
import tensorflow as tf
import numpy as np

In [None]:
from official.projects.rngdet.tasks import rngdet
from official.core import exp_factory
exp_config = exp_factory.get_exp_config('rngdet_cityscale')
task_obj = rngdet.RNGDetTask(exp_config.task)
model = task_obj.build_model()
#task_obj.initialize(model)

In [None]:
ckpt_dir_or_file = '/mnt/hdd-nfs-intern/ghpark/03_temp/ckpt/test_10'
ckpt = tf.train.Checkpoint(
    backbone=model.backbone,
    backbone_history=model.backbone_history,
    transformer=model.transformer,
    segment_fpn=model._segment_fpn,
    keypoint_fpn=model._keypoint_fpn,
    query_embeddings=model._query_embeddings,
    segment_head=model._segment_head,
    keypoint_head=model._keypoint_head,
    class_embed=model._class_embed,
    bbox_embed=model._bbox_embed,
    input_proj=model.input_proj)
status = ckpt.restore(tf.train.latest_checkpoint(ckpt_dir_or_file))
status.expect_partial().assert_existing_objects_matched()
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("LOAD CHECKPOINT DONE")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")

In [None]:
from PIL import Image
from official.projects.rngdet.eval import agent

pad_size = 128
sat_image = np.array(Image.open(os.path.join('./region_0_sat.png')))

sat_image = tf.cast(sat_image, tf.float32)
agent = agent.Agent(model, sat_image)

In [None]:
logit_threshold = 0.75
roi_size = 128

while 1:
    agent.step_counter += 1
    # crop ROI
    sat_ROI, historical_ROI = agent.crop_ROI(agent.current_coord)
    sat_ROI = tf.expand_dims(sat_ROI, 0) / 255.0
    # (gunho) historical_ROI / 255.0 in original code
    historical_ROI = tf.expand_dims(historical_ROI, 0) / 255.0
    historical_ROI = tf.expand_dims(historical_ROI, -1)
    historical_ROI = tf.cast(historical_ROI, tf.float32)
    # predict vertices in the next step
    outputs, pred_segment, pred_keypoint = model(sat_ROI, historical_ROI, training=False)
    # agent moves
    # alignment vertices
    outputs = outputs[-1]
    pred_coords = outputs['box_outputs']
    pred_probs = outputs['cls_outputs']
    alignment_vertices = [[v[0]-agent.current_coord[0]+agent.crop_size//2,
        v[1]-agent.current_coord[1]+agent.crop_size//2] for v in agent.historical_vertices]
    pred_coords_ROI = agent.step(pred_probs,pred_coords,thr=logit_threshold)
    
    if agent.finish_current_image:
        print(f'STEP 3: Finsh exploration. Save visualization and graph...')
        Image.fromarray(
            agent.historical_map[roi_size:-roi_size,roi_size:-roi_size].astype(np.uint8)
            ).convert('RGB').save(f'./segmentation/0_result.png')
        break
    # stop action
        

# TEMP

Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID.

Run a batch of the processed training data through the model, and view the results