In [1]:
import os
import time
import json
import math
from tqdm import tqdm

import cv2
import numpy as np
import tensorflow as tf

# 데이터 전처리

In [2]:
def getUnionBBox(aBB, bBB, ih, iw):
    margin = 10
    return [max(0, min(aBB[0], bBB[0]) - margin),
             max(0, min(aBB[1], bBB[1]) - margin),
             min(iw, max(aBB[2], bBB[2]) + margin),
             min(ih, max(aBB[3], bBB[3]) + margin)]

In [3]:
def getAppr(im, bb):
    subim = im[bb[1] : bb[3], bb[0] : bb[2], :]
    subim = cv2.resize(subim, None, None, 224.0 / subim.shape[1], 224.0 / subim.shape[0], interpolation=cv2.INTER_LINEAR)
    subim = tf.keras.applications.vgg16.preprocess_input(subim)
    return subim

In [4]:
def getDualMask(ih, iw, bb):
    rh = 32.0 / ih
    rw = 32.0 / iw
    x1 = max(0, int(math.floor(bb[0] * rw)))
    x2 = min(32, int(math.ceil(bb[2] * rw)))
    y1 = max(0, int(math.floor(bb[1] * rh)))
    y2 = min(32, int(math.ceil(bb[3] * rh)))
    mask = np.zeros((32, 32))
    mask[y1 : y2, x1 : x2] = 1
    assert(mask.sum() == (y2 - y1) * (x2 - x1))
    return mask

In [5]:
def forward_batch(model, ims, poses, qas, qbs):
    test_set = []
    for i in range(ims.shape[0]):
        test_set.append({'qa': qas[i], 'qb': qbs[i], 'im': ims[i], 'posdata': poses[i]})

    test_elements = tuple(test_set)
    test_dataset = tf.data.Dataset.from_generator(
        lambda: test_elements, {'qa': tf.int32, 'qb': tf.int32, 'im': tf.float32, 'posdata': tf.float32}
    )
    test_dataset = test_dataset.cache().batch(ims.shape[0]).prefetch(buffer_size=AUTOTUNE)
    
    for sample in test_dataset:
        itr_pred = model(sample['qa'], sample['qb'], sample['im'], sample['posdata'])
    
    return itr_pred

In [6]:
def test_model(model, out_path):
    num_img = len(image_paths)
    num_class = 101
    thresh = 0.05
    batch_size = 20
    pred = []
    pred_bboxes = []

    for i in range(num_img):
        im = cv2.imread(image_paths[i]).astype(np.float32, copy=False)
        ih = im.shape[0]
        iw = im.shape[1]
        gts = np.array(all_gts[i])
        gt_bboxes = np.array(all_gt_bboxes[i])
        num_gts = gts.shape[0]
        pred.append([])
        pred_bboxes.append([])
        ims = []
        poses = []
        qas = []
        qbs = []
        for j in range(num_gts):
            sub = gt_bboxes[j, 0, :]
            obj = gt_bboxes[j, 1, :]
            rBB = getUnionBBox(sub, obj, ih, iw)
            rAppr = getAppr(im, rBB)
            rMask = np.array([getDualMask(ih, iw, sub), getDualMask(ih, iw, obj)])
            ims.append(rAppr)
            poses.append(rMask)
            qa = np.zeros(num_class - 1)
            qa[gts[j, 0] - 1] = 1
            qb = np.zeros(num_class - 1)
            qb[gts[j, 2] - 1] = 1
            qas.append(qa)
            qbs.append(qb)
        if len(ims) == 0:
            continue
        ims = np.array(ims)
        poses = np.array(poses)
        qas = np.array(qas)
        qbs = np.array(qbs)
        poses = poses.transpose((0, 2, 3, 1))
        _cursor = 0
        itr_pred = None
        num_ins = ims.shape[0]
        while _cursor < num_ins:
            _end_batch = min(_cursor + batch_size, num_ins)
            itr_pred_batch = forward_batch(model, ims[_cursor : _end_batch], poses[_cursor : _end_batch], qas[_cursor : _end_batch], qbs[_cursor : _end_batch])
            if itr_pred is None:
                itr_pred = itr_pred_batch
            else:
                itr_pred = np.vstack((itr_pred, itr_pred_batch))
            _cursor = _end_batch

        for j in range(num_gts):
            sub = gt_bboxes[j, 0, :]
            obj = gt_bboxes[j, 1, :]
            for k in range(itr_pred.shape[1]):
                if itr_pred[j, k] < thresh: 
                    continue
                pred[i].append([itr_pred[j, k], 1, 1, gts[j, 0], k, gts[j, 2]])
                pred_bboxes[i].append([sub, obj])
        pred[i] = np.array(pred[i])
        pred_bboxes[i] = np.array(pred_bboxes[i])

    print("writing file..")
    np.savez(out_path, pred=pred, pred_bboxes=pred_bboxes)

