In [1]:
import numpy as np
import os
from glob import glob
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import cv2
import matplotlib.pyplot as plt

  from ._conv import register_converters as _register_converters


Instructions for updating:
Use the retry module or similar alternatives.


In [2]:
tf.enable_eager_execution(device_policy=tfe.DEVICE_PLACEMENT_SILENT)

## Global Variables

In [3]:
DIR_TFRECORDS = 'data_small_tfrecords'
DATA_TRAIN = glob('./'+DIR_TFRECORDS+'/*.tfrecords')

NUM_OBJECTS = 20
MAX_DETECTIONS_PER_IMAGE = 20

GRID_H, GRID_W = 16, 16
GRID_SIZE = 608//GRID_H 

ANCHORS_NORMALIZED = np.array(
    [
        [0.09112895, 0.06958421],
        [0.21102316, 0.16803947],
        [0.42625895, 0.26609842],
        [0.25476474, 0.49848   ],
        [0.52668947, 0.59138947]
    ]
)
ANCHORS = ANCHORS_NORMALIZED * np.array([GRID_H, GRID_W])
NUM_ANCHORS = ANCHORS.shape[0]

IMG_H, IMG_W = GRID_H * GRID_SIZE, GRID_W * GRID_SIZE

COOEFFICIENT_OBJ = 1
COOEFFICIENT_NO_OBJ = 1
COOEFFICIENT_REG = 5

THRESHOLD_IOU_SCORES = 0.5
COEFF_LOSS_CONFIDENCE_OBJECT_PRESENT = 5
COEFF_LOSS_CONFIDENCE_OBJECT_ABSENT = 1
THRESHOLD_OUT_PROB = 0.5
THRESHOLD_IOU_NMS = 0.5

NUM_EPOCHS = 10 # not used
BATCH_SIZE = 5
CHECKPOINT_DIR = 'model'
CHECKPOINT_PREFIX = os.path.join(CHECKPOINT_DIR, "ckpt")
DIR_IMG_OUT = 'imgs_out'

DEBUG = False

## Utility

In [4]:
def sigmoid(x):
    return np.exp(x)/(1+np.exp(x))

def draw_adjusted_anchor(img, idx_h, idx_w, idx_a, label_bbox):
    # box_adjustment.shape = [5,]
    
    center_y = (label_bbox[0]+idx_h)*GRID_SIZE
    center_x = (label_bbox[1]+idx_w)*GRID_SIZE
    
    height = label_bbox[2] * ANCHORS[idx_a,0] * GRID_SIZE
    width = label_bbox[3] * ANCHORS[idx_a,1] * GRID_SIZE
     
    left = int(center_x - width/2)
    top = int(center_y - height/2)
    right = int(left + width)
    bottom = int(top + height)

    img = cv2.rectangle(img, (left, top), (right, bottom), color=(0, 255, 0), thickness=3)
    
    return img
    
def label(img, label):
    # unnormalize image
    img = (img * 255).astype(np.uint8)
    
    for idx_h in range(GRID_H):
        for idx_w in range(GRID_W):
            for idx_a in range(NUM_ANCHORS):
                if sigmoid(label[idx_h, idx_w, idx_a, 5]) > 0.1:
                    bbox = label[idx_h, idx_w, idx_a, :4]
                    img = draw_adjusted_anchor(img, idx_h, idx_w, idx_a, bbox)
    
    return img

def draw_output(img, output):
    # unnormalize image
    img = (img * 255).astype(np.uint8)
    
    output = output.astype(np.int32)
    for idx_box in range(output.shape[0]):
        bbox = output[idx_box]
        img = cv2.rectangle(img, (bbox[1], bbox[0]), (bbox[3], bbox[2]), color=(255, 0, 0), thickness=3)
    
    return img

