In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import cv2
import os, time, pickle
from tqdm.notebook import tqdm

from ssd_data import InputGenerator
from sl_utils import PriorUtil
from sl_training import SegLinkLoss, SegLinkFocalLoss
from sl_metric import evaluate_results
from ssd_metric import fscore
from utils.model import load_weights, count_parameters, calc_memory_usage
from utils.training import Logger, LearningRateDecay, MetricUtility

### Data

In [None]:
from data_synthtext import GTUtility

file_name = 'gt_util_synthtext_seglink.pkl'
with open(file_name, 'rb') as f:
    gt_util = pickle.load(f)
gt_util_train, gt_util_val = gt_util.split(0.9)
gt_util_train, _ = gt_util.split(0.25)
gt_util_val, _ = gt_util.split(0.25)

print(gt_util)

### Model

In [None]:
from sl_model import SL384x512_dense
model = SL384x512_dense()

In [None]:
for i in range(len(model.source_layers)):
    l = model.source_layers[i]
    print('%-2s %s' %(i, l.get_shape().as_list()))
print()
count_parameters(model)
calc_memory_usage(model)

c = 0
for l in model.layers:
    if l.__class__.__name__ == "Conv2D":
        c += 1
print(c)

In [None]:
from ssd_data import preprocess

inputs = []
images = []
data = []

gtu = gt_util_val
image_size = model.image_size

np.random.seed(1337)

for i in [0]: #np.random.randint(0, gtu.num_samples, 16):

    img_path = os.path.join(gtu.image_path, gtu.image_names[i])
    img = cv2.imread(img_path)
    print('img_shape', img.shape)
    
    image_size_cv = image_size[::-1]
    print('image_size_ssd', image_size)
    print('image_size_cv ', image_size_cv)
    inpt = preprocess(img, image_size_cv)
    inputs.append(inpt)
    print('image_size_inp', inpt.shape)
    
    img = cv2.resize(img, image_size_cv, cv2.INTER_LINEAR).astype('float32') # should we do resizing
    print('image_size_img', img.shape)
    img = img[:, :, (2,1,0)] # BGR to RGB
    img /= 255
    images.append(img)
    
    boxes = gtu.data[i]
    data.append(boxes)
    
    print()

inputs = np.asarray(inputs)

test_idx = 0
test_input = inputs[test_idx]
test_img = images[test_idx]
test_gt = data[test_idx]

### Encoding/Decoding

In [None]:
prior_util = PriorUtil(model)

plt.figure(figsize=[12,9])
plt.axis('off')
plt.axis('equal')
plt.imshow(test_img)

test_encoded_gt = prior_util.encode(test_gt, debug=False)

loc_idxs = list(range(1000))

for m_idx in [5]:
#for m_idx in [0,1,2,3,4,5]:
    #prior_util.prior_maps[m_idx-1].plot_locations()
    m = prior_util.prior_maps[m_idx]
    m.plot_locations()
    #m.plot_boxes(loc_idxs)
    #prior_util.plot_neighbors(m_idx, loc_idxs, cross_layer=False)
    prior_util.plot_neighbors(m_idx, loc_idxs, inter_layer=False)
    prior_util.plot_assignment(m_idx)
    
plt.show()

dummy_output = np.copy(test_encoded_gt)
#dummy_output[:,2:4] += np.random.randn(*dummy_output[:,2:4].shape)*0.05

plt.figure(figsize=[12,9])
ax = plt.gca()
plt.imshow(test_img)
res = prior_util.decode(dummy_output, debug=False, debug_combining=True)
#res = decode(prior_util, dummy_output, debug=False)
prior_util.plot_gt()
prior_util.plot_results(res)
plt.axis('off'); plt.xlim(0, image_size[1]); plt.ylim(image_size[0],0)
plt.show()

### Training

In [None]:
epochs = 100
initial_epoch = 0
batch_size = 6
freeze = []
experiment = 'sl384x512_synthtext'

prior_util = PriorUtil(model)

#optimizer = tf.optimizers.SGD(learning_rate=1e-3, momentum=0.9, decay=0, nesterov=True)
optimizer = tf.optimizers.Adam(learning_rate=1e-3, beta_1=0.9, beta_2=0.999, epsilon=0.001, decay=0.0)

#loss = SegLinkLoss(lambda_offsets=1.0, lambda_links=1.0, neg_pos_ratio=3.0)
loss = SegLinkFocalLoss(lambda_segments=100.0, lambda_offsets=1.0, lambda_links=100.0,
                        gamma_segments=2, gamma_links=2, first_map_size=(96,128))

#regularizer = None
regularizer = keras.regularizers.l2(5e-4) # None if disabled

gen_train = InputGenerator(gt_util_train, prior_util, batch_size, model.image_size, augmentation=False)
gen_val = InputGenerator(gt_util_val, prior_util, batch_size, model.image_size, augmentation=False)


