In [1]:
import os
import time
import json
import math
from tqdm import tqdm

import cv2
import numpy as np
import tensorflow as tf

# 데이터 전처리

In [2]:
def getUnionBBox(aBB, bBB, ih, iw):
    margin = 10
    return [max(0, min(aBB[0], bBB[0]) - margin),
            max(0, min(aBB[1], bBB[1]) - margin),
            min(iw, max(aBB[2], bBB[2]) + margin),
            min(ih, max(aBB[3], bBB[3]) + margin)]

In [3]:
def getAppr(im, bb):
    subim = im[bb[1] : bb[3], bb[0] : bb[2], :]
    subim = cv2.resize(subim, None, None, 224.0 / subim.shape[1], 224.0 / subim.shape[0], interpolation=cv2.INTER_LINEAR)
    #pixel_means = np.array([[[103.939, 116.779, 123.68]]])
    #subim -= pixel_means
    subim = subim / 255.0
    return subim

In [4]:
def getDualMask(ih, iw, bb):
    rh = 32.0 / ih
    rw = 32.0 / iw
    x1 = max(0, int(math.floor(bb[0] * rw)))
    x2 = min(32, int(math.ceil(bb[2] * rw)))
    y1 = max(0, int(math.floor(bb[1] * rh)))
    y2 = min(32, int(math.ceil(bb[3] * rh)))
    mask = np.zeros((32, 32))
    mask[y1 : y2, x1 : x2] = 1
    assert(mask.sum() == (y2 - y1) * (x2 - x1))
    return mask

In [5]:
def forward_batch(model, ims, poses, qas, qbs):
    test_set = []
    for i in range(ims.shape[0]):
        test_set.append({'qa': qas[i], 'qb': qbs[i], 'im': ims[i], 'posdata': poses[i]})

    test_elements = tuple(test_set)
    test_dataset = tf.data.Dataset.from_generator(
        lambda: test_elements, {'qa': tf.int32, 'qb': tf.int32, 'im': tf.float32, 'posdata': tf.float32}
    )
    test_dataset = test_dataset.cache().batch(ims.shape[0]).prefetch(buffer_size=AUTOTUNE)
    
    for sample in test_dataset:
        itr_pred = model(sample['qa'], sample['qb'], sample['im'], sample['posdata'])
    
    return itr_pred

In [6]:
def test_model(model, out_path):
    num_img = len(image_paths)
    num_class = 101
    thresh = 0.05
    batch_size = 20
    pred = []
    pred_bboxes = []

    for i in range(num_img):
        im = cv2.imread(image_paths[i]).astype(np.float32, copy=False)
        ih = im.shape[0]
        iw = im.shape[1]
        gts = np.array(all_gts[i])
        gt_bboxes = np.array(all_gt_bboxes[i])
        num_gts = gts.shape[0]
        pred.append([])
        pred_bboxes.append([])
        ims = []
        poses = []
        qas = []
        qbs = []
        for j in range(num_gts):
            sub = gt_bboxes[j, 0, :]
            obj = gt_bboxes[j, 1, :]
            rBB = getUnionBBox(sub, obj, ih, iw)
            rAppr = getAppr(im, rBB)
            rMask = np.array([getDualMask(ih, iw, sub), getDualMask(ih, iw, obj)])
            ims.append(rAppr)
            poses.append(rMask)
            qa = np.zeros(num_class - 1)
            qa[gts[j, 0] - 1] = 1
            qb = np.zeros(num_class - 1)
            qb[gts[j, 2] - 1] = 1
            qas.append(qa)
            qbs.append(qb)
        if len(ims) == 0:
            continue
        ims = np.array(ims)
        poses = np.array(poses)
        qas = np.array(qas)
        qbs = np.array(qbs)
        poses = poses.transpose((0, 2, 3, 1))
        _cursor = 0
        itr_pred = None
        num_ins = ims.shape[0]
        while _cursor < num_ins:
            _end_batch = min(_cursor + batch_size, num_ins)
            itr_pred_batch = forward_batch(model, ims[_cursor : _end_batch], poses[_cursor : _end_batch], qas[_cursor : _end_batch], qbs[_cursor : _end_batch])
            if itr_pred is None:
                itr_pred = itr_pred_batch
            else:
                itr_pred = np.vstack((itr_pred, itr_pred_batch))
            _cursor = _end_batch

        for j in range(num_gts):
            sub = gt_bboxes[j, 0, :]
            obj = gt_bboxes[j, 1, :]
            for k in range(itr_pred.shape[1]):
                if itr_pred[j, k] < thresh: 
                    continue
                pred[i].append([itr_pred[j, k], 1, 1, gts[j, 0], k, gts[j, 2]])
                pred_bboxes[i].append([sub, obj])
        pred[i] = np.array(pred[i])
        pred_bboxes[i] = np.array(pred_bboxes[i])

    print("writing file..")
    np.savez(out_path, pred=pred, pred_bboxes=pred_bboxes)

In [7]:
def computeArea(bb):
    return max(0, bb[2] - bb[0] + 1) * max(0, bb[3] - bb[1] + 1)

In [8]:
def computeIoU(bb1, bb2):
    ibb = [max(bb1[0], bb2[0]), \
        max(bb1[1], bb2[1]), \
        min(bb1[2], bb2[2]), \
        min(bb1[3], bb2[3])]
    iArea = computeArea(ibb)
    uArea = computeArea(bb1) + computeArea(bb2) - iArea
    return (iArea + 0.0) / uArea

In [9]:
def computeOverlap(detBBs, gtBBs):
    aIoU = computeIoU(detBBs[0, :], gtBBs[0, :])
    bIoU = computeIoU(detBBs[1, :], gtBBs[1, :])
    return min(aIoU, bIoU)