def center2corner(predictions_yx, predictions_hw):
    # predictions_yx = [GRID_H, GRID_W, NUM_ANCHORS, 2]
    
    bbox_min = predictions_yx - (predictions_hw/2.)
    bbox_max = predictions_yx + (predictions_hw/2.)
    
    predictions_corner = tf.concat([bbox_min[...,0:1], bbox_min[...,1:2], bbox_max[...,0:1], bbox_max[...,1:2]], axis=-1)
    return predictions_corner

def get_filtered_predictions(predictions_corner, predictions_prob_obj, predictions_prob_class):
    # compute overall prob for each anchor in each grid
    predictions_prob = predictions_prob_obj * predictions_prob_class
    
    # get max prob among all classes at each anchor in each grid
    predictions_idx_class_max = tf.argmax(predictions_prob, axis=-1)
    predictions_prob = tf.reduce_max(predictions_prob, axis=-1)
    
    # compute filter mask
    mask_filter = predictions_prob >= THRESHOLD_OUT_PROB
    
    # apply mask on output
    bbox_filtered = tf.boolean_mask(predictions_corner, mask_filter)
    prob_filtered = tf.boolean_mask(predictions_prob, mask_filter)
    with tf.device('/cpu:0'):
        idx_class_filtered = tf.boolean_mask(predictions_idx_class_max, mask_filter)
    
    return bbox_filtered, prob_filtered, idx_class_filtered


def predictions2outputs(predictions):
    # apply corresponding transformations on predictions
    predictions_yx, predictions_hw, predictions_prob_obj, predictions_prob_class = apply_transformations(predictions)
    
    # map predictions_bbox to [0,1] space
    predictions_yx, predictions_hw = grid2normalized(predictions_yx, predictions_hw)
    
    # represent boxes using corners
    predictions_corner = center2corner(predictions_yx, predictions_hw)
    
    # filter predictions based on (prob_obj * prob_class). (needs to be done separately for each image in batch)
    bbox_filtered, prob_filtered, idx_class_filtered = get_filtered_predictions(predictions_corner, predictions_prob_obj, predictions_prob_class)
#     bbox_filtered, prob_filtered = get_filtered_predictions(predictions_corner, predictions_prob_obj, predictions_prob_class)
    # bbox_filtered.shape = [BATCH_SIZE, NUM_FILTERED, 4]
    
    # TODO: perform nms for each class separately
    # scale boxes from [0,1] to image space
    img_space = tf.reshape(tf.cast(tf.stack([IMG_H, IMG_W, IMG_H, IMG_W]), tf.float32), [1, 1, 4])
    bbox_filtered = tf.reshape(bbox_filtered, [-1, 4])  # tf.nms takes num_boxes (no batch support)
    
    # perform non-max suppression
    with tf.device('/cpu:0'):
        bbox_nms_indices = tf.image.non_max_suppression(bbox_filtered, tf.reshape(prob_filtered,[-1]), MAX_DETECTIONS_PER_IMAGE)
    bbox_nms = tf.gather(bbox_filtered, bbox_nms_indices)  # box_nms.shape = [len(bbox_nms_indices), 4]
    prob_nms = tf.expand_dims(tf.gather(prob_filtered, bbox_nms_indices), axis=-1) # prob_nms.shape = [len(bbox_nms_indices), 1]
    with tf.device('/cpu:0'):
        idx_class_nms = tf.expand_dims(tf.cast(tf.gather(idx_class_filtered, bbox_nms_indices), tf.float32), axis=-1)
    
    # concat return data
    output = tf.concat([bbox_nms, prob_nms, idx_class_nms], axis=-1)
#     output = tf.concat([bbox_nms, prob_nms], axis=-1)
    
    return tf.expand_dims(output, axis=0)

