In [1]:
import os
import time
import json
import math
from tqdm import tqdm

import cv2
import numpy as np
import tensorflow as tf

# 데이터 전처리

In [2]:
def getUnionBBox(aBB, bBB, ih, iw):
    margin = 10
    return [max(0, min(aBB[0], bBB[0]) - margin),
            max(0, min(aBB[1], bBB[1]) - margin),
            min(iw, max(aBB[2], bBB[2]) + margin),
            min(ih, max(aBB[3], bBB[3]) + margin)]

In [3]:
def getAppr(im, bb):
    subim = im[bb[1] : bb[3], bb[0] : bb[2], :]
    subim = cv2.resize(subim, None, None, 224.0 / subim.shape[1], 224.0 / subim.shape[0], interpolation=cv2.INTER_LINEAR)
    subim = tf.keras.applications.vgg16.preprocess_input(subim)
    return subim

In [4]:
def getDualMask(ih, iw, bb):
    rh = 32.0 / ih
    rw = 32.0 / iw
    x1 = max(0, int(math.floor(bb[0] * rw)))
    x2 = min(32, int(math.ceil(bb[2] * rw)))
    y1 = max(0, int(math.floor(bb[1] * rh)))
    y2 = min(32, int(math.ceil(bb[3] * rh)))
    mask = np.zeros((32, 32))
    mask[y1 : y2, x1 : x2] = 1
    assert(mask.sum() == (y2 - y1) * (x2 - x1))
    return mask

In [5]:
def forward_batch(model, ims, poses, qas, qbs):
    test_set = []
    for i in range(ims.shape[0]):
        test_set.append({'qa': qas[i], 'qb': qbs[i], 'im': ims[i], 'posdata': poses[i]})

    test_elements = tuple(test_set)
    test_dataset = tf.data.Dataset.from_generator(
        lambda: test_elements, {'qa': tf.int32, 'qb': tf.int32, 'im': tf.float32, 'posdata': tf.float32}
    )
    test_dataset = test_dataset.cache().batch(ims.shape[0]).prefetch(buffer_size=AUTOTUNE)
    
    for sample in test_dataset:
        itr_pred = model(sample['qa'], sample['qb'], sample['im'], sample['posdata'])
    
    return itr_pred

In [6]:
def test_model(model, out_path):
    num_img = len(image_paths)
    num_class = 101
    thresh = 0.05
    batch_size = 20
    pred = []
    pred_bboxes = []

    for i in range(num_img):
        im = cv2.imread(image_paths[i]).astype(np.float32, copy=False)
        ih = im.shape[0]
        iw = im.shape[1]
        gts = np.array(all_gts[i])
        gt_bboxes = np.array(all_gt_bboxes[i])
        num_gts = gts.shape[0]
        pred.append([])
        pred_bboxes.append([])
        ims = []
        poses = []
        qas = []
        qbs = []
        for j in range(num_gts):
            sub = gt_bboxes[j, 0, :]
            obj = gt_bboxes[j, 1, :]
            rBB = getUnionBBox(sub, obj, ih, iw)
            rAppr = getAppr(im, rBB)
            rMask = np.array([getDualMask(ih, iw, sub), getDualMask(ih, iw, obj)])
            ims.append(rAppr)
            poses.append(rMask)
            qa = np.zeros(num_class - 1)
            qa[gts[j, 0] - 1] = 1
            qb = np.zeros(num_class - 1)
            qb[gts[j, 2] - 1] = 1
            qas.append(qa)
            qbs.append(qb)
        if len(ims) == 0:
            continue
        ims = np.array(ims)
        poses = np.array(poses)
        qas = np.array(qas)
        qbs = np.array(qbs)
        poses = poses.transpose((0, 2, 3, 1))
        _cursor = 0
        itr_pred = None
        num_ins = ims.shape[0]
        while _cursor < num_ins:
            _end_batch = min(_cursor + batch_size, num_ins)
            itr_pred_batch = forward_batch(model, ims[_cursor : _end_batch], poses[_cursor : _end_batch], qas[_cursor : _end_batch], qbs[_cursor : _end_batch])
            if itr_pred is None:
                itr_pred = itr_pred_batch
            else:
                itr_pred = np.vstack((itr_pred, itr_pred_batch))
            _cursor = _end_batch

        for j in range(num_gts):
            sub = gt_bboxes[j, 0, :]
            obj = gt_bboxes[j, 1, :]
            for k in range(itr_pred.shape[1]):
                if itr_pred[j, k] < thresh: 
                    continue
                pred[i].append([itr_pred[j, k], 1, 1, gts[j, 0], k, gts[j, 2]])
                pred_bboxes[i].append([sub, obj])
        pred[i] = np.array(pred[i])
        pred_bboxes[i] = np.array(pred_bboxes[i])

    print("writing file..")
    np.savez(out_path, pred=pred, pred_bboxes=pred_bboxes)