In [10]:
def eval_recall(det_file_path, num_dets=50, ov_thresh=0.5):
    det_file = np.load(det_file_path, allow_pickle=True)
    dets = det_file['pred']
    det_bboxes = det_file['pred_bboxes']
    num_img = len(dets)
    tp = []
    fp = []
    score = []
    total_num_gts = 0
    for i in range(num_img):
        gts = np.array(all_gts[i])
        gt_bboxes = np.array(all_gt_bboxes[i])
        num_gts = gts.shape[0]
        total_num_gts += num_gts
        gt_detected = np.zeros(num_gts)
        if isinstance(dets[i], np.ndarray) and dets[i].shape[0] > 0:
            det_score = np.log(dets[i][:, 0]) + np.log(dets[i][:, 1]) + np.log(dets[i][:, 2])
            inds = np.argsort(det_score)[::-1]
            if num_dets > 0 and num_dets < len(inds):
                inds = inds[:num_dets]
            top_dets = dets[i][inds, 3:]
            top_scores = det_score[inds]
            top_det_bboxes = det_bboxes[i][inds, :]
            num_dets = len(inds)
            for j in range(num_dets):
                ov_max = 0
                arg_max = -1
                for k in range(num_gts):
                    if gt_detected[k] == 0 and top_dets[j, 0] == gts[k, 0] and top_dets[j, 1] == gts[k, 1] and top_dets[j, 2] == gts[k, 2]:
                        ov = computeOverlap(top_det_bboxes[j, :, :], gt_bboxes[k, :, :])
                        if ov >= ov_thresh and ov > ov_max:
                            ov_max = ov
                            arg_max = k
                if arg_max != -1:
                    gt_detected[arg_max] = 1
                    tp.append(1)
                    fp.append(0)
                else:
                    tp.append(0)
                    fp.append(1)
                score.append(top_scores[j])
    score = np.array(score)
    tp = np.array(tp)
    fp = np.array(fp)
    inds = np.argsort(score)
    inds = inds[::-1]
    tp = tp[inds]
    fp = fp[inds]
    tp = np.cumsum(tp)
    fp = np.cumsum(fp)
    recall = (tp + 0.0) / total_num_gts
    top_recall = recall[-1]
    print('Recall:', top_recall)
    return top_recall

In [11]:
dataset = './reltrain.json'
nclass = 100

samples = json.load(open(dataset))
num_instance = len(samples)
name_to_top_map = {"qa": 0, "qb": 1, "im": 2, "posdata": 3, "labels": 4}

In [12]:
qas = []
qbs = []
ims = []
poses = []
labels = []

for i in range(num_instance):
    sample = samples[i]
    im = cv2.imread(sample["imPath"]).astype(np.float32, copy=False)
    ih = im.shape[0]
    iw = im.shape[1]
    qa = np.zeros(nclass)
    qa[sample["aLabel"] - 1] = 1
    qas.append(qa)
    qb = np.zeros(nclass)
    qb[sample["bLabel"] - 1] = 1
    qbs.append(qb)
    ims.append(getAppr(im, sample["rBBox"]))
    poses.append([getDualMask(ih, iw, sample["aBBox"]), 
                  getDualMask(ih, iw, sample["bBBox"])])
    labels.append(sample["rLabel"])

In [13]:
poses = np.array(poses).transpose((0, 2, 3, 1))

qa = np.array(qas)
qb = np.array(qbs)
im = np.array(ims)
posdata = np.array(poses)
labels = np.array(labels)

In [14]:
train_set = []

for i in range(num_instance):
    train_set.append({'qa': qa[i], 'qb': qb[i], 'im': im[i], 'posdata': posdata[i], 'labels': labels[i]})
    
train_elements = tuple(train_set)

In [15]:
train_dataset = tf.data.Dataset.from_generator(
    lambda: train_elements, {'qa': tf.int32, 'qb': tf.int32, 'im': tf.float32, 'posdata': tf.float32, 'labels': tf.int32}
)

In [16]:
for i, sample in enumerate(train_dataset.take(3)):
    print(i+1, ':', sample)

1 : {'qa': <tf.Tensor: id=42, shape=(100,), dtype=int32, numpy=
array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>, 'qb': <tf.Tensor: id=43, shape=(100,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0])>, 'im': <tf.Tensor: id=39, shape=(224, 224, 3), dtype=float32, numpy=
array([[[1.        , 0.94425774, 0.9089636 ],
        [0.99596465, 0.9567489 , 0.9175332 ],
        [1.        , 0.96431196, 0.929403

In [17]:
BUFFER_SIZE = num_instance
BATCH_SIZE = 32
AUTOTUNE = tf.data.experimental.AUTOTUNE
EPOCHS = 100

In [18]:
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).prefetch(buffer_size=AUTOTUNE)

In [19]:
test_dataset = './reltest.json'
nclass = 100

test_samples = json.load(open(test_dataset))
test_num_instance = len(test_samples)

In [20]:
image_paths = []
gt_label = []
gt_box = []
j = 0

for i in range(test_num_instance):
    if(i == 0):
        gt_label.append([])
        gt_box.append([])
        img_path = test_samples[i]['imPath']
        image_paths.append(img_path)
        gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
        gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])
    else:
        if(img_path == test_samples[i]['imPath']):
            gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
            gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])
        else:
            j += 1
            gt_label.append([])
            gt_box.append([])
            img_path = test_samples[i]['imPath']
            image_paths.append(img_path)
            gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
            gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])

In [21]:
all_gts = np.array(gt_label)
all_gt_bboxes = np.array(gt_box)

# 모델