In [5]:
def parse_record(record):
    # dictionary as per saved TFRecord
    keys_to_features = {
        'img': tf.FixedLenFeature(shape=(), dtype=tf.string),
        'label': tf.FixedLenFeature(shape=(), dtype=tf.string),
    }

    # parse record
    parsed = tf.parse_single_example(record, keys_to_features)

    # decode image
    img = tf.decode_raw(parsed['img'], tf.uint8)
    img = tf.cast(tf.reshape(img, [IMG_H, IMG_W, 3]), tf.float32)
    img /= 255.  # normalize

    # decode label
    label = tf.decode_raw(parsed['label'], tf.float32)
    label = tf.reshape(label, [GRID_H, GRID_W, NUM_ANCHORS, 6])

    return img, label

In [6]:
def apply_transformations(predictions):
    predictions_yx = tf.sigmoid(predictions[..., 0:2])
    predictions_hw = tf.exp(predictions[...,2:4])
    predictions_prob_obj = tf.sigmoid(predictions[...,4:5])
    predictions_prob_class = tf.nn.softmax(predictions[...,5:])
    
    return predictions_yx, predictions_hw, predictions_prob_obj, predictions_prob_class

def get_coordinates(h, w):
    coordinates_y = tf.range(h)
    coordinates_x = tf.range(w)
    x, y = tf.meshgrid(coordinates_x, coordinates_y)
    coordinates = tf.stack([y,x], axis=-1)
    coordinates = tf.reshape(coordinates, [1, h, w, 1, 2])
    coordinates = tf.cast(coordinates, tf.float32)
    
    return coordinates

def grid2normalized(predictions_yx, predictions_hw):    
    # create cartesian coordinates on grid space
    coordinates = get_coordinates(GRID_H, GRID_W)
    
    # map from grid space to [0,19] space
    anchors = tf.cast(tf.reshape(ANCHORS, [1, 1, 1, ANCHORS.shape[0], 2]), dtype=tf.float32)  # [0,19] space
    predictions_yx += coordinates
    predictions_hw *= anchors
    
    # map from [0,19] space to [0,1] space
    shape = tf.cast(tf.reshape([GRID_H, GRID_W], [1, 1, 1, 1, 2]), tf.float32)
    predictions_yx /= shape
    predictions_hw /= shape
    
    return predictions_yx, predictions_hw

def get_boxes_gt(args_map):
    # extract ground truth bboxes wherever prob_obj = 1
    mask_object = tf.cast(tf.reshape(args_map[1], [GRID_H, GRID_W, NUM_ANCHORS]), tf.bool)
    bboxes = tf.boolean_mask(args_map[0], mask_object)
    # bboxes.shape = [NUM_DETECTIONS, 4]; NUM_DETECTIONS vary with each image
    
    # pad bboxes so that bboxes is fixed dimension (fix NUM_DETECTIONS to MAX_DETECTIONS_PER_IMAGE)
    pad = tf.zeros((MAX_DETECTIONS_PER_IMAGE - tf.shape(bboxes)[0], 4))  # TODO: when NUM_DETECTIONS > MAX_DETECTIONS_PER_IMAGE
    bboxes = tf.concat([bboxes, pad], axis=0)
    
    return bboxes