In [7]:
def computeArea(bb):
    return max(0, bb[2] - bb[0] + 1) * max(0, bb[3] - bb[1] + 1)

In [8]:
def computeIoU(bb1, bb2):
    ibb = [max(bb1[0], bb2[0]), \
        max(bb1[1], bb2[1]), \
        min(bb1[2], bb2[2]), \
        min(bb1[3], bb2[3])]
    iArea = computeArea(ibb)
    uArea = computeArea(bb1) + computeArea(bb2) - iArea
    return (iArea + 0.0) / uArea

In [9]:
def computeOverlap(detBBs, gtBBs):
    aIoU = computeIoU(detBBs[0, :], gtBBs[0, :])
    bIoU = computeIoU(detBBs[1, :], gtBBs[1, :])
    return min(aIoU, bIoU)

In [10]:
def eval_recall(det_file_path, num_dets=50, ov_thresh=0.5):
    det_file = np.load(det_file_path, allow_pickle=True)
    dets = det_file['pred']
    det_bboxes = det_file['pred_bboxes']
    num_img = len(dets)
    tp = []
    fp = []
    score = []
    total_num_gts = 0
    for i in range(num_img):
        gts = np.array(all_gts[i])
        gt_bboxes = np.array(all_gt_bboxes[i])
        num_gts = gts.shape[0]
        total_num_gts += num_gts
        gt_detected = np.zeros(num_gts)
        if isinstance(dets[i], np.ndarray) and dets[i].shape[0] > 0:
            det_score = np.log(dets[i][:, 0]) + np.log(dets[i][:, 1]) + np.log(dets[i][:, 2])
            inds = np.argsort(det_score)[::-1]
            if num_dets > 0 and num_dets < len(inds):
                inds = inds[:num_dets]
            top_dets = dets[i][inds, 3:]
            top_scores = det_score[inds]
            top_det_bboxes = det_bboxes[i][inds, :]
            temp_num_dets = len(inds)
            for j in range(temp_num_dets):
                ov_max = 0
                arg_max = -1
                for k in range(num_gts):
                    if gt_detected[k] == 0 and top_dets[j, 0] == gts[k, 0] and top_dets[j, 1] == gts[k, 1] and top_dets[j, 2] == gts[k, 2]:
                        ov = computeOverlap(top_det_bboxes[j, :, :], gt_bboxes[k, :, :])
                        if ov >= ov_thresh and ov > ov_max:
                            ov_max = ov
                            arg_max = k
                if arg_max != -1:
                    gt_detected[arg_max] = 1
                    tp.append(1)
                    fp.append(0)
                else:
                    tp.append(0)
                    fp.append(1)
                score.append(top_scores[j])
    score = np.array(score)
    tp = np.array(tp)
    fp = np.array(fp)
    inds = np.argsort(score)
    inds = inds[::-1]
    tp = tp[inds]
    fp = fp[inds]
    tp = np.cumsum(tp)
    fp = np.cumsum(fp)
    recall = (tp + 0.0) / total_num_gts
    top_recall = recall[-1]
    print('Recall:', top_recall)
    return top_recall

In [11]:
dataset = './reltrain.json'
nclass = 100