In [22]:
class AppearanceSubnet(tf.keras.layers.Layer):
    def __init__(self, input_shape):
        super(AppearanceSubnet, self).__init__()
        
        self.conv1 = tf.keras.Sequential([
            tf.keras.layers.Conv2D(64, 3, padding='same', input_shape=input_shape, name='conv1_1'),
            tf.keras.layers.ReLU(name='relu1_1'),
            tf.keras.layers.Conv2D(64, 3, padding='same', name='conv1_2'),
            tf.keras.layers.ReLU(name='relu1_2'),
            tf.keras.layers.MaxPool2D(2, 2, name='pool1')
        ])
        
        self.conv2 = tf.keras.Sequential([
            tf.keras.layers.Conv2D(128, 3, padding='same', name='conv2_1'),
            tf.keras.layers.ReLU(name='relu2_1'),
            tf.keras.layers.Conv2D(128, 3, padding='same', name='conv2_2'),
            tf.keras.layers.ReLU(name='relu2_2'),
            tf.keras.layers.MaxPool2D(2, 2, name='pool2')
        ])
        
        self.conv3 = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, 3, padding='same', name='conv3_1'),
            tf.keras.layers.ReLU(name='relu3_1'),
            tf.keras.layers.Conv2D(256, 3, padding='same', name='conv3_2'),
            tf.keras.layers.ReLU(name='relu3_2'),
            tf.keras.layers.Conv2D(256, 3, padding='same', name='conv3_3'),
            tf.keras.layers.ReLU(name='relu3_3'),
            tf.keras.layers.MaxPool2D(2, 2, name='pool3')
        ])
        
        self.conv4 = tf.keras.Sequential([
            tf.keras.layers.Conv2D(512, 3, padding='same', name='conv4_1'),
            tf.keras.layers.ReLU(name='relu4_1'),
            tf.keras.layers.Conv2D(512, 3, padding='same', name='conv4_2'),
            tf.keras.layers.ReLU(name='relu4_2'),
            tf.keras.layers.Conv2D(512, 3, padding='same', name='conv4_3'),
            tf.keras.layers.ReLU(name='relu4_3'),
            tf.keras.layers.MaxPool2D(2, 2, name='pool4')
        ])
        
        self.conv5 = tf.keras.Sequential([
            tf.keras.layers.Conv2D(512, 3, padding='same', name='conv5_1'),
            tf.keras.layers.ReLU(name='relu5_1'),
            tf.keras.layers.Conv2D(512, 3, padding='same', name='conv5_2'),
            tf.keras.layers.ReLU(name='relu5_2'),
            tf.keras.layers.Conv2D(512, 3, padding='same', name='conv5_3'),
            tf.keras.layers.ReLU(name='relu5_3'),
            tf.keras.layers.MaxPool2D(2, 2, name='pool5')
        ])
        
        self.fc = tf.keras.Sequential([
            tf.keras.layers.Flatten(name='flat'),
            tf.keras.layers.Dense(4096, name='fc6'),
            tf.keras.layers.ReLU(name='relu6'),
            tf.keras.layers.Dense(4096, name='fc7'),
            tf.keras.layers.ReLU(name='relu7'),
            tf.keras.layers.Dense(256, kernel_initializer='glorot_normal', name='fc8'),
            tf.keras.layers.ReLU(name='relu8')
        ]) 
        
    def call(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.fc(out)
        return out

In [23]:
class SpatialSubnet(tf.keras.layers.Layer):
    def __init__(self, input_shape):
        super(SpatialSubnet, self).__init__()
        
        self.conv = tf.keras.Sequential([
            tf.keras.layers.Conv2D(96, 5, 2, padding='same', input_shape=input_shape, name='conv1_p'),
            tf.keras.layers.ReLU(name='relu1_p'),
            tf.keras.layers.Conv2D(128, 5, 2, padding='same', name='conv2_p'),
            tf.keras.layers.Conv2D(64, 8, name='conv3_p'),
            tf.keras.layers.ReLU(name='relu3_p')
        ])
        
    def call(self, x):
        out = self.conv(x)
        return out

In [24]:
class CombineSubnets(tf.keras.layers.Layer):
    def __init__(self):
        super(CombineSubnets, self).__init__()
        
        self.concat1_c = tf.keras.layers.Concatenate(name='concat1_c')
        
        self.fc = tf.keras.Sequential([
            tf.keras.layers.Dense(128, name='fc2_c'),
            tf.keras.layers.ReLU(name='relu2_c'),
            tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiR_0'),
            tf.keras.layers.ReLU(name='relu_0')
        ])
        
    def call(self, x1, x2):
        x2 = x2[:, 0, 0, :]
        out = self.concat1_c([x1, x2])
        out = self.fc(out) # qr0
        return out

In [25]:
class DRLayer(tf.keras.layers.Layer):
    def __init__(self, i, activation=True):
        super(DRLayer, self).__init__()
        
        self.activation = activation
        
        self.PhiA = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiA_%d'%(i)) # qar_i
        self.PhiB = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiB_%d'%(i)) # qbr_i
        self.PhiR = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiR_%d'%(i)) # q_i_r
        self.QSum = tf.keras.layers.Add(name='QSum_%d'%(i)) # qr_i_un
        if(activation == True):
            self.relu = tf.keras.layers.ReLU(name='relu_%d'%(i)) # qr_i
        
    def call(self, qa, qb, qr):
        qar = self.PhiA(qa)
        qbr = self.PhiB(qb)
        qr = self.PhiR(qr)
        qrun = self.QSum([qar, qbr, qr])
        if self.activation:
            qr = self.relu(qrun)
        else:
            qr = qrun
        return qr

In [26]:
class DRModule(tf.keras.layers.Layer):
    def __init__(self, num_layers):
        super(DRModule, self).__init__()
        
        self.num_layers = num_layers
        
        self.dr_layers = [DRLayer(i+1, activation=True) if((i+1) != num_layers) 
                          else DRLayer(i+1, activation=False) for i in range(num_layers)]
        
    def call(self, qa, qb, qr):
        for i in range(self.num_layers):
            qr = self.dr_layers[i](qa, qb, qr)
        return qr

In [27]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        self.combine = CombineSubnets()
        self.dr = DRModule(num_layers=num_layers)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        qr0 = self.combine(fc8, conv3_p)
        qr = self.dr(qa, qb, qr0)
        out = self.softmax(qr)
        return out

# 학습

## Appearance

In [28]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        self.appr = AppearanceSubnet(input_shape=im_shape)
        #self.spatial = SpatialSubnet(input_shape=posdata_shape)
        #self.combine = CombineSubnets()
        #self.dr = DRModule(num_layers=num_layers)
        self.temp_fc = tf.keras.layers.Dense(70)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        fc8 = self.appr(im)
        #conv3_p = self.spatial(posdata)
        #qr0 = self.combine(fc8, conv3_p)
        #qr = self.dr(qa, qb, qr0)
        qr = self.temp_fc(fc8)
        out = self.softmax(qr)
        return out

In [29]:
model_A = DRNet()

In [30]:
for sample in train_dataset:
    print(model_A(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

tf.Tensor(
[[0.0142784  0.01428739 0.01426942 ... 0.01428815 0.01427437 0.0142937 ]
 [0.0142751  0.01428419 0.01427061 ... 0.01429363 0.01425983 0.01430507]
 [0.01427801 0.0142874  0.0142685  ... 0.0142898  0.0142664  0.01429612]
 ...
 [0.0142755  0.01429006 0.01427076 ... 0.01429318 0.01426644 0.01430127]
 [0.01427758 0.01428216 0.01427443 ... 0.01429116 0.01427181 0.01429635]
 [0.01427878 0.01428629 0.01426942 ... 0.01428832 0.01426927 0.01429617]], shape=(32, 70), dtype=float32)


In [31]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

In [32]:
checkpoint_path = "./checkpoints/model_A"
ckpt = tf.train.Checkpoint(model_A=model_A, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [33]:
@tf.function
def train_step(model, qa, qb, im, posdata, label):
    loss = 0
    
    with tf.GradientTape() as tape:
        y = model(qa, qb, im, posdata)
        loss += loss_object(label, y)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(label, y)

In [34]:
def eval_step(model, out_path='temp.npz', num_dets=50, ov_thresh=0.5):
    test_model(model, out_path)
    test_recall = eval_recall(out_path, num_dets, ov_thresh)
    return test_recall

In [35]:
max_recall = 0
early_stop_cnt = 0

for epoch in tqdm(range(EPOCHS)):
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()

    for sample in train_dataset:
        train_step(model_A, sample['qa'], sample['qb'], sample['im'], sample['posdata'], sample['labels'])
        
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                train_loss.result(), 
                                                train_accuracy.result()))
    
    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
    
    if (epoch+1) % 5 == 0:
        test_recall = eval_step(model_A)
        if(test_recall > max_recall):
            max_recall = test_recall
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                                 ckpt_save_path))
        else:
            early_stop_cnt += 1
        
    if early_stop_cnt == 10:
        break

  1%|▊                                                                              | 1/100 [02:54<4:47:48, 174.43s/it]