def get_iou_scores(predictions_yx, predictions_hw, bboxes_gt):
    # predictions_yx.shape = predictions_hw.shape = [BATCH_SIZE, GRID_H, GRID_W, NUM_ANCHORS, 2]
    # bboxes_gt.shape = [BATCH_SIZE, MAX_DETECTIONS_PER_IMAGE, 4]
    
    # compute ious for each anchor in each grid in axis=4
    predictions_yx = tf.expand_dims(predictions_yx, 4)
    predictions_hw = tf.expand_dims(predictions_hw, 4)
    
    predictions_min = predictions_yx - predictions_hw/2.
    predictions_max = predictions_yx + predictions_hw/2.
    
    bboxes_gt = tf.reshape(bboxes_gt, [tf.shape(bboxes_gt)[0], 1, 1, 1, MAX_DETECTIONS_PER_IMAGE, 4])
    bboxes_gt_yx = bboxes_gt[..., 0:2]
    bboxes_gt_hw = bboxes_gt[..., 2:4]
    
    bboxes_gt_min = bboxes_gt_yx - bboxes_gt_hw/2.
    bboxes_gt_max = bboxes_gt_yx + bboxes_gt_hw/2.
    
    intersection_min = tf.maximum(predictions_min, bboxes_gt_min)
    intersection_max = tf.minimum(predictions_max, bboxes_gt_max)
    intersection_hw = tf.maximum(intersection_max - intersection_min, 0.)
    area_intersection = intersection_hw[..., 0] * intersection_hw[..., 1]
    
    area_predictions = predictions_hw[...,0] * predictions_hw[...,1]
    area_bboxes_gt = bboxes_gt_hw[...,0] * bboxes_gt_hw[...,1]
    area_union = area_bboxes_gt + area_predictions - area_intersection
    iou = area_intersection / area_union
    
    return iou

def get_confidence_loss(labels_prob_obj, iou_mask, predictions_prob_obj):
    mask_object_absent = (1 - labels_prob_obj) * (1 - iou_mask)
    loss_object_absent = mask_object_absent * tf.square(predictions_prob_obj)
    
    loss_object_present = labels_prob_obj * tf.square(1-predictions_prob_obj)
    
    loss_confidence = COEFF_LOSS_CONFIDENCE_OBJECT_ABSENT * loss_object_absent \
            + COEFF_LOSS_CONFIDENCE_OBJECT_PRESENT * loss_object_present
    
    return tf.reduce_sum(loss_confidence)
    
def get_classification_loss(labels_prob_obj, labels_class, predictions_prob_class):
    labels_class = tf.cast(labels_class, tf.int32)
    labels_class = tf.one_hot(labels_class, NUM_OBJECTS)
    
    loss_classification = labels_prob_obj * tf.squared_difference(labels_class, predictions_prob_class)
    
    return tf.reduce_sum(loss_classification)

def get_regression_loss(labels_bbox, predictions_bbox, labels_prob_obj):
    loss_regression = labels_prob_obj * tf.squared_difference(labels_bbox,predictions_bbox)
    
    return tf.reduce_sum(loss_regression)

## Model

In [18]:
class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.optimizer = tf.train.AdamOptimizer()
        
        # add layers
        self.conv1 = tf.keras.layers.Conv2D(16, 3, padding='same')
        self.norm1 = tf.keras.layers.BatchNormalization()
        self.pool1 = tf.keras.layers.MaxPool2D()
        self.conv2 = tf.keras.layers.Conv2D(32, 3, padding='same')
        self.norm2 = tf.keras.layers.BatchNormalization()
        self.pool2 = tf.keras.layers.MaxPool2D()
        self.conv3 = tf.keras.layers.Conv2D(64, 3, padding='same')
        self.norm3 = tf.keras.layers.BatchNormalization()
        self.pool3 = tf.keras.layers.MaxPool2D()
        self.conv4 = tf.keras.layers.Conv2D(128, 3, padding='same')
        self.norm4 = tf.keras.layers.BatchNormalization()
        self.pool4 = tf.keras.layers.MaxPool2D()
        self.conv5 = tf.keras.layers.Conv2D(256, 3, padding='same')
        self.norm5 = tf.keras.layers.BatchNormalization()
        self.pool5 = tf.keras.layers.MaxPool2D()
        self.conv6 = tf.keras.layers.Conv2D(512, 3, padding='same')
        self.norm6 = tf.keras.layers.BatchNormalization()  # till here same as classification network
        
        self.conv7 = tf.keras.layers.Conv2D(1024, 3, padding='same')
        self.norm7 = tf.keras.layers.BatchNormalization()
        self.conv8 = tf.keras.layers.Conv2D(1024, 3, padding='same')
        self.norm8 = tf.keras.layers.BatchNormalization()  # extra layers for detection
        
        self.conv9 = tf.keras.layers.Conv2D(64, 1, padding='same')
        self.norm9 = tf.keras.layers.BatchNormalization()  # applies on skipped layer
        
        self.conv10 = tf.keras.layers.Conv2D(1024, 3, padding='same')  # applies on concatenated output
        self.norm10 = tf.keras.layers.BatchNormalization()
        self.conv11 = tf.keras.layers.Conv2D(NUM_ANCHORS*(5+NUM_OBJECTS), 1, padding='same')
        