samples = json.load(open(dataset))
num_instance = len(samples)

In [12]:
qas = []
qbs = []
ims = []
poses = []
labels = []

for i in range(num_instance):
    sample = samples[i]
    im = cv2.imread(sample["imPath"]).astype(np.float32, copy=False)
    ih = im.shape[0]
    iw = im.shape[1]
    qa = np.zeros(nclass)
    qa[sample["aLabel"] - 1] = 1
    qas.append(qa)
    qb = np.zeros(nclass)
    qb[sample["bLabel"] - 1] = 1
    qbs.append(qb)
    ims.append(getAppr(im, sample["rBBox"]))
    poses.append([getDualMask(ih, iw, sample["aBBox"]), 
                  getDualMask(ih, iw, sample["bBBox"])])
    labels.append(sample["rLabel"])

In [13]:
poses = np.array(poses).transpose((0, 2, 3, 1))

qa = np.array(qas)
qb = np.array(qbs)
im = np.array(ims)
posdata = np.array(poses)
labels = np.array(labels)

In [14]:
train_set = []

for i in range(num_instance):
    train_set.append({'qa': qa[i], 'qb': qb[i], 'im': im[i], 'posdata': posdata[i], 'labels': labels[i]})
    
train_set, valid_set = train_set[:-7500], train_set[-7500:]
train_elements = tuple(train_set)
valid_elements = tuple(valid_set)

In [15]:
train_dataset = tf.data.Dataset.from_generator(
    lambda: train_elements, {'qa': tf.int32, 'qb': tf.int32, 'im': tf.float32, 'posdata': tf.float32, 'labels': tf.int32}
)

In [16]:
valid_dataset = tf.data.Dataset.from_generator(
    lambda: valid_elements, {'qa': tf.int32, 'qb': tf.int32, 'im': tf.float32, 'posdata': tf.float32, 'labels': tf.int32}
)

In [17]:
for i, sample in enumerate(train_dataset.take(3)):
    print(i+1, ':', sample)