In [7]:
def computeArea(bb):
    return max(0, bb[2] - bb[0] + 1) * max(0, bb[3] - bb[1] + 1)

In [8]:
def computeIoU(bb1, bb2):
    ibb = [max(bb1[0], bb2[0]), \
        max(bb1[1], bb2[1]), \
        min(bb1[2], bb2[2]), \
        min(bb1[3], bb2[3])]
    iArea = computeArea(ibb)
    uArea = computeArea(bb1) + computeArea(bb2) - iArea
    return (iArea + 0.0) / uArea

In [9]:
def computeOverlap(detBBs, gtBBs):
    aIoU = computeIoU(detBBs[0, :], gtBBs[0, :])
    bIoU = computeIoU(detBBs[1, :], gtBBs[1, :])
    return min(aIoU, bIoU)

In [10]:
def eval_recall(det_file_path, num_dets=50, ov_thresh=0.5):
    det_file = np.load(det_file_path, allow_pickle=True)
    dets = det_file['pred']
    det_bboxes = det_file['pred_bboxes']
    num_img = len(dets)
    tp = []
    fp = []
    score = []
    total_num_gts = 0
    for i in range(num_img):
        gts = np.array(all_gts[i])
        gt_bboxes = np.array(all_gt_bboxes[i])
        num_gts = gts.shape[0]
        total_num_gts += num_gts
        gt_detected = np.zeros(num_gts)
        if isinstance(dets[i], np.ndarray) and dets[i].shape[0] > 0:
            det_score = np.log(dets[i][:, 0]) + np.log(dets[i][:, 1]) + np.log(dets[i][:, 2])
            inds = np.argsort(det_score)[::-1]
            if num_dets > 0 and num_dets < len(inds):
                inds = inds[:num_dets]
            top_dets = dets[i][inds, 3:]
            top_scores = det_score[inds]
            top_det_bboxes = det_bboxes[i][inds, :]
            temp_num_dets = len(inds)
            for j in range(temp_num_dets):
                ov_max = 0
                arg_max = -1
                for k in range(num_gts):
                    if gt_detected[k] == 0 and top_dets[j, 0] == gts[k, 0] and top_dets[j, 1] == gts[k, 1] and top_dets[j, 2] == gts[k, 2]:
                        ov = computeOverlap(top_det_bboxes[j, :, :], gt_bboxes[k, :, :])
                        if ov >= ov_thresh and ov > ov_max:
                            ov_max = ov
                            arg_max = k
                if arg_max != -1:
                    gt_detected[arg_max] = 1
                    tp.append(1)
                    fp.append(0)
                else:
                    tp.append(0)
                    fp.append(1)
                score.append(top_scores[j])
    score = np.array(score)
    tp = np.array(tp)
    fp = np.array(fp)
    inds = np.argsort(score)
    inds = inds[::-1]
    tp = tp[inds]
    fp = fp[inds]
    tp = np.cumsum(tp)
    fp = np.cumsum(fp)
    recall = (tp + 0.0) / total_num_gts
    top_recall = recall[-1]
    print('Recall:', top_recall)
    return top_recall

In [11]:
dataset = './reltrain.json'
nclass = 100

samples = json.load(open(dataset))
num_instance = len(samples)

In [12]:
qas = []
qbs = []
ims = []
poses = []
labels = []

for i in range(num_instance):
    sample = samples[i]
    im = cv2.imread(sample["imPath"]).astype(np.float32, copy=False)
    ih = im.shape[0]
    iw = im.shape[1]
    qa = np.zeros(nclass)
    qa[sample["aLabel"] - 1] = 1
    qas.append(qa)
    qb = np.zeros(nclass)
    qb[sample["bLabel"] - 1] = 1
    qbs.append(qb)
    ims.append(getAppr(im, sample["rBBox"]))
    poses.append([getDualMask(ih, iw, sample["aBBox"]), 
                  getDualMask(ih, iw, sample["bBBox"])])
    labels.append(sample["rLabel"])

In [13]:
poses = np.array(poses).transpose((0, 2, 3, 1))

qa = np.array(qas)
qb = np.array(qbs)
im = np.array(ims)
posdata = np.array(poses)
labels = np.array(labels)

In [14]:
train_set = []

for i in range(num_instance):
    train_set.append({'qa': qa[i], 'qb': qb[i], 'im': im[i], 'posdata': posdata[i], 'labels': labels[i]})
    
train_set, valid_set = train_set[:-7500], train_set[-7500:]
train_elements = tuple(train_set)
valid_elements = tuple(valid_set)

In [15]:
train_dataset = tf.data.Dataset.from_generator(
    lambda: train_elements, {'qa': tf.int32, 'qb': tf.int32, 'im': tf.float32, 'posdata': tf.float32, 'labels': tf.int32}
)