#         self.dense1 = tf.keras.layers.Dense(256, activation=tf.nn.relu, kernel_initializer=tf.glorot_uniform_initializer())
#         self.dense2 = tf.keras.layers.Dense(256, activation=tf.nn.relu, kernel_initializer=tf.glorot_uniform_initializer())
#         self.dense3 = tf.keras.layers.Dense(GRID_H*GRID_W*NUM_ANCHORS*(5+NUM_OBJECTS))
        
    def forward(self, imgs):
        # imgs.shape = [B, IMG_H, IMG_W, 3]
        
        # for now, resize and reshape imgs to vector
        imgs = tf.image.resize_images(imgs, [512, 512])
        
        c1 = self.conv1(imgs)
        n1 = self.norm1(c1)
        a1 = tf.nn.leaky_relu(n1)
        p1 = self.pool1(a1)
        c2 = self.conv2(p1)
        n2 = self.norm2(c2)
        a2 = tf.nn.leaky_relu(n2)
        p2 = self.pool2(a2)
        c3 = self.conv3(p2)
        n3 = self.norm3(c3)
        a3 = tf.nn.leaky_relu(n3)
        p3 = self.pool3(a3)
        c4 = self.conv4(p3)
        n4 = self.norm4(c4)
        a4 = tf.nn.leaky_relu(n4)
        p4 = self.pool4(a4)
        c5 = self.conv5(p4)
        n5 = self.norm5(c5)
        a5 = tf.nn.leaky_relu(n5)
        p5 = self.pool5(a5)
        c6 = self.conv6(p5)
        n6 = self.norm6(c6)
        a6 = tf.nn.leaky_relu(n6)
        
        c7 = self.conv7(a6)
        n7 = self.norm7(c7)
        a7 = tf.nn.leaky_relu(n7)
        c8 = self.conv8(a7)
        n8 = self.norm8(c8)
        a8 = tf.nn.leaky_relu(n8)
        
        c9 = self.conv9(a4)  # change a to match dimension appropriately
        n9 = self.norm9(c9)
        a9 = tf.nn.leaky_relu(n9)
        
        a9_reshaped = tf.space_to_depth(a9, 4)
        a_8_9_concat = tf.concat([a8, a9_reshaped], axis=3)
        
        c10 = self.conv10(a_8_9_concat)
        n10 = self.norm10(c10)
        a10 = tf.nn.leaky_relu(n10)
        c11 = self.conv11(a10)
        
        # reshape output
        pred = tf.reshape(c11, [-1, GRID_H, GRID_W, NUM_ANCHORS, 5+NUM_OBJECTS])
        
        return pred
    
    def get_loss(self, predictions, labels):
        # predictions.shape = [BATCH_SIZE, GRID_H, GRID_W, NUM_ANCHORS, 5+NUM_OBJECTS] (they are in grid space)
        # labels.shape = [BATCH_SIZE, GRID_H, GRID_W, NUM_ANCHORS, 6]

        # apply corresponding transformations on predictions
        predictions_yx, predictions_hw, predictions_prob_obj, predictions_prob_class = apply_transformations(predictions)

        # map predictions_bbox to [0,1] space
        predictions_yx, predictions_hw = grid2normalized(predictions_yx, predictions_hw)

        # map labels_bbox to [0,1] space
        labels_yx, labels_hw = grid2normalized(labels[...,0:2], labels[...,2:4])

        # get ground truth bboxes using labels_bbox & prob_obj in labels
        labels_bbox = tf.concat([labels_yx, labels_hw], axis=-1)
        bboxes_gt = tf.map_fn(get_boxes_gt, (labels_bbox, labels[...,5]), dtype=tf.float32)

        # compute iou scores for each anchor in each grid for all bboxes_gt
        iou_scores = get_iou_scores(predictions_yx, predictions_hw, bboxes_gt)

        # keep anchors whose iou_scores are above THRESHOLD_IOU_SCORES
        iou_scores_best = tf.reduce_max(iou_scores, axis=4, keep_dims=True)
        iou_mask = tf.cast(iou_scores_best > THRESHOLD_IOU_SCORES, tf.float32)

        ## Loss
        # object confidence loss (presence and absence)
        loss_confidence = get_confidence_loss(labels[...,5:6], iou_mask, predictions_prob_obj)

        # classification loss
        loss_classification = get_classification_loss(labels[...,5:6], labels[...,4], predictions_prob_class)

        # regression loss
        predictions_bbox = tf.concat([predictions_yx, predictions_hw], axis=-1)
        loss_regression = get_regression_loss(labels_bbox, predictions_bbox, labels[...,5:6])

        # total loss
        loss = ( loss_confidence + loss_classification + loss_regression ) / tf.cast(tf.shape(labels)[0], tf.float32)

        return loss
    
    def train(self, dataset):
        '''trains the model for one epoch'''
        epoch_loss = tf.constant(0.)
        for idx_batch, data in enumerate(tfe.Iterator(dataset)):
            with tfe.GradientTape() as tape:
                # forward pass
                predictions = self.forward(data[0])

                # compute loss
                loss = self.get_loss(predictions, data[1])
                
            # backward pass (compute gradients)
            gradients = tape.gradient(loss, self.variables)
            
            # update parameters
            self.optimizer.apply_gradients(
                zip(gradients, self.variables), 
                global_step=tf.train.get_or_create_global_step()
            )
            
            epoch_loss += loss
            
        return (epoch_loss/(idx_batch+1)).numpy()
        
    def predict(self, imgs):
        '''predicts bboxes and draws them on the image'''
        # imgs.shape = [B, IMG_H, IMG_W, 3]
        
        # forward pass
        predictions = self.forward(imgs)
        
        # post-process to get bounding boxes
        outputs = predictions2outputs(predictions)  
        # CAUTION!!!
        # TODO: use batch multi-class nms (currently works with BATCH_SIZE=1)
        # reference: https://github.com/tensorflow/models/blob/master/research/object_detection/core/post_processing.py
        
        # draw outputs on the image
        with tf.device('/cpu:0'):
            imgs_out = tf.image.draw_bounding_boxes(imgs, outputs[..., 0:4])
        
        return imgs_out, outputs