1 : {'qa': <tf.Tensor: id=73, shape=(100,), dtype=int32, numpy=
array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>, 'qb': <tf.Tensor: id=74, shape=(100,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0])>, 'im': <tf.Tensor: id=70, shape=(224, 224, 3), dtype=float32, numpy=
array([[[127.84672  , 124.00672  , 131.32     ],
        [130.03198  , 127.19198  , 130.29099  ],
        [133.05893  , 129.12054  , 

In [18]:
for i, sample in enumerate(valid_dataset.take(3)):
    print(i+1, ':', sample)

1 : {'qa': <tf.Tensor: id=95, shape=(100,), dtype=int32, numpy=
array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>, 'qb': <tf.Tensor: id=96, shape=(100,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>, 'im': <tf.Tensor: id=92, shape=(224, 224, 3), dtype=float32, numpy=
array([[[ -31.561043 ,  -72.40103  , -103.30203  ],
        [ -34.46836  ,  -75.26817  , -106.161194 ],
        [ -36.227394 ,  -75.0

In [19]:
BUFFER_SIZE = num_instance
BATCH_SIZE = 32
AUTOTUNE = tf.data.experimental.AUTOTUNE
EPOCHS = 100

In [20]:
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).prefetch(buffer_size=AUTOTUNE)
valid_dataset = valid_dataset.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

In [21]:
test_dataset = './reltest.json'
nclass = 100

test_samples = json.load(open(test_dataset))
test_num_instance = len(test_samples)

In [22]:
image_paths = []
gt_label = []
gt_box = []
j = 0

for i in range(test_num_instance):
    if(i == 0):
        gt_label.append([])
        gt_box.append([])
        img_path = test_samples[i]['imPath']
        image_paths.append(img_path)
        gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
        gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])
    else:
        if(img_path == test_samples[i]['imPath']):
            gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
            gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])
        else:
            j += 1
            gt_label.append([])
            gt_box.append([])
            img_path = test_samples[i]['imPath']
            image_paths.append(img_path)
            gt_label[j].append([test_samples[i]['aLabel'], test_samples[i]['rLabel'], test_samples[i]['bLabel']])
            gt_box[j].append([test_samples[i]['aBBox'], test_samples[i]['bBBox']])

In [23]:
all_gts = np.array(gt_label)
all_gt_bboxes = np.array(gt_box)

# 모델

In [24]:
class AppearanceSubnet(tf.keras.layers.Layer):
    def __init__(self, input_shape):
        super(AppearanceSubnet, self).__init__()
        
        self.vgg16 = tf.keras.applications.vgg16.VGG16(include_top=False, 
                                                       weights='imagenet', 
                                                       input_shape=input_shape)
        
        self.fc = tf.keras.Sequential([
            tf.keras.layers.Flatten(name='flat'),
            tf.keras.layers.Dense(4096, name='fc6'),
            tf.keras.layers.ReLU(name='relu6'),
            tf.keras.layers.Dense(4096, name='fc7'),
            tf.keras.layers.ReLU(name='relu7'),
            tf.keras.layers.Dense(256, kernel_initializer='glorot_normal', name='fc8'),
            tf.keras.layers.ReLU(name='relu8')
        ]) 
        
    def call(self, x):
        out = self.vgg16(x)
        out = self.fc(out)
        return out

In [25]:
class SpatialSubnet(tf.keras.layers.Layer):
    def __init__(self, input_shape):
        super(SpatialSubnet, self).__init__()
        
        self.conv = tf.keras.Sequential([
            tf.keras.layers.Conv2D(96, 5, 2, padding='same', input_shape=input_shape, name='conv1_p'),
            tf.keras.layers.ReLU(name='relu1_p'),
            tf.keras.layers.Conv2D(128, 5, 2, padding='same', name='conv2_p'),
            tf.keras.layers.Conv2D(64, 8, name='conv3_p'),
            tf.keras.layers.ReLU(name='relu3_p')
        ])
        
    def call(self, x):
        out = self.conv(x)
        return out

In [26]:
class CombineSubnets(tf.keras.layers.Layer):
    def __init__(self):
        super(CombineSubnets, self).__init__()
        
        self.concat1_c = tf.keras.layers.Concatenate(name='concat1_c')
        
        self.fc = tf.keras.Sequential([
            tf.keras.layers.Dense(128, name='fc2_c'),
            tf.keras.layers.ReLU(name='relu2_c'),
            tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiR_0'),
            tf.keras.layers.ReLU(name='relu_0')
        ])
        
    def call(self, x1, x2):
        x2 = x2[:, 0, 0, :]
        out = self.concat1_c([x1, x2])
        out = self.fc(out) # qr0
        return out

In [27]:
class DRLayer(tf.keras.layers.Layer):
    def __init__(self, i, activation=True):
        super(DRLayer, self).__init__()
        
        self.activation = activation
        
        self.PhiA = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiA_%d'%(i)) # qar_i
        self.PhiB = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiB_%d'%(i)) # qbr_i
        self.PhiR = tf.keras.layers.Dense(70, kernel_initializer='glorot_normal', name='PhiR_%d'%(i)) # q_i_r
        self.QSum = tf.keras.layers.Add(name='QSum_%d'%(i)) # qr_i_un
        if(activation == True):
            self.relu = tf.keras.layers.ReLU(name='relu_%d'%(i)) # qr_i
        
    def call(self, qa, qb, qr):
        qar = self.PhiA(qa)
        qbr = self.PhiB(qb)
        qr = self.PhiR(qr)
        qrun = self.QSum([qar, qbr, qr])
        if self.activation:
            qr = self.relu(qrun)
        else:
            qr = qrun
        return qr

In [28]:
class DRModule(tf.keras.layers.Layer):
    def __init__(self, num_layers):
        super(DRModule, self).__init__()
        
        self.num_layers = num_layers
        
        self.dr_layers = [DRLayer(i+1, activation=True) if((i+1) != num_layers) 
                          else DRLayer(i+1, activation=False) for i in range(num_layers)]
        
    def call(self, qa, qb, qr):
        for i in range(self.num_layers):
            qr = self.dr_layers[i](qa, qb, qr)
        return qr

In [29]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        self.combine = CombineSubnets()
        self.dr = DRModule(num_layers=num_layers)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        qr0 = self.combine(fc8, conv3_p)
        qr = self.dr(qa, qb, qr0)
        out = self.softmax(qr)
        return out

# 학습

## Spatial

In [30]:
class DRNet(tf.keras.Model):
    def __init__(self, num_layers=8, im_shape=(224, 224, 3), posdata_shape=(32, 32, 2)):
        super(DRNet, self).__init__()
        
        #self.appr = AppearanceSubnet(input_shape=im_shape)
        self.spatial = SpatialSubnet(input_shape=posdata_shape)
        #self.combine = CombineSubnets()
        #self.dr = DRModule(num_layers=num_layers)
        self.temp_fc = tf.keras.layers.Dense(70)
        self.softmax = tf.keras.layers.Softmax()
        
    def call(self, qa, qb, im, posdata):
        #fc8 = self.appr(im)
        conv3_p = self.spatial(posdata)
        #qr0 = self.combine(fc8, conv3_p)
        #qr = self.dr(qa, qb, qr0)
        qr = self.temp_fc(conv3_p[:, 0, 0, :])
        out = self.softmax(qr)
        return out

In [31]:
model_S = DRNet()

In [32]:
for sample in train_dataset:
    print(model_S(sample['qa'], sample['qb'], sample['im'], sample['posdata']))
    break

tf.Tensor(
[[0.01545621 0.01420001 0.01431436 ... 0.01339362 0.01447584 0.01402828]
 [0.01484783 0.01383682 0.01420806 ... 0.01439826 0.01460338 0.01449128]
 [0.01536262 0.01430611 0.01362091 ... 0.01404728 0.01462132 0.0143776 ]
 ...
 [0.01598221 0.01372999 0.01398286 ... 0.01397879 0.0146846  0.0135144 ]
 [0.015683   0.01375536 0.01374532 ... 0.01445804 0.0155774  0.013614  ]
 [0.01434089 0.0140108  0.01434687 ... 0.01392833 0.01426934 0.01419551]], shape=(32, 70), dtype=float32)


In [33]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='eval_accuracy')

In [34]:
checkpoint_path = "./checkpoints/model_S"
ckpt = tf.train.Checkpoint(model_S=model_S, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [35]:
@tf.function
def train_step(model, qa, qb, im, posdata, label):
    loss = 0
    
    with tf.GradientTape() as tape:
        y = model(qa, qb, im, posdata)
        loss += loss_object(label, y)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(label, y)

In [36]:
@tf.function
def eval_step(model, qa, qb, im, posdata, label):
    loss = 0
    
    y = model(qa, qb, im, posdata)
    loss += loss_object(label, y)
    
    eval_loss(loss)
    eval_accuracy(label, y)

In [37]:
def test_step(model, out_path='temp.npz', num_dets=50, ov_thresh=0.5):
    test_model(model, out_path)
    test_recall = eval_recall(out_path, num_dets, ov_thresh)
    return test_recall

In [38]:
max_acc = 0
early_stop_cnt = 0

for epoch in tqdm(range(EPOCHS)):
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    eval_loss.reset_states()
    eval_accuracy.reset_states()

    for sample in train_dataset:
        train_step(model_S, sample['qa'], sample['qb'], sample['im'], sample['posdata'], sample['labels'])
        
    for sample in valid_dataset:
        eval_step(model_S, sample['qa'], sample['qb'], sample['im'], sample['posdata'], sample['labels'])
    
    end = time.time()
    eval_acc = eval_accuracy.result()
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f} Eval_Loss {:.4f} Eval_Accuracy {:.4f}'.format(epoch + 1, 
                                                                                               train_loss.result(), 
                                                                                               train_accuracy.result(),
                                                                                               eval_loss.result(),
                                                                                               eval_accuracy.result())) 
    
    print ('Time taken for 1 epoch: {} secs\n'.format(end - start))
    
    if(eval_acc > max_acc):
        max_acc = eval_acc
        early_stop_cnt = 0
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                             ckpt_save_path))
    else:
        early_stop_cnt += 1
        
    if early_stop_cnt == 20:
        break

  1%|▊                                                                                 | 1/100 [00:29<49:02, 29.72s/it]