Epoch 1 Loss 6.9940 Accuracy 0.1605
Time taken for 1 epoch: 174.43052291870117 secs



  2%|█▌                                                                             | 2/100 [05:45<4:43:23, 173.51s/it]

Epoch 2 Loss 2.8981 Accuracy 0.1660
Time taken for 1 epoch: 171.35398650169373 secs



  3%|██▎                                                                            | 3/100 [08:37<4:39:35, 172.94s/it]

Epoch 3 Loss 2.8968 Accuracy 0.1667
Time taken for 1 epoch: 171.61574745178223 secs



  4%|███▏                                                                           | 4/100 [11:29<4:36:04, 172.54s/it]

Epoch 4 Loss 2.8951 Accuracy 0.1656
Time taken for 1 epoch: 171.61362886428833 secs

Epoch 5 Loss 2.8936 Accuracy 0.1659
Time taken for 1 epoch: 171.62937903404236 secs

writing file..
Recall: 0.16483228511530398


  5%|███▉                                                                           | 5/100 [18:51<6:41:31, 253.60s/it]

Saving checkpoint for epoch 5 at ./checkpoints/model_A\ckpt-1


  6%|████▋                                                                          | 6/100 [21:42<5:58:26, 228.79s/it]

Epoch 6 Loss 2.8932 Accuracy 0.1670
Time taken for 1 epoch: 170.91446447372437 secs



  7%|█████▌                                                                         | 7/100 [24:34<5:28:01, 211.63s/it]

Epoch 7 Loss 2.8933 Accuracy 0.1672
Time taken for 1 epoch: 171.57035994529724 secs



  8%|██████▎                                                                        | 8/100 [27:25<5:06:07, 199.64s/it]

Epoch 8 Loss 2.8933 Accuracy 0.1676
Time taken for 1 epoch: 171.68561172485352 secs



  9%|███████                                                                        | 9/100 [30:17<4:50:05, 191.26s/it]

Epoch 9 Loss 2.8928 Accuracy 0.1675
Time taken for 1 epoch: 171.71502208709717 secs

Epoch 10 Loss 2.8926 Accuracy 0.1678
Time taken for 1 epoch: 171.711040019989 secs

writing file..


 10%|███████▊                                                                      | 10/100 [37:24<6:33:00, 262.00s/it]

Recall: 0.1501572327044025


 11%|████████▌                                                                     | 11/100 [40:18<5:49:21, 235.53s/it]

Epoch 11 Loss 2.8934 Accuracy 0.1669
Time taken for 1 epoch: 173.7468295097351 secs



 12%|█████████▎                                                                    | 12/100 [43:10<5:17:19, 216.36s/it]

Epoch 12 Loss 2.8920 Accuracy 0.1685
Time taken for 1 epoch: 171.62308049201965 secs



 13%|██████████▏                                                                   | 13/100 [46:01<4:54:13, 202.91s/it]