## Train

In [8]:
# dataset processing
dataset_train = tf.data.TFRecordDataset(DATA_TRAIN)
dataset_train = dataset_train.map(parse_record)
dataset_train = dataset_train.shuffle(buffer_size=1024)
dataset_train = dataset_train.batch(BATCH_SIZE)

In [21]:
with tf.device('/gpu:0'):
    model = Model()

In [None]:
with tf.device('/gpu:0'):
    model.optimizer = tf.train.AdamOptimizer(0.0001)
    for i in range(100):
        loss = model.train(dataset_train)
        print('Epoch:{}, Loss={}'.format(i, loss))

Epoch:0, Loss=14.834146499633789
Epoch:1, Loss=14.711410522460938
Epoch:2, Loss=14.533823013305664
Epoch:3, Loss=14.395456314086914
Epoch:4, Loss=14.36840534210205
Epoch:5, Loss=14.37568473815918
Epoch:6, Loss=14.322985649108887
Epoch:7, Loss=14.25763988494873
Epoch:8, Loss=14.23637866973877
Epoch:9, Loss=14.185025215148926
Epoch:10, Loss=14.088035583496094
Epoch:11, Loss=14.121101379394531
Epoch:12, Loss=14.162786483764648
Epoch:13, Loss=14.016748428344727
Epoch:14, Loss=14.043904304504395
Epoch:15, Loss=13.952638626098633
Epoch:16, Loss=13.8106107711792
Epoch:17, Loss=14.192758560180664
Epoch:18, Loss=14.171308517456055
Epoch:19, Loss=14.117353439331055
Epoch:20, Loss=14.068041801452637
Epoch:21, Loss=14.01365852355957
Epoch:22, Loss=14.028668403625488
Epoch:23, Loss=13.999699592590332
Epoch:24, Loss=14.002134323120117
Epoch:25, Loss=13.957536697387695
Epoch:26, Loss=13.850397109985352
Epoch:27, Loss=13.74223518371582
Epoch:28, Loss=13.6453275680542
Epoch:29, Loss=13.593106269836426