Epoch 1 Loss 2.3664 Accuracy 0.3673 Eval_Loss 2.1656 Eval_Accuracy 0.3969
Time taken for 1 epoch: 29.60547947883606 secs

Saving checkpoint for epoch 1 at ./checkpoints/model_S\ckpt-1


  2%|█▋                                                                                | 2/100 [00:37<37:45, 23.12s/it]

Epoch 2 Loss 2.0921 Accuracy 0.4199 Eval_Loss 2.0925 Eval_Accuracy 0.4225
Time taken for 1 epoch: 7.6280176639556885 secs

Saving checkpoint for epoch 2 at ./checkpoints/model_S\ckpt-2


  3%|██▍                                                                               | 3/100 [00:43<29:20, 18.15s/it]

Epoch 3 Loss 1.9905 Accuracy 0.4393 Eval_Loss 2.0781 Eval_Accuracy 0.4241
Time taken for 1 epoch: 6.460819244384766 secs

Saving checkpoint for epoch 3 at ./checkpoints/model_S\ckpt-3


  4%|███▎                                                                              | 4/100 [00:50<23:24, 14.63s/it]

Epoch 4 Loss 1.9037 Accuracy 0.4580 Eval_Loss 2.0645 Eval_Accuracy 0.4256
Time taken for 1 epoch: 6.329664707183838 secs

Saving checkpoint for epoch 4 at ./checkpoints/model_S\ckpt-4


  5%|████                                                                              | 5/100 [00:56<19:10, 12.11s/it]

Epoch 5 Loss 1.8055 Accuracy 0.4769 Eval_Loss 2.1148 Eval_Accuracy 0.4139
Time taken for 1 epoch: 6.222521066665649 secs



  6%|████▉                                                                             | 6/100 [01:02<16:08, 10.31s/it]

Epoch 6 Loss 1.6884 Accuracy 0.5032 Eval_Loss 2.2267 Eval_Accuracy 0.3919
Time taken for 1 epoch: 6.108520746231079 secs



  7%|█████▋                                                                            | 7/100 [01:08<14:03,  9.07s/it]

Epoch 7 Loss 1.5517 Accuracy 0.5328 Eval_Loss 2.3340 Eval_Accuracy 0.3945
Time taken for 1 epoch: 6.192621946334839 secs



  8%|██████▌                                                                           | 8/100 [01:15<12:33,  8.19s/it]

Epoch 8 Loss 1.4099 Accuracy 0.5651 Eval_Loss 2.4830 Eval_Accuracy 0.3557
Time taken for 1 epoch: 6.127825975418091 secs



  9%|███████▍                                                                          | 9/100 [01:21<11:29,  7.58s/it]

Epoch 9 Loss 1.2655 Accuracy 0.6036 Eval_Loss 2.6212 Eval_Accuracy 0.3767
Time taken for 1 epoch: 6.142406225204468 secs



 10%|████████                                                                         | 10/100 [01:27<10:43,  7.15s/it]

Epoch 10 Loss 1.1402 Accuracy 0.6369 Eval_Loss 2.7658 Eval_Accuracy 0.3645
Time taken for 1 epoch: 6.144956350326538 secs



 11%|████████▉                                                                        | 11/100 [01:33<10:20,  6.98s/it]

Epoch 11 Loss 1.0181 Accuracy 0.6736 Eval_Loss 3.0288 Eval_Accuracy 0.3569
Time taken for 1 epoch: 6.566776990890503 secs



 12%|█████████▋                                                                       | 12/100 [01:40<09:53,  6.74s/it]

Epoch 12 Loss 0.9291 Accuracy 0.6973 Eval_Loss 3.1464 Eval_Accuracy 0.3547
Time taken for 1 epoch: 6.19245719909668 secs



 13%|██████████▌                                                                      | 13/100 [01:46<09:29,  6.54s/it]

Epoch 13 Loss 0.8447 Accuracy 0.7246 Eval_Loss 3.4325 Eval_Accuracy 0.3460
Time taken for 1 epoch: 6.067538261413574 secs



 14%|███████████▎                                                                     | 14/100 [01:52<09:09,  6.39s/it]

Epoch 14 Loss 0.7851 Accuracy 0.7473 Eval_Loss 3.5310 Eval_Accuracy 0.3524
Time taken for 1 epoch: 6.032993793487549 secs



 15%|████████████▏                                                                    | 15/100 [01:58<08:54,  6.29s/it]

Epoch 15 Loss 0.7369 Accuracy 0.7602 Eval_Loss 3.6801 Eval_Accuracy 0.3475
Time taken for 1 epoch: 6.034044981002808 secs



 16%|████████████▉                                                                    | 16/100 [02:04<08:41,  6.21s/it]

Epoch 16 Loss 0.6900 Accuracy 0.7739 Eval_Loss 3.8810 Eval_Accuracy 0.3487
Time taken for 1 epoch: 6.037369728088379 secs



 17%|█████████████▊                                                                   | 17/100 [02:10<08:31,  6.16s/it]

Epoch 17 Loss 0.6567 Accuracy 0.7856 Eval_Loss 3.8698 Eval_Accuracy 0.3329
Time taken for 1 epoch: 6.046459674835205 secs



 18%|██████████████▌                                                                  | 18/100 [02:16<08:22,  6.13s/it]

Epoch 18 Loss 0.6207 Accuracy 0.7957 Eval_Loss 4.1130 Eval_Accuracy 0.3499
Time taken for 1 epoch: 6.032814979553223 secs



 19%|███████████████▍                                                                 | 19/100 [02:22<08:13,  6.10s/it]

Epoch 19 Loss 0.5972 Accuracy 0.8038 Eval_Loss 4.2418 Eval_Accuracy 0.3477
Time taken for 1 epoch: 6.030694246292114 secs



 20%|████████████████▏                                                                | 20/100 [02:28<08:06,  6.08s/it]

Epoch 20 Loss 0.5795 Accuracy 0.8096 Eval_Loss 4.1359 Eval_Accuracy 0.3403
Time taken for 1 epoch: 6.0364813804626465 secs



 21%|█████████████████                                                                | 21/100 [02:34<08:00,  6.08s/it]

Epoch 21 Loss 0.5544 Accuracy 0.8171 Eval_Loss 4.3376 Eval_Accuracy 0.3385
Time taken for 1 epoch: 6.064593315124512 secs



 22%|█████████████████▊                                                               | 22/100 [02:40<07:54,  6.08s/it]

Epoch 22 Loss 0.5367 Accuracy 0.8199 Eval_Loss 4.6056 Eval_Accuracy 0.3409
Time taken for 1 epoch: 6.087301731109619 secs



 23%|██████████████████▋                                                              | 23/100 [02:46<07:47,  6.08s/it]

Epoch 23 Loss 0.5238 Accuracy 0.8262 Eval_Loss 4.2430 Eval_Accuracy 0.3347
Time taken for 1 epoch: 6.05962872505188 secs



 23%|██████████████████▋                                                              | 23/100 [02:52<09:38,  7.51s/it]

Epoch 24 Loss 0.5094 Accuracy 0.8286 Eval_Loss 4.4407 Eval_Accuracy 0.3417
Time taken for 1 epoch: 6.0527708530426025 secs






In [39]:
checkpoint_path = "./checkpoints/model_S"
ckpt = tf.train.Checkpoint(model_S=model_S, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

Latest checkpoint restored!!


In [40]:
test_step(ckpt.model_S)

writing file..
Recall: 0.7497379454926625


0.7497379454926625