Epoch 13 Loss 2.8929 Accuracy 0.1671
Time taken for 1 epoch: 171.51315999031067 secs



 14%|██████████▉                                                                   | 14/100 [48:53<4:37:20, 193.49s/it]

Epoch 14 Loss 2.8919 Accuracy 0.1679
Time taken for 1 epoch: 171.51588940620422 secs

Epoch 15 Loss 2.8927 Accuracy 0.1675
Time taken for 1 epoch: 171.49336647987366 secs

writing file..
Recall: 0.1806865828092243


 15%|███████████▋                                                                  | 15/100 [56:23<6:23:25, 270.65s/it]

Saving checkpoint for epoch 15 at ./checkpoints/model_A\ckpt-2


 16%|████████████▍                                                                 | 16/100 [59:15<5:37:14, 240.89s/it]

Epoch 16 Loss 2.8923 Accuracy 0.1663
Time taken for 1 epoch: 171.43409132957458 secs



 17%|████████████▉                                                               | 17/100 [1:02:06<5:04:26, 220.08s/it]

Epoch 17 Loss 2.8914 Accuracy 0.1680
Time taken for 1 epoch: 171.51330375671387 secs



 18%|█████████████▋                                                              | 18/100 [1:04:58<4:40:51, 205.51s/it]

Epoch 18 Loss 2.8913 Accuracy 0.1680
Time taken for 1 epoch: 171.50398302078247 secs



 19%|██████████████▍                                                             | 19/100 [1:07:49<4:23:40, 195.32s/it]

Epoch 19 Loss 2.8908 Accuracy 0.1679
Time taken for 1 epoch: 171.5336091518402 secs

Epoch 20 Loss 2.8907 Accuracy 0.1667
Time taken for 1 epoch: 171.51616263389587 secs

writing file..


 20%|███████████████▏                                                            | 20/100 [1:15:29<6:06:16, 274.71s/it]

Recall: 0.1669287211740042


 21%|███████████████▉                                                            | 21/100 [1:18:20<5:20:47, 243.64s/it]

Epoch 21 Loss 2.8908 Accuracy 0.1676
Time taken for 1 epoch: 171.11683678627014 secs



 22%|████████████████▋                                                           | 22/100 [1:21:12<4:48:37, 222.01s/it]

Epoch 22 Loss 2.8909 Accuracy 0.1685
Time taken for 1 epoch: 171.56145477294922 secs



 23%|█████████████████▍                                                          | 23/100 [1:24:04<4:25:30, 206.88s/it]

Epoch 23 Loss 2.8905 Accuracy 0.1681
Time taken for 1 epoch: 171.5800838470459 secs



 24%|██████████████████▏                                                         | 24/100 [1:26:55<4:08:38, 196.29s/it]

Epoch 24 Loss 2.8903 Accuracy 0.1678
Time taken for 1 epoch: 171.57830357551575 secs

Epoch 25 Loss 2.8907 Accuracy 0.1684
Time taken for 1 epoch: 171.57793474197388 secs

writing file..


 25%|███████████████████                                                         | 25/100 [1:34:20<5:38:29, 270.79s/it]

Recall: 0.16627358490566038


 26%|███████████████████▊                                                        | 26/100 [1:37:12<4:57:30, 241.22s/it]

Epoch 26 Loss 2.8897 Accuracy 0.1684
Time taken for 1 epoch: 172.21651577949524 secs



 27%|████████████████████▌                                                       | 27/100 [1:40:04<4:28:17, 220.52s/it]

Epoch 27 Loss 2.8896 Accuracy 0.1676
Time taken for 1 epoch: 172.12737202644348 secs



 28%|█████████████████████▎                                                      | 28/100 [1:42:56<4:07:09, 205.96s/it]

Epoch 28 Loss 2.8898 Accuracy 0.1672
Time taken for 1 epoch: 171.93139338493347 secs



 29%|██████████████████████                                                      | 29/100 [1:45:48<3:51:38, 195.75s/it]

Epoch 29 Loss 2.8899 Accuracy 0.1684
Time taken for 1 epoch: 171.8504922389984 secs

Epoch 30 Loss 2.8899 Accuracy 0.1669
Time taken for 1 epoch: 171.635427236557 secs

writing file..


 30%|██████████████████████▊                                                     | 30/100 [1:53:10<5:14:31, 269.59s/it]

Recall: 0.1669287211740042


 31%|███████████████████████▌                                                    | 31/100 [1:56:01<4:35:56, 239.95s/it]

Epoch 31 Loss 2.8894 Accuracy 0.1680
Time taken for 1 epoch: 170.79392910003662 secs



 32%|████████████████████████▎                                                   | 32/100 [1:58:51<4:08:07, 218.93s/it]

Epoch 32 Loss 2.8894 Accuracy 0.1684
Time taken for 1 epoch: 169.87082982063293 secs



 33%|█████████████████████████                                                   | 33/100 [2:01:42<3:48:37, 204.74s/it]

Epoch 33 Loss 2.8886 Accuracy 0.1679
Time taken for 1 epoch: 171.61904525756836 secs



 34%|█████████████████████████▊                                                  | 34/100 [2:04:34<3:34:16, 194.80s/it]

Epoch 34 Loss 2.8894 Accuracy 0.1677
Time taken for 1 epoch: 171.6206510066986 secs

Epoch 35 Loss 2.8893 Accuracy 0.1673
Time taken for 1 epoch: 171.62686443328857 secs

writing file..


 35%|██████████████████████████▌                                                 | 35/100 [2:12:07<4:55:01, 272.32s/it]

Recall: 0.16483228511530398


 36%|███████████████████████████▎                                                | 36/100 [2:14:58<4:18:07, 242.00s/it]

Epoch 36 Loss 2.8892 Accuracy 0.1672
Time taken for 1 epoch: 171.22217464447021 secs



 37%|████████████████████████████                                                | 37/100 [2:17:51<3:52:05, 221.04s/it]