In [11]:
checkpoint = tfe.Checkpoint(model=model, optimizer_step=tf.train.get_or_create_global_step())
checkpoint.save(file_prefix=CHECKPOINT_PREFIX)

'model/ckpt-1'

## Prediction

In [12]:
# dataset processing
dataset_test = tf.data.TFRecordDataset(DATA_TRAIN)
dataset_test = dataset_test.map(parse_record)
dataset_test = dataset_test.batch(1)

In [13]:
with tf.device('/gpu:0'):
    # load trained model
    checkpoint = tfe.Checkpoint(model=model, optimizer_step=tf.train.get_or_create_global_step())
    checkpoint.restore(tf.train.latest_checkpoint(CHECKPOINT_DIR))
    
    for idx_img, data in enumerate(tfe.Iterator(dataset_test)):
        # predict
        imgs_out, output = model.predict(data[0])
        
        # write images
        img = imgs_out.numpy()[0]
        img = (img * 255).astype(np.uint8)
        cv2.imwrite(DIR_IMG_OUT+ '/'+str(idx_img)+'.png', img)

a8.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
a9.shape= tf.Tensor([ 1 64 64 64], shape=(4,), dtype=int32)
a9_reshaped.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
tf.Tensor([  1  16  16 125], shape=(4,), dtype=int32)
a8.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
a9.shape= tf.Tensor([ 1 64 64 64], shape=(4,), dtype=int32)
a9_reshaped.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
tf.Tensor([  1  16  16 125], shape=(4,), dtype=int32)
a8.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
a9.shape= tf.Tensor([ 1 64 64 64], shape=(4,), dtype=int32)
a9_reshaped.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
tf.Tensor([  1  16  16 125], shape=(4,), dtype=int32)
a8.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
a9.shape= tf.Tensor([ 1 64 64 64], shape=(4,), dtype=int32)
a9_reshaped.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
tf.Tensor([  1  16

a8.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
a9.shape= tf.Tensor([ 1 64 64 64], shape=(4,), dtype=int32)
a9_reshaped.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
tf.Tensor([  1  16  16 125], shape=(4,), dtype=int32)
a8.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
a9.shape= tf.Tensor([ 1 64 64 64], shape=(4,), dtype=int32)
a9_reshaped.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
tf.Tensor([  1  16  16 125], shape=(4,), dtype=int32)
a8.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
a9.shape= tf.Tensor([ 1 64 64 64], shape=(4,), dtype=int32)
a9_reshaped.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
tf.Tensor([  1  16  16 125], shape=(4,), dtype=int32)
a8.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
a9.shape= tf.Tensor([ 1 64 64 64], shape=(4,), dtype=int32)
a9_reshaped.shape= tf.Tensor([   1   16   16 1024], shape=(4,), dtype=int32)
tf.Tensor([  1  16