In [71]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2

import os
from datetime import datetime
import sys
import time

sys.path.insert(0, "../scripts/")

from utils import train_test_split, image_batch_generator, get_train_augmentation, random_batch_generator, get_table_augmentation
from utils import DATASET_PATH, DS_IMAGES, DS_MASKS, SaveValidSamplesCallback
import utils
from metrics import iou, f1_score, jaccard_distance
import metrics
from vis import anshow, imshow
import vis
from models import TableNet, load_unet_model

IMAGE_NAMES = os.listdir(DS_IMAGES)

import importlib
importlib.reload(metrics)
importlib.reload(vis)
importlib.reload(utils)

<module 'utils' from '/home/cseadmin/Tigran/table_extractor/notebooks/../scripts/utils.py'>

In [54]:
with tf.device("/CPU:0"):
    model = load_unet_model((512, 512), 2, weight_decay=0.1)
    optim = tf.keras.optimizers.Adam()

In [55]:
with tf.device("/CPU:0"):
    checkpoint = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=model)
    print(f"loading checkpoint {'training_checkpoints/' + '2022.08.03-13/ckpt-347'}")
    status = checkpoint.restore("../scripts/training_checkpoints/" + '2022.08.03-13/ckpt-347')

loading checkpoint training_checkpoints/2022.08.03-13/ckpt-347


In [56]:
loss_fn = jaccard_distance

train_names, valid_names = train_test_split(IMAGE_NAMES, shuffle=True, random_state=2022, test_size=0.2)

In [57]:
len(valid_names)

840

In [10]:
valid_batch_generator = image_batch_generator(
                            valid_names, 
                            batch_size=1, 
                            resize_shape=(512, 512),
                            aug_transform=None,
                            normalize=True, include_edges_as_band=True
                        )

In [58]:
def print_progress(name, metrics, step, all_steps):
    str_prog = f"{all_steps}/{step}: "
    str_prog += "{} loss {:.4f}, tf_iou {:.4f}, iou {:.4f}, f1 {:.4f}, prec {:.4f}, rec {:.4f}".format(
        name,
        np.mean(metrics["loss"]),
        np.mean(metrics["tf_iou"]),  
        np.mean(metrics["iou"]), 
        np.mean(metrics["f1"]),
        np.mean(metrics["precision"]),
        np.mean(metrics["recall"])
    )

    print(str_prog, end='\r')

In [19]:
# with tf.device("CPU:0"):
#     val_metrics = {n:[] for n in ("loss", "tf_iou", "iou", "f1", "precision", "recall")}

#     mean_time = []

#     # valid loop
#     # with tf.device("GPU:0"):

#     for i, (batch_X, batch_y) in enumerate(valid_batch_generator):

#         start = time.time()

#         # print(batch_X.dtype, batch_y.dtype)
#         # print(batch_y.min(), batch_y.max())
#         batch_X = tf.convert_to_tensor(batch_X, dtype=tf.float32)
#         batch_y = tf.convert_to_tensor(batch_y, dtype=tf.float32)

#         logits = model(batch_X, training=False)
#         logits = tf.squeeze(logits)

#         # rgb_masks = np.array(logits * 255, dtype=np.uint8)
#         # final_masks = []

#         # for mask in rgb_masks:

#         #     thresh = thresh = cv2.threshold(mask, 200, 255, cv2.THRESH_BINARY)[1]
#         #     contours, hierarchy = cv2.findContours(image=thresh, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE)

#         #     final_mask = np.zeros_like(mask)

#         #     for ind, c in enumerate(contours):
#         #         if len(c) > 100:
#         #             min_x, max_x = np.squeeze(c)[:, 0].min(), np.squeeze(c)[:, 0].max()
#         #             min_y, max_y = np.squeeze(c)[:, 1].min(), np.squeeze(c)[:, 1].max()
#         #             final_mask[min_y:max_y, min_x:max_x] = 1

#         #     final_masks.append(final_mask)

#         # final_masks = tf.convert_to_tensor(final_masks, np.float32)

#         # print(np.unique(final_masks))

#         # print([metrics.iou(gt, pr) for gt, pr in zip(batch_y, final_masks)])
#         # break

#         loss_value = loss_fn(batch_y, logits)
#         # print(loss_value)

#         (
#             iou_value, tf_iou_value,
#             f1_score_value, 
#             presicion_value, 
#             recall_value
#         ) = metrics.calculate_metrics(batch_y, logits)
#         val_metrics["loss"].append(np.mean(loss_value))
#         val_metrics["tf_iou"].append(tf_iou_value)
#         val_metrics["iou"].append(iou_value)
#         val_metrics["f1"].append(f1_score_value)
#         val_metrics["precision"].append(presicion_value)
#         val_metrics["recall"].append(recall_value)

#         # break

#         mean_time.append(time.time() - start)
#         # print(f"{len(valid_names)//8}/{i+1}", end='\r')
#         print_progress("valid", val_metrics, i+1, len(valid_names)//8)
#         # break
#         if (i + 1) >= len(valid_names)//8:
#             break

In [None]:
loss_fn()

In [59]:
with tf.device("/CPU:0"):
    batch_X, batch_y = utils.read_inf_sample(valid_names[:50], (512, 512))

In [73]:
val_metrics = {n:[] for n in ("loss", "iou", "tf_iou", "f1", "precision", "recall")}
mean_time = []

with tf.device("/CPU:0"):

    for i, (X, y) in enumerate(zip(batch_X, batch_y)):

        start = time.time()

        raw = model(tf.expand_dims(X, 0), training=False)
        raw = tf.squeeze(raw)

        pred1 = utils.preprocess_raw_output(raw, 2, 100)
        pred2 = utils.preprocess_raw_output(pred1, 2, 2000, max_seg_dist=40)
        # pred2 = utils.preprocess_raw_output(pred2, 2, 0, max_seg_dist=20)
        # pred2 = utils.preprocess_raw_output(pred2, 2, 0)

        mean_time.append(time.time() - start)
        # print(y.shape, pred.shape)

        loss_value = loss_fn(tf.expand_dims(y, 0), tf.expand_dims(pred2, 0))
        (
            iou_value, tf_iou_value,
            f1_score_value,
            presicion_value, 
            recall_value
        ) = metrics.calculate_metrics([y], [pred2])

        val_metrics["loss"].append(np.mean(loss_value))
        val_metrics["iou"].append(iou_value)
        val_metrics["tf_iou"].append(tf_iou_value)
        val_metrics["f1"].append(f1_score_value)
        val_metrics["precision"].append(presicion_value)
        val_metrics["recall"].append(recall_value)

        # print(X[:, :, 0].shape, raw.shape, y.shape, pred.shape)
        final_img = cv2.hconcat([
            np.array(X[:, :, 0] * 255, dtype=np.uint8),
            np.array(raw * 255, dtype=np.uint8),
            np.array(y * 255, dtype=np.uint8),
            np.array(pred1 * 255, dtype=np.uint8),
            np.array(pred2 * 255, dtype=np.uint8)
            # np.array(thresh)
        ])
        cv2.imwrite("pred_samples/{:.4f}_{}_image.png".format(iou_value, i), final_img)
        # break

In [50]:
np.mean(mean_time)

0.7750808954238891

In [51]:
print_progress("valid", val_metrics, 1, 1)

1/1: valid loss 22.9774, tf_iou 0.8437, iou 0.6292, f1 0.8094, prec 0.6996, rec 0.6683