In [16]:
valid_dataset = tf.data.Dataset.from_generator(
    lambda: valid_elements, {'qa': tf.int32, 'qb': tf.int32, 'im': tf.float32, 'posdata': tf.float32, 'labels': tf.int32}
)

In [17]:
for i, sample in enumerate(train_dataset.take(3)):
    print(i+1, ':', sample)

1 : {'qa': <tf.Tensor: id=73, shape=(100,), dtype=int32, numpy=
array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>, 'qb': <tf.Tensor: id=74, shape=(100,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0])>, 'im': <tf.Tensor: id=70, shape=(224, 224, 3), dtype=float32, numpy=
array([[[127.84672  , 124.00672  , 131.32     ],
        [130.03198  , 127.19198  , 130.29099  ],
        [133.05893  , 129.12054  , 

In [18]:
for i, sample in enumerate(valid_dataset.take(3)):
    print(i+1, ':', sample)

1 : {'qa': <tf.Tensor: id=95, shape=(100,), dtype=int32, numpy=
array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>, 'qb': <tf.Tensor: id=96, shape=(100,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>, 'im': <tf.Tensor: id=92, shape=(224, 224, 3), dtype=float32, numpy=
array([[[ -31.561043 ,  -72.40103  , -103.30203  ],
        [ -34.46836  ,  -75.26817  , -106.161194 ],
        [ -36.227394 ,  -75.0

In [19]:
BUFFER_SIZE = num_instance
BATCH_SIZE = 32
AUTOTUNE = tf.data.experimental.AUTOTUNE
EPOCHS = 100

In [20]:
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).prefetch(buffer_size=AUTOTUNE)
valid_dataset = valid_dataset.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

In [21]:
test_dataset = './reltest.json'
nclass = 100

test_samples = json.load(open(test_dataset))
test_num_instance = len(test_samples)

In [22]:
image_paths = []
gt_label = []
gt_box = []
j = 0

for i in range(test_num_instance):
    if(i == 0):
        gt_label.append([])
        gt_box.append([])
        img_path = test_samples[i]['imPath']
        image_paths.append(img_path)
        gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
        gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])
    else:
        if(img_path == test_samples[i]['imPath']):
            gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
            gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])
        else:
            j += 1
            gt_label.append([])
            gt_box.append([])
            img_path = test_samples[i]['imPath']
            image_paths.append(img_path)
            gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
            gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])

In [23]:
all_gts = np.array(gt_label)
all_gt_bboxes = np.array(gt_box)

# 모델

In [24]:
class AppearanceSubnet(tf.keras.layers.Layer):
    def __init__(self, input_shape):
        super(AppearanceSubnet, self).__init__()
        
        self.vgg16 = tf.keras.applications.vgg16.VGG16(include_top=False, 
                                                       weights='imagenet', 
                                                       input_shape=input_shape)
        
        self.fc = tf.keras.Sequential([
            tf.keras.layers.Flatten(name='flat'),
            tf.keras.layers.Dense(4096, name='fc6'),
            tf.keras.layers.ReLU(name='relu6'),
            tf.keras.layers.Dense(4096, name='fc7'),
            tf.keras.layers.ReLU(name='relu7'),
            tf.keras.layers.Dense(256, kernel_initializer='glorot_normal', name='fc8'),
            tf.keras.layers.ReLU(name='relu8')
        ]) 
        
    def call(self, x):
        out = self.vgg16(x)
        out = self.fc(out)
        return out

In [25]:
class SpatialSubnet(tf.keras.layers.Layer):
    def __init__(self, input_shape):
        super(SpatialSubnet, self).__init__()
        
        self.conv = tf.keras.Sequential([
            tf.keras.layers.Conv2D(96, 5, 2, padding='same', input_shape=input_shape, name='conv1_p'),
            tf.keras.layers.ReLU(name='relu1_p'),
            tf.keras.layers.Conv2D(128, 5, 2, padding='same', name='conv2_p'),
            tf.keras.layers.Conv2D(64, 8, name='conv3_p'),
            tf.keras.layers.ReLU(name='relu3_p')
        ])
        
    def call(self, x):
        out = self.conv(x)
        return out

In [26]:
class CombineSubnets(tf.keras.layers.Layer):
    def __init__(self):
        super(CombineSubnets, self).__init__()
        
        self.concat1_c = tf.keras.layers.Concatenate(name='concat1_c')
        
        self.fc = tf.keras.Sequential([
            tf.keras.layers.Dense(128, name='fc2_c'),
            tf.keras.layers.ReLU(name='relu2_c'),
            tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiR_0'),
            tf.keras.layers.ReLU(name='relu_0')
        ])
        
    def call(self, x1, x2):
        x2 = x2[:, 0, 0, :]
        out = self.concat1_c([x1, x2])
        out = self.fc(out) # qr0
        return out

In [27]:
class DRLayer(tf.keras.layers.Layer):
    def __init__(self, i, activation=True):
        super(DRLayer, self).__init__()
        
        self.activation = activation
        
        self.PhiA = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiA_%d'%(i)) # qar_i
        self.PhiB = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiB_%d'%(i)) # qbr_i
        self.PhiR = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiR_%d'%(i)) # q_i_r
        self.QSum = tf.keras.layers.Add(name='QSum_%d'%(i)) # qr_i_un
        if(activation == True):
            self.relu = tf.keras.layers.ReLU(name='relu_%d'%(i)) # qr_i
        
    def call(self, qa, qb, qr):
        qar = self.PhiA(qa)
        qbr = self.PhiB(qb)
        qr = self.PhiR(qr)
        qrun = self.QSum([qar, qbr, qr])
        if self.activation:
            qr = self.relu(qrun)
        else:
            qr = qrun
        return qr

In [28]:
class DRModule(tf.keras.layers.Layer):
    def __init__(self, num_layers):
        super(DRModule, self).__init__()
        
        self.num_layers = num_layers
        
        self.dr_layers = [DRLayer(i+1, activation=True) if((i+1) != num_layers) 
                          else DRLayer(i+1, activation=False) for i in range(num_layers)]
        
    def call(self, qa, qb, qr):
        for i in range(self.num_layers):
            qr = self.dr_layers[i](qa, qb, qr)
        return qr

In [29]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        self.combine = CombineSubnets()
        self.dr = DRModule(num_layers=num_layers)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        qr0 = self.combine(fc8, conv3_p)
        qr = self.dr(qa, qb, qr0)
        out = self.softmax(qr)
        return out

# 학습

## Appearance + Spatial + DRNet

In [30]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        self.combine = CombineSubnets()
        #self.dr = DRModule(num_layers=num_layers)
        self.temp_fc = tf.keras.layers.Dense(70)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        qr0 = self.combine(fc8, conv3_p)
        #qr = self.dr(qa, qb, qr0)
        qr = self.temp_fc(qr0)
        out = self.softmax(qr)
        return out

In [31]:
model_AS = DRNet()

In [32]:
for sample in train_dataset:
    print(model_AS(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

tf.Tensor(
[[1.13224905e-05 1.71477659e-04 1.29768625e-02 ... 7.77317109e-06
  1.07966225e-05 1.31354454e-05]
 [2.15396540e-07 5.25443602e-06 1.91818867e-06 ... 5.42237011e-09
  3.57638585e-10 5.70835934e-10]
 [9.02796726e-11 1.69140549e-07 7.02400257e-06 ... 1.63505698e-09
  4.73820450e-08 1.37312856e-10]
 ...
 [1.78591711e-06 8.85407701e-07 2.13635445e-04 ... 9.36024117e-06
  3.56996861e-05 6.40996223e-05]
 [2.36437132e-04 2.41724818e-04 9.74151073e-04 ... 1.62374854e-05
  2.08381622e-04 5.13813924e-04]
 [4.67778300e-05 8.90298470e-05 1.67208854e-02 ... 9.64279934e-06
  8.60984528e-08 3.33787757e-05]], shape=(32, 70), dtype=float32)


In [33]:
checkpoint_path = "./checkpoints/model_AS"
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
ckpt = tf.train.Checkpoint(model_AS=model_AS, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

Latest checkpoint restored!!


In [34]:
model_AS.get_layer(index=0).weights

[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32, numpy=
 array([[[[ 4.27931070e-01,  1.07833758e-01,  3.07149403e-02, ...,
           -1.29441008e-01, -5.16037457e-02,  6.89771818e-03],
          [ 5.48374116e-01,  1.63840204e-02,  9.50021371e-02, ...,
           -8.09179693e-02, -5.09200953e-02,  3.40959579e-02],
          [ 4.77695614e-01, -1.72479630e-01,  3.30862179e-02, ...,
           -1.21648513e-01, -5.03711998e-02,  3.12613398e-02]],
 
         [[ 3.72379839e-01,  1.52337238e-01, -1.46048912e-03, ...,
           -1.45280331e-01, -2.33460948e-01, -6.33275658e-02],
          [ 4.38339412e-01,  4.21929546e-02,  4.76780683e-02, ...,
           -9.46645066e-02, -2.95681745e-01, -7.34420791e-02],
          [ 4.06389177e-01, -1.70360178e-01, -9.46670026e-03, ...,
           -1.16538994e-01, -2.76247591e-01, -4.24190387e-02]],
 
         [[-6.22575507e-02,  1.26073912e-01, -1.19190760e-01, ...,
           -1.37872234e-01, -3.75898689e-01, -3.01074803e-01],
    

In [35]:
model_AS.get_layer(index=1).weights

[<tf.Variable 'conv1_p/kernel:0' shape=(5, 5, 2, 96) dtype=float32, numpy=
 array([[[[-1.36488527e-01, -3.68125290e-02, -1.26451537e-01, ...,
            1.08364308e-02,  4.73632812e-02,  1.34240817e-02],
          [-4.63199876e-02,  2.65155938e-02, -7.69321844e-02, ...,
           -4.48953696e-02, -1.29076645e-01, -3.79951932e-02]],
 
         [[-1.67776972e-01, -1.25102345e-02, -1.10174783e-01, ...,
           -6.15776591e-02,  3.18346657e-02,  6.26979396e-02],
          [-5.74336648e-02,  1.13600865e-02, -8.50612298e-02, ...,
           -1.32284552e-01, -8.98456499e-02,  5.63021488e-02]],
 
         [[-8.72625038e-02,  1.47426249e-02, -4.51460630e-02, ...,
           -7.81078190e-02,  5.41246906e-02,  8.95245075e-02],
          [-8.66236910e-02, -2.71416791e-02, -4.55229916e-02, ...,
           -1.04289569e-01, -6.76627010e-02, -2.23814808e-02]],
 
         [[-5.31949103e-02, -2.55441368e-02,  3.55071947e-02, ...,
           -9.03653651e-02,  2.54687648e-02,  2.85620131e-02],
      

In [36]:
model_AS.get_layer(index=2).weights

[<tf.Variable 'dr_net/combine_subnets/sequential_2/fc2_c/kernel:0' shape=(320, 128) dtype=float32, numpy=
 array([[-0.00968383,  0.06317469,  0.07667089, ..., -0.11179165,
          0.01952122, -0.07292804],
        [ 0.07201405, -0.06636026,  0.10938871, ...,  0.04089434,
          0.07850075,  0.05342102],
        [-0.04323443, -0.0855433 ,  0.05487521, ...,  0.05189518,
          0.07151001, -0.00344929],
        ...,
        [-0.09635261, -0.11027512,  0.04752418, ...,  0.01985394,
          0.10196023,  0.02087229],
        [ 0.0611864 , -0.09159049,  0.03113977, ..., -0.07005785,
         -0.08391578, -0.11150547],
        [-0.03242181,  0.07993075, -0.08217543, ...,  0.05101888,
         -0.00510558,  0.09665211]], dtype=float32)>,
 <tf.Variable 'dr_net/combine_subnets/sequential_2/fc2_c/bias:0' shape=(128,) dtype=float32, numpy=
 array([ 0.00263214,  0.00660579, -0.01093696,  0.00015087,  0.00481015,
         0.00940789, -0.00241184, -0.00633316,  0.01254156,  0.00991397,
     

In [37]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        self.combine = CombineSubnets()
        self.dr = DRModule(num_layers=num_layers)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        qr0 = self.combine(fc8, conv3_p)
        qr = self.dr(qa, qb, qr0)
        out = self.softmax(qr)
        return out

In [38]:
model_ASD = DRNet()

In [39]:
for sample in train_dataset:
    print(model_ASD(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

tf.Tensor(
[[0.01853327 0.01870873 0.01126556 ... 0.0066454  0.01288078 0.01492342]
 [0.02019287 0.01447443 0.01200012 ... 0.00857792 0.01615425 0.01016115]
 [0.01676575 0.01418035 0.01744283 ... 0.00786474 0.01867569 0.0142611 ]
 ...
 [0.01757145 0.01669572 0.02355957 ... 0.00433662 0.01471759 0.01294818]
 [0.01943001 0.01558659 0.01687654 ... 0.00838649 0.01629069 0.01635185]
 [0.02712214 0.01681428 0.0180939  ... 0.00900768 0.01106688 0.01019764]], shape=(32, 70), dtype=float32)


In [40]:
for i in range(len(model_ASD.get_layer(index=0).weights)):
    model_ASD.get_layer(index=0).weights[i].assign(model_AS.get_layer(index=0).weights[i])

In [41]:
model_ASD.get_layer(index=0).weights

[<tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32, numpy=
 array([[[[ 4.27931070e-01,  1.07833758e-01,  3.07149403e-02, ...,
           -1.29441008e-01, -5.16037457e-02,  6.89771818e-03],
          [ 5.48374116e-01,  1.63840204e-02,  9.50021371e-02, ...,
           -8.09179693e-02, -5.09200953e-02,  3.40959579e-02],
          [ 4.77695614e-01, -1.72479630e-01,  3.30862179e-02, ...,
           -1.21648513e-01, -5.03711998e-02,  3.12613398e-02]],
 
         [[ 3.72379839e-01,  1.52337238e-01, -1.46048912e-03, ...,
           -1.45280331e-01, -2.33460948e-01, -6.33275658e-02],
          [ 4.38339412e-01,  4.21929546e-02,  4.76780683e-02, ...,
           -9.46645066e-02, -2.95681745e-01, -7.34420791e-02],
          [ 4.06389177e-01, -1.70360178e-01, -9.46670026e-03, ...,
           -1.16538994e-01, -2.76247591e-01, -4.24190387e-02]],
 
         [[-6.22575507e-02,  1.26073912e-01, -1.19190760e-01, ...,
           -1.37872234e-01, -3.75898689e-01, -3.01074803e-01],
  

In [42]:
for i in range(len(model_ASD.get_layer(index=1).weights)):
    model_ASD.get_layer(index=1).weights[i].assign(model_AS.get_layer(index=1).weights[i])

In [43]:
model_ASD.get_layer(index=1).weights

[<tf.Variable 'conv1_p_1/kernel:0' shape=(5, 5, 2, 96) dtype=float32, numpy=
 array([[[[-1.36488527e-01, -3.68125290e-02, -1.26451537e-01, ...,
            1.08364308e-02,  4.73632812e-02,  1.34240817e-02],
          [-4.63199876e-02,  2.65155938e-02, -7.69321844e-02, ...,
           -4.48953696e-02, -1.29076645e-01, -3.79951932e-02]],
 
         [[-1.67776972e-01, -1.25102345e-02, -1.10174783e-01, ...,
           -6.15776591e-02,  3.18346657e-02,  6.26979396e-02],
          [-5.74336648e-02,  1.13600865e-02, -8.50612298e-02, ...,
           -1.32284552e-01, -8.98456499e-02,  5.63021488e-02]],
 
         [[-8.72625038e-02,  1.47426249e-02, -4.51460630e-02, ...,
           -7.81078190e-02,  5.41246906e-02,  8.95245075e-02],
          [-8.66236910e-02, -2.71416791e-02, -4.55229916e-02, ...,
           -1.04289569e-01, -6.76627010e-02, -2.23814808e-02]],
 
         [[-5.31949103e-02, -2.55441368e-02,  3.55071947e-02, ...,
           -9.03653651e-02,  2.54687648e-02,  2.85620131e-02],
    

In [44]:
for i in range(len(model_ASD.get_layer(index=2).weights)):
    model_ASD.get_layer(index=2).weights[i].assign(model_AS.get_layer(index=2).weights[i])

In [45]:
model_ASD.get_layer(index=2).weights

[<tf.Variable 'dr_net_1/combine_subnets_1/sequential_5/fc2_c/kernel:0' shape=(320, 128) dtype=float32, numpy=
 array([[-0.00968383,  0.06317469,  0.07667089, ..., -0.11179165,
          0.01952122, -0.07292804],
        [ 0.07201405, -0.06636026,  0.10938871, ...,  0.04089434,
          0.07850075,  0.05342102],
        [-0.04323443, -0.0855433 ,  0.05487521, ...,  0.05189518,
          0.07151001, -0.00344929],
        ...,
        [-0.09635261, -0.11027512,  0.04752418, ...,  0.01985394,
          0.10196023,  0.02087229],
        [ 0.0611864 , -0.09159049,  0.03113977, ..., -0.07005785,
         -0.08391578, -0.11150547],
        [-0.03242181,  0.07993075, -0.08217543, ...,  0.05101888,
         -0.00510558,  0.09665211]], dtype=float32)>,
 <tf.Variable 'dr_net_1/combine_subnets_1/sequential_5/fc2_c/bias:0' shape=(128,) dtype=float32, numpy=
 array([ 0.00263214,  0.00660579, -0.01093696,  0.00015087,  0.00481015,
         0.00940789, -0.00241184, -0.00633316,  0.01254156,  0.0099139

In [46]:
for sample in train_dataset:
    print(model_ASD(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

tf.Tensor(
[[0.01352933 0.0163089  0.01421448 ... 0.00958284 0.01884048 0.01371165]
 [0.01796558 0.01886468 0.01635736 ... 0.0092968  0.01349617 0.01184941]
 [0.01247367 0.01919288 0.01545339 ... 0.00713109 0.01541986 0.01276136]
 ...
 [0.01356694 0.0167219  0.01398242 ... 0.00963992 0.0191011  0.01341463]
 [0.01607051 0.01784141 0.02184272 ... 0.00677053 0.01672702 0.01202476]
 [0.01193113 0.01573065 0.01597536 ... 0.01227649 0.01436133 0.01025282]], shape=(32, 70), dtype=float32)


In [47]:
for layer in model_ASD.layers[:3]:
    layer.trainable = False

In [48]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='eval_accuracy')

In [49]:
checkpoint_path = "./checkpoints/model_ASD"
ckpt = tf.train.Checkpoint(model_ASD=model_ASD, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')



In [50]:
@tf.function
def train_step(model, qa, qb, im, posdata, label):
    loss = 0
    
    with tf.GradientTape() as tape:
        y = model(qa, qb, im, posdata)
        loss += loss_object(label, y)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(label, y)

In [51]:
@tf.function
def eval_step(model, qa, qb, im, posdata, label):
    loss = 0
    
    y = model(qa, qb, im, posdata)
    loss += loss_object(label, y)
    
    eval_loss(loss)
    eval_accuracy(label, y)

In [52]:
def test_step(model, out_path='temp.npz', num_dets=50, ov_thresh=0.5):
    test_model(model, out_path)
    test_recall = eval_recall(out_path, num_dets, ov_thresh)
    return test_recall

In [53]:
max_acc = 0
early_stop_cnt = 0

for epoch in tqdm(range(EPOCHS)):
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    eval_loss.reset_states()
    eval_accuracy.reset_states()

    for sample in train_dataset:
        train_step(model_ASD, sample['qa'], sample['qb'], sample['im'], sample['posdata'], sample['labels'])
        
    for sample in valid_dataset:
        eval_step(model_ASD, sample['qa'], sample['qb'], sample['im'], sample['posdata'], sample['labels'])
    
    end = time.time()
    eval_acc = eval_accuracy.result()
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f} Eval_Loss {:.4f} Eval_Accuracy {:.4f}'.format(epoch + 1, 
                                                                                               train_loss.result(), 
                                                                                               train_accuracy.result(),
                                                                                               eval_loss.result(),
                                                                                               eval_accuracy.result())) 
    
    print ('Time taken for 1 epoch: {} secs\n'.format(end - start))
    
    if(eval_acc > max_acc):
        max_acc = eval_acc
        early_stop_cnt = 0
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                             ckpt_save_path))
    else:
        early_stop_cnt += 1
        
    if early_stop_cnt == 20:
        break

  0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 1 Loss 1.8735 Accuracy 0.4888 Eval_Loss 1.7536 Eval_Accuracy 0.4987
Time taken for 1 epoch: 59.33489274978638 secs



  1%|▊                                                                               | 1/100 [01:00<1:40:12, 60.74s/it]

Saving checkpoint for epoch 1 at ./checkpoints/model_ASD\ckpt-1
Epoch 2 Loss 1.5141 Accuracy 0.5551 Eval_Loss 1.6672 Eval_Accuracy 0.5123
Time taken for 1 epoch: 55.09110713005066 secs



  2%|█▌                                                                              | 2/100 [01:57<1:37:01, 59.40s/it]

Saving checkpoint for epoch 2 at ./checkpoints/model_ASD\ckpt-2
Epoch 3 Loss 1.4210 Accuracy 0.5702 Eval_Loss 1.6598 Eval_Accuracy 0.5275
Time taken for 1 epoch: 55.12201905250549 secs



  3%|██▍                                                                             | 3/100 [02:53<1:34:32, 58.48s/it]

Saving checkpoint for epoch 3 at ./checkpoints/model_ASD\ckpt-3
Epoch 4 Loss 1.3627 Accuracy 0.5815 Eval_Loss 1.6270 Eval_Accuracy 0.5288
Time taken for 1 epoch: 55.25356411933899 secs



  4%|███▏                                                                            | 4/100 [03:49<1:32:35, 57.87s/it]

Saving checkpoint for epoch 4 at ./checkpoints/model_ASD\ckpt-4


  5%|████                                                                            | 5/100 [04:45<1:30:21, 57.07s/it]

Epoch 5 Loss 1.3150 Accuracy 0.5894 Eval_Loss 1.6172 Eval_Accuracy 0.5252
Time taken for 1 epoch: 55.18105173110962 secs



  6%|████▊                                                                           | 6/100 [05:40<1:28:33, 56.52s/it]

Epoch 6 Loss 1.2728 Accuracy 0.5994 Eval_Loss 1.6616 Eval_Accuracy 0.5269
Time taken for 1 epoch: 55.22300744056702 secs



  7%|█████▌                                                                          | 7/100 [06:35<1:27:00, 56.13s/it]

Epoch 7 Loss 1.2367 Accuracy 0.6031 Eval_Loss 1.6559 Eval_Accuracy 0.5115
Time taken for 1 epoch: 55.1819703578949 secs



  8%|██████▍                                                                         | 8/100 [07:30<1:25:39, 55.86s/it]

Epoch 8 Loss 1.2028 Accuracy 0.6121 Eval_Loss 1.6838 Eval_Accuracy 0.5127
Time taken for 1 epoch: 55.21241331100464 secs



  9%|███████▏                                                                        | 9/100 [08:25<1:24:26, 55.68s/it]

Epoch 9 Loss 1.1770 Accuracy 0.6210 Eval_Loss 1.6837 Eval_Accuracy 0.5108
Time taken for 1 epoch: 55.231849193573 secs



 10%|███████▉                                                                       | 10/100 [09:21<1:23:18, 55.54s/it]

Epoch 10 Loss 1.1499 Accuracy 0.6249 Eval_Loss 1.6852 Eval_Accuracy 0.5204
Time taken for 1 epoch: 55.192997455596924 secs



 11%|████████▋                                                                      | 11/100 [10:16<1:22:14, 55.45s/it]

Epoch 11 Loss 1.1204 Accuracy 0.6312 Eval_Loss 1.6842 Eval_Accuracy 0.5225
Time taken for 1 epoch: 55.19548034667969 secs



 12%|█████████▍                                                                     | 12/100 [11:11<1:21:10, 55.35s/it]

Epoch 12 Loss 1.0957 Accuracy 0.6358 Eval_Loss 1.7426 Eval_Accuracy 0.5065
Time taken for 1 epoch: 55.09964394569397 secs



 13%|██████████▎                                                                    | 13/100 [12:06<1:20:08, 55.27s/it]

Epoch 13 Loss 1.0691 Accuracy 0.6448 Eval_Loss 1.7495 Eval_Accuracy 0.5096
Time taken for 1 epoch: 55.06776785850525 secs



 14%|███████████                                                                    | 14/100 [13:01<1:19:08, 55.21s/it]

Epoch 14 Loss 1.0431 Accuracy 0.6505 Eval_Loss 1.8203 Eval_Accuracy 0.5049
Time taken for 1 epoch: 55.04536509513855 secs



 15%|███████████▊                                                                   | 15/100 [13:56<1:18:10, 55.18s/it]

Epoch 15 Loss 1.0196 Accuracy 0.6552 Eval_Loss 1.8388 Eval_Accuracy 0.5117
Time taken for 1 epoch: 55.09011101722717 secs



 16%|████████████▋                                                                  | 16/100 [14:51<1:17:13, 55.16s/it]

Epoch 16 Loss 1.0000 Accuracy 0.6588 Eval_Loss 1.9451 Eval_Accuracy 0.5084
Time taken for 1 epoch: 55.07764959335327 secs



 17%|█████████████▍                                                                 | 17/100 [15:47<1:16:23, 55.22s/it]

Epoch 17 Loss 0.9684 Accuracy 0.6686 Eval_Loss 1.9141 Eval_Accuracy 0.5009
Time taken for 1 epoch: 55.32738661766052 secs



 18%|██████████████▏                                                                | 18/100 [16:42<1:15:39, 55.36s/it]

Epoch 18 Loss 0.9518 Accuracy 0.6742 Eval_Loss 1.9551 Eval_Accuracy 0.4961
Time taken for 1 epoch: 55.668256521224976 secs



 19%|███████████████                                                                | 19/100 [17:38<1:14:43, 55.35s/it]

Epoch 19 Loss 0.9313 Accuracy 0.6779 Eval_Loss 1.9566 Eval_Accuracy 0.4951
Time taken for 1 epoch: 55.312034368515015 secs



 20%|███████████████▊                                                               | 20/100 [18:33<1:13:43, 55.29s/it]

Epoch 20 Loss 0.9092 Accuracy 0.6842 Eval_Loss 2.0325 Eval_Accuracy 0.4995
Time taken for 1 epoch: 55.12065935134888 secs



 21%|████████████████▌                                                              | 21/100 [19:28<1:12:43, 55.23s/it]

Epoch 21 Loss 0.8915 Accuracy 0.6892 Eval_Loss 2.0402 Eval_Accuracy 0.4979
Time taken for 1 epoch: 55.07346487045288 secs



 22%|█████████████████▍                                                             | 22/100 [20:23<1:11:46, 55.21s/it]

Epoch 22 Loss 0.8726 Accuracy 0.6957 Eval_Loss 2.0472 Eval_Accuracy 0.4884
Time taken for 1 epoch: 55.11211919784546 secs



 23%|██████████████████▏                                                            | 23/100 [21:18<1:10:48, 55.18s/it]

Epoch 23 Loss 0.8508 Accuracy 0.6987 Eval_Loss 2.1888 Eval_Accuracy 0.4808
Time taken for 1 epoch: 55.07874870300293 secs



 23%|██████████████████▏                                                            | 23/100 [22:13<1:14:25, 58.00s/it]

Epoch 24 Loss 0.8320 Accuracy 0.7095 Eval_Loss 2.2668 Eval_Accuracy 0.4928
Time taken for 1 epoch: 55.112958908081055 secs






In [54]:
checkpoint_path = "./checkpoints/model_ASD"
ckpt = tf.train.Checkpoint(model_ASD=model_ASD, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

Latest checkpoint restored!!


In [55]:
test_step(ckpt.model_ASD)

writing file..
Recall: 0.8173480083857443


0.8173480083857443