Epoch 37 Loss 2.8886 Accuracy 0.1679
Time taken for 1 epoch: 172.06705856323242 secs



 38%|████████████████████████████▉                                               | 38/100 [2:20:43<3:33:12, 206.33s/it]

Epoch 38 Loss 2.8891 Accuracy 0.1676
Time taken for 1 epoch: 171.91234183311462 secs



 39%|█████████████████████████████▋                                              | 39/100 [2:23:34<3:19:16, 196.01s/it]

Epoch 39 Loss 2.8886 Accuracy 0.1684
Time taken for 1 epoch: 171.86379289627075 secs

Epoch 40 Loss 2.8884 Accuracy 0.1675
Time taken for 1 epoch: 171.8562777042389 secs

writing file..


 40%|██████████████████████████████▍                                             | 40/100 [2:31:02<4:31:21, 271.36s/it]

Recall: 0.16627358490566038


 41%|███████████████████████████████▏                                            | 41/100 [2:33:54<3:57:44, 241.77s/it]

Epoch 41 Loss 2.8890 Accuracy 0.1683
Time taken for 1 epoch: 172.70634818077087 secs



 42%|███████████████████████████████▉                                            | 42/100 [2:36:46<3:33:30, 220.87s/it]

Epoch 42 Loss 2.8889 Accuracy 0.1684
Time taken for 1 epoch: 172.03401231765747 secs



 43%|████████████████████████████████▋                                           | 43/100 [2:39:40<3:16:25, 206.77s/it]

Epoch 43 Loss 2.8887 Accuracy 0.1679
Time taken for 1 epoch: 173.7770836353302 secs



 44%|█████████████████████████████████▍                                          | 44/100 [2:42:32<3:03:15, 196.34s/it]

Epoch 44 Loss 2.8886 Accuracy 0.1684
Time taken for 1 epoch: 171.9303469657898 secs

Epoch 45 Loss 2.8884 Accuracy 0.1684
Time taken for 1 epoch: 173.27389788627625 secs

writing file..


 45%|██████████████████████████████████▏                                         | 45/100 [2:50:37<4:19:10, 282.75s/it]

Recall: 0.16627358490566038


 46%|██████████████████████████████████▉                                         | 46/100 [2:53:27<3:44:15, 249.17s/it]

Epoch 46 Loss 2.8884 Accuracy 0.1680
Time taken for 1 epoch: 170.8151113986969 secs



 47%|███████████████████████████████████▋                                        | 47/100 [2:56:19<3:19:28, 225.81s/it]

Epoch 47 Loss 2.8881 Accuracy 0.1685
Time taken for 1 epoch: 171.3224914073944 secs



 48%|████████████████████████████████████▍                                       | 48/100 [2:59:11<3:01:38, 209.58s/it]

Epoch 48 Loss 2.8883 Accuracy 0.1684
Time taken for 1 epoch: 171.703040599823 secs



 49%|█████████████████████████████████████▏                                      | 49/100 [3:02:02<2:48:29, 198.22s/it]

Epoch 49 Loss 2.8885 Accuracy 0.1683
Time taken for 1 epoch: 171.70091104507446 secs

Epoch 50 Loss 2.8881 Accuracy 0.1684
Time taken for 1 epoch: 171.6985321044922 secs

writing file..


 50%|██████████████████████████████████████                                      | 50/100 [3:09:28<3:47:04, 272.49s/it]

Recall: 0.16470125786163523


 51%|██████████████████████████████████████▊                                     | 51/100 [3:12:19<3:17:44, 242.13s/it]

Epoch 51 Loss 2.8884 Accuracy 0.1684
Time taken for 1 epoch: 171.26714849472046 secs



 52%|███████████████████████████████████████▌                                    | 52/100 [3:15:11<2:56:50, 221.05s/it]

Epoch 52 Loss 2.8882 Accuracy 0.1685
Time taken for 1 epoch: 171.77669143676758 secs



 53%|████████████████████████████████████████▎                                   | 53/100 [3:18:03<2:41:37, 206.33s/it]

Epoch 53 Loss 2.8885 Accuracy 0.1685
Time taken for 1 epoch: 171.9012541770935 secs



 54%|█████████████████████████████████████████                                   | 54/100 [3:20:55<2:30:13, 195.94s/it]

Epoch 54 Loss 2.8878 Accuracy 0.1685
Time taken for 1 epoch: 171.68867230415344 secs

Epoch 55 Loss 2.8881 Accuracy 0.1684
Time taken for 1 epoch: 171.70012617111206 secs

writing file..
Recall: 0.1829140461215933


 55%|█████████████████████████████████████████▊                                  | 55/100 [3:28:45<3:28:38, 278.19s/it]

Saving checkpoint for epoch 55 at ./checkpoints/model_A\ckpt-3


 56%|██████████████████████████████████████████▌                                 | 56/100 [3:31:37<3:00:34, 246.25s/it]

Epoch 56 Loss 2.8881 Accuracy 0.1685
Time taken for 1 epoch: 171.63407278060913 secs



 57%|███████████████████████████████████████████▎                                | 57/100 [3:34:30<2:40:50, 224.44s/it]

Epoch 57 Loss 2.8883 Accuracy 0.1684
Time taken for 1 epoch: 173.44537734985352 secs



 58%|████████████████████████████████████████████                                | 58/100 [3:37:22<2:26:06, 208.73s/it]

Epoch 58 Loss 2.8880 Accuracy 0.1685
Time taken for 1 epoch: 171.98773503303528 secs



 59%|████████████████████████████████████████████▊                               | 59/100 [3:40:14<2:15:04, 197.67s/it]

Epoch 59 Loss 2.8879 Accuracy 0.1685
Time taken for 1 epoch: 171.81480526924133 secs

Epoch 60 Loss 2.8880 Accuracy 0.1683
Time taken for 1 epoch: 171.6340618133545 secs

writing file..


 60%|█████████████████████████████████████████████▌                              | 60/100 [3:47:42<3:01:51, 272.79s/it]

Recall: 0.16561844863731656


 61%|██████████████████████████████████████████████▎                             | 61/100 [3:50:35<2:37:43, 242.64s/it]

Epoch 61 Loss 2.8880 Accuracy 0.1685
Time taken for 1 epoch: 172.2824740409851 secs



 62%|███████████████████████████████████████████████                             | 62/100 [3:53:26<2:20:09, 221.31s/it]

Epoch 62 Loss 2.8880 Accuracy 0.1674
Time taken for 1 epoch: 171.45160627365112 secs



 63%|███████████████████████████████████████████████▉                            | 63/100 [3:56:18<2:07:18, 206.43s/it]

Epoch 63 Loss 2.8880 Accuracy 0.1679
Time taken for 1 epoch: 171.662211894989 secs



 64%|████████████████████████████████████████████████▋                           | 64/100 [3:59:09<1:57:35, 196.00s/it]

Epoch 64 Loss 2.8883 Accuracy 0.1683
Time taken for 1 epoch: 171.64720368385315 secs

Epoch 65 Loss 2.8879 Accuracy 0.1684
Time taken for 1 epoch: 172.17649388313293 secs

writing file..


 64%|████████████████████████████████████████████████▋                           | 64/100 [4:08:02<2:19:31, 232.54s/it]

Recall: 0.18029350104821804





## Spatial

In [36]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        #self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        #self.combine = CombineSubnets()
        #self.dr = DRModule(num_layers=num_layers)
        self.temp_fc = tf.keras.layers.Dense(70)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        #fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        #qr0 = self.combine(fc8, conv3_p)
        #qr = self.dr(qa, qb, qr0)
        qr = self.temp_fc(conv3_p[:, 0, 0, :])
        out = self.softmax(qr)
        return out

In [37]:
model_S = DRNet()

In [38]:
for sample in train_dataset:
    print(model_S(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

tf.Tensor(
[[0.01491421 0.0140573  0.01439006 ... 0.01386466 0.01364683 0.01480633]
 [0.01435929 0.0144551  0.01473911 ... 0.01404168 0.01382555 0.01458371]
 [0.01478864 0.01382029 0.01425514 ... 0.01338174 0.014387   0.0149344 ]
 ...
 [0.01565182 0.01440093 0.01487634 ... 0.01333662 0.01275025 0.01625833]
 [0.01427076 0.01435189 0.0146712  ... 0.01425032 0.01381755 0.01472828]
 [0.0147174  0.01417669 0.01456229 ... 0.01369194 0.01376129 0.0159472 ]], shape=(32, 70), dtype=float32)


In [39]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

In [40]:
checkpoint_path = "./checkpoints/model_S"
ckpt = tf.train.Checkpoint(model_S=model_S, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [41]:
@tf.function
def train_step(model, qa, qb, im, posdata, label):
    loss = 0
    
    with tf.GradientTape() as tape:
        y = model(qa, qb, im, posdata)
        loss += loss_object(label, y)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(label, y)

In [42]:
def eval_step(model, out_path='temp.npz', num_dets=50, ov_thresh=0.5):
    test_model(model, out_path)
    test_recall = eval_recall(out_path, num_dets, ov_thresh)
    return test_recall

In [None]:
max_recall = 0
early_stop_cnt = 0

for epoch in tqdm(range(EPOCHS)):
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()

    for sample in train_dataset:
        train_step(model_S, sample['qa'], sample['qb'], sample['im'], sample['posdata'], sample['labels'])
    
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                train_loss.result(), 
                                                train_accuracy.result()))
    
    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
    
    if (epoch+1) % 5 == 0:
        test_recall = eval_step(model_S)
        if(test_recall > max_recall):
            max_recall = test_recall
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                                 ckpt_save_path))
        else:
            early_stop_cnt += 1
        
    if early_stop_cnt == 10:
        break

  1%|▊                                                                               | 1/100 [00:47<1:18:05, 47.33s/it]

Epoch 1 Loss 2.2920 Accuracy 0.3820
Time taken for 1 epoch: 47.3262140750885 secs



  2%|█▌                                                                              | 2/100 [01:24<1:12:08, 44.17s/it]

Epoch 2 Loss 2.0493 Accuracy 0.4305
Time taken for 1 epoch: 36.71003818511963 secs



  3%|██▍                                                                             | 3/100 [01:48<1:01:40, 38.15s/it]

Epoch 3 Loss 1.9615 Accuracy 0.4428
Time taken for 1 epoch: 24.00798225402832 secs



  4%|███▎                                                                              | 4/100 [02:03<50:01, 31.27s/it]

Epoch 4 Loss 1.8659 Accuracy 0.4611
Time taken for 1 epoch: 15.126852989196777 secs

Epoch 5 Loss 1.7505 Accuracy 0.4827
Time taken for 1 epoch: 7.49103569984436 secs

writing file..
Recall: 0.11386268343815513


  5%|███▉                                                                           | 5/100 [06:37<2:45:00, 104.21s/it]

Saving checkpoint for epoch 5 at ./checkpoints/model_S\ckpt-1


  6%|████▊                                                                           | 6/100 [07:17<2:13:04, 84.94s/it]

Epoch 6 Loss 1.6153 Accuracy 0.5111
Time taken for 1 epoch: 39.9422709941864 secs



  7%|█████▌                                                                          | 7/100 [07:52<1:48:25, 69.95s/it]

Epoch 7 Loss 1.4691 Accuracy 0.5447
Time taken for 1 epoch: 34.871830224990845 secs



  8%|██████▍                                                                         | 8/100 [08:11<1:23:37, 54.53s/it]

Epoch 8 Loss 1.3153 Accuracy 0.5819
Time taken for 1 epoch: 18.47900080680847 secs



  9%|███████▏                                                                        | 9/100 [08:21<1:02:29, 41.20s/it]

Epoch 9 Loss 1.1824 Accuracy 0.6195
Time taken for 1 epoch: 9.998266220092773 secs

Epoch 10 Loss 1.0674 Accuracy 0.6535
Time taken for 1 epoch: 8.220048427581787 secs



  9%|███████▏                                                                        | 9/100 [12:05<2:02:19, 80.65s/it]


## Appearance + Spatial

In [None]:
checkpoint_path = "./checkpoints/model_A"
ckpt = tf.train.Checkpoint(model_A=model_A, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
model_A.get_layer(index=0).weights

In [None]:
checkpoint_path = "./checkpoints/model_S"
ckpt = tf.train.Checkpoint(model_S=model_S, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
model_S.get_layer(index=0).weights

In [None]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        self.combine = CombineSubnets()
        #self.dr = DRModule(num_layers=num_layers)
        self.temp_fc = tf.keras.layers.Dense(70)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        qr0 = self.combine(fc8, conv3_p)
        #qr = self.dr(qa, qb, qr0)
        qr = self.temp_fc(qr0)
        out = self.softmax(qr)
        return out

In [None]:
model_AS = DRNet()

In [None]:
for sample in train_dataset:
    print(model_AS(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

In [None]:
for i in range(len(model_AS.get_layer(index=0).weights)):
    model_AS.get_layer(index=0).weights[i].assign(model_A.get_layer(index=0).weights[i])

In [None]:
model_AS.get_layer(index=0).weights

In [None]:
for i in range(len(model_AS.get_layer(index=1).weights)):
    model_AS.get_layer(index=1).weights[i].assign(model_S.get_layer(index=0).weights[i])

In [None]:
model_AS.get_layer(index=1).weights

In [None]:
for sample in train_dataset:
    print(model_AS(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

In [None]:
checkpoint_path = "./checkpoints/model_AS"
ckpt = tf.train.Checkpoint(model_AS=model_AS, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
@tf.function
def train_step(model, qa, qb, im, posdata, label):
    loss = 0
    
    with tf.GradientTape() as tape:
        y = model(qa, qb, im, posdata)
        loss += loss_object(label, y)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(label, y)

In [None]:
def eval_step(model, out_path='temp.npz', num_dets=50, ov_thresh=0.5):
    test_model(model, out_path)
    test_recall = eval_recall(out_path, num_dets, ov_thresh)
    return test_recall

In [None]:
max_recall = 0
early_stop_cnt = 0

for epoch in tqdm(range(EPOCHS)):
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()

    for sample in train_dataset:
        train_step(model_AS, sample['qa'], sample['qb'], sample['im'], sample['posdata'], sample['labels'])
        
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                train_loss.result(), 
                                                train_accuracy.result()))
    
    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
    
    if (epoch+1) % 5 == 0:
        test_recall = eval_step(model_AS)
        if(test_recall > max_recall):
            max_recall = test_recall
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                                 ckpt_save_path))
        else:
            early_stop_cnt += 1
        
    if early_stop_cnt == 10:
        break

## Appearance + Spatial + DRNet

In [None]:
checkpoint_path = "./checkpoints/model_AS"
ckpt = tf.train.Checkpoint(model_AS=model_AS, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
model_AS.get_layer(index=0).weights

In [None]:
model_AS.get_layer(index=1).weights

In [None]:
model_AS.get_layer(index=2).weights

In [None]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        self.combine = CombineSubnets()
        self.dr = DRModule(num_layers=num_layers)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        qr0 = self.combine(fc8, conv3_p)
        qr = self.dr(qa, qb, qr0)
        out = self.softmax(qr)
        return out

In [None]:
model_ASD = DRNet()

In [None]:
for sample in train_dataset:
    print(model_ASD(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

In [None]:
for i in range(len(model_ASD.get_layer(index=0).weights)):
    model_ASD.get_layer(index=0).weights[i].assign(model_AS.get_layer(index=0).weights[i])

In [None]:
model_ASD.get_layer(index=0).weights

In [None]:
for i in range(len(model_ASD.get_layer(index=1).weights)):
    model_ASD.get_layer(index=1).weights[i].assign(model_AS.get_layer(index=1).weights[i])

In [None]:
model_ASD.get_layer(index=1).weights

In [None]:
for i in range(len(model_ASD.get_layer(index=2).weights)):
    model_ASD.get_layer(index=2).weights[i].assign(model_AS.get_layer(index=2).weights[i])

In [None]:
model_ASD.get_layer(index=2).weights

In [None]:
for sample in train_dataset:
    print(model_ASD(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

In [None]:
checkpoint_path = "./checkpoints/model_ASD"
ckpt = tf.train.Checkpoint(model_ASD=model_ASD, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
@tf.function
def train_step(model, qa, qb, im, posdata, label):
    loss = 0
    
    with tf.GradientTape() as tape:
        y = model(qa, qb, im, posdata)
        loss += loss_object(label, y)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(label, y)

In [None]:
def eval_step(model, out_path='temp.npz', num_dets=50, ov_thresh=0.5):
    test_model(model, out_path)
    test_recall = eval_recall(out_path, num_dets, ov_thresh)
    return test_recall

In [None]:
max_recall = 0
early_stop_cnt = 0

for epoch in tqdm(range(EPOCHS)):
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()

    for sample in train_dataset:
        train_step(model_ASD, sample['qa'], sample['qb'], sample['im'], sample['posdata'], sample['labels'])
        
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                train_loss.result(), 
                                                train_accuracy.result()))
    
    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
    
    if (epoch+1) % 5 == 0:
        test_recall = eval_step(model_ASD)
        if(test_recall > max_recall):
            max_recall = test_recall
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                                 ckpt_save_path))
        else:
            early_stop_cnt += 1
        
    if early_stop_cnt == 10:
        break