dataset_train, dataset_val = gen_train.get_dataset(), gen_val.get_dataset()
iterator_train, iterator_val = iter(dataset_train), iter(dataset_val)

checkdir = './checkpoints/' + time.strftime('%Y%m%d%H%M') + '_' + experiment

if not os.path.exists(checkdir):
    os.makedirs(checkdir)

with open(checkdir+'/source.py','wb') as f:
    source = ''.join(['# In[%i]\n%s\n\n' % (i, In[i]) for i in range(len(In))])
    f.write(source.encode())

print(checkdir)

for l in model.layers:
    l.trainable = not l.name in freeze
    if regularizer and l.__class__.__name__.startswith('Conv'):
        model.add_loss(lambda l=l: regularizer(l.kernel))

metric_util = MetricUtility(loss.metric_names, logdir=checkdir)

@tf.function
def step(x, y_true, training=False):
    if training:
        with tf.GradientTape() as tape:
            y_pred = model(x, training=True)
            metric_values = loss.compute(y_true, y_pred)
            total_loss = metric_values['loss']
            if len(model.losses):
                total_loss += tf.add_n(model.losses)
        gradients = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    else:
        y_pred = model(x, training=True)
        metric_values = loss.compute(y_true, y_pred)
    return metric_values

#tf.profiler.experimental.start('./tblog')

for k in tqdm(range(initial_epoch, epochs), 'total', leave=False):
    print('\nepoch %i/%i' % (k+1, epochs))
    metric_util.on_epoch_begin()

    for i in tqdm(range(gen_train.num_batches//4), 'training', leave=False):
        x, y_true = next(iterator_train)
        metric_values = step(x, y_true, training=True)
        metric_util.update(metric_values, training=True)
        #if i == 100: break
    
    model.save_weights(checkdir+'/weights.%03i.h5' % (k+1,))

    for i in tqdm(range(gen_val.num_batches), 'validation', leave=False):
        x, y_true = next(iterator_val)
        metric_values = step(x, y_true, training=False)
        metric_util.update(metric_values, training=False)
        #if i == 10: break

    metric_util.on_epoch_end(verbose=1)
    #if k == 1: break

#tf.profiler.experimental.stop()

### Predict

In [None]:
weights_path = './checkpoints/201809251754_sl384x512_synthtext/weights.020.h5'
segment_threshold = 0.50; link_threshold = 0.45
#load_weights(model, weights_path)
model.load_weights(weights_path)

In [None]:
_, inputs, images, data = gt_util_val.sample_random_batch(batch_size=1024, input_size=model.image_size)

preds = model.predict(inputs, batch_size=1, verbose=1)

In [None]:
results = [prior_util.decode(p, segment_threshold, link_threshold) for p in preds]

for i in range(8):
    plt.figure(figsize=[8]*2)
    plt.imshow(images[i])
    prior_util.encode(data[i])
    prior_util.plot_gt()
    prior_util.plot_results(results[i])
    plt.axis('off')
    #plt.savefig('plots/%s_test_%03i.pgf' % (plot_name, i), bbox_inches='tight')
    plt.show()

### Grid Search

In [None]:
#steps_seg, steps_lnk = np.arange(0.1, 1, 0.1), np.arange(0.1, 1, 0.1)
steps_seg, steps_lnk = np.arange(0.05, 1, 0.05), np.arange(0.05, 1, 0.05)

fmes_grid = np.zeros((len(steps_seg),len(steps_lnk)))

for i, st in enumerate(steps_seg):
    for j, lt in enumerate(steps_lnk):
        results = [prior_util.decode(p, st, lt) for p in preds]
        TP, FP, FN = evaluate_results(data, results, image_size=image_size)
        recall = TP / (TP+FN)
        precision = TP / (TP+FP)
        fmes = fscore(precision, recall)
        fmes_grid[i,j] = fmes
        print('segment_threshold %.2f link_threshold %.2f f-measure %.2f' % (st, lt, fmes))

In [None]:
max_idx = np.argmax(fmes_grid)
max_idx1 = max_idx//fmes_grid.shape[0]
max_idx2 = max_idx%fmes_grid.shape[0]
print(steps_seg[max_idx1], steps_seg[max_idx2], fmes_grid[max_idx1,max_idx2])
plt.figure(figsize=[8]*2)
plt.imshow(fmes_grid, cmap='jet', origin='lower', interpolation='bicubic') # nearest, bilinear, bicubic
plt.title('f-measure')
plt.xticks(range(len(steps_lnk)), steps_lnk.astype('float32'))
plt.yticks(range(len(steps_seg)), steps_seg.astype('float32'))
plt.plot(max_idx2, max_idx1, 'or')
plt.xlabel('link_threshold')
plt.ylabel('segment_threshold')
plt.grid()
#plt.savefig('plots/%s_gridsearch.pgf' % (plot_name), bbox_inches='tight')
plt.show()