In [80]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2

import os
from datetime import datetime
import sys
import time

sys.path.insert(0, "../scripts/")

from utils import train_test_split, image_batch_generator, get_train_augmentation, random_batch_generator, get_table_augmentation
from utils import DATASET_PATH, DS_IMAGES, DS_MASKS, SaveValidSamplesCallback
import utils
from metrics import iou, f1_score, jaccard_distance
import metrics
from vis import anshow, imshow
import vis
from models import TableNet, load_unet_model

IMAGE_NAMES = os.listdir(DS_IMAGES)

import importlib
importlib.reload(metrics)
importlib.reload(vis)
importlib.reload(utils)

<module 'utils' from '/home/cseadmin/Tigran/table_extractor/notebooks/../scripts/utils.py'>

In [2]:
model = load_unet_model((512, 512), 2, weight_decay=0.1)
optim = tf.keras.optimizers.Adam()

2022-08-01 14:38:38.190246: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-08-01 14:38:38.990090: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9237 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:65:00.0, compute capability: 7.5


In [3]:
checkpoint = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=model)
print(f"loading checkpoint {'training_checkpoints/' + '2022.07.30-22/' + 'ckpt-238'}")
status = checkpoint.restore("../scripts/training_checkpoints/" + '2022.07.30-22/' + 'ckpt-238')

loading checkpoint training_checkpoints/2022.07.30-22/ckpt-238


In [4]:
loss_fn = jaccard_distance

train_names, valid_names = train_test_split(IMAGE_NAMES, shuffle=True, random_state=2022, test_size=0.2)

In [31]:
valid_batch_generator = image_batch_generator(
                            valid_names, 
                            batch_size=8, 
                            resize_shape=(512, 512),
                            aug_transform=None,
                            normalize=True, include_edges_as_band=True
                        )

In [32]:
def print_progress(name, metrics, step, all_steps):
    str_prog = f"{all_steps}/{step}: "
    str_prog += "{} loss {:.4f}, IOU {:.4f}, f1 {:.4f}, prec {:.4f}, rec {:.4f}".format(
        name,
        np.mean(metrics["loss"]), 
        np.mean(metrics["iou"]), 
        np.mean(metrics["f1"]),
        np.mean(metrics["precision"]),
        np.mean(metrics["recall"])
    )

    print(str_prog, end='\r')

In [62]:
val_metrics = {n:[] for n in ("loss", "iou", "f1", "precision", "recall")}

mean_time = []

# valid loop
# with tf.device("GPU:0"):

for i, (batch_X, batch_y) in enumerate(valid_batch_generator):

    start = time.time()

    # print(batch_X.dtype, batch_y.dtype)
    # print(batch_y.min(), batch_y.max())
    batch_X = tf.convert_to_tensor(batch_X, dtype=tf.float32)
    batch_y = tf.convert_to_tensor(batch_y, dtype=tf.float32)

    logits = model(batch_X, training=False)
    logits = tf.squeeze(logits)

    rgb_masks = np.array(logits * 255, dtype=np.uint8)
    final_masks = []

    for mask in rgb_masks:

        thresh = thresh = cv2.threshold(mask, 200, 255, cv2.THRESH_BINARY)[1]
        contours, hierarchy = cv2.findContours(image=thresh, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE)

        final_mask = np.zeros_like(mask)

        for ind, c in enumerate(contours):
            if len(c) > 100:
                min_x, max_x = np.squeeze(c)[:, 0].min(), np.squeeze(c)[:, 0].max()
                min_y, max_y = np.squeeze(c)[:, 1].min(), np.squeeze(c)[:, 1].max()
                final_mask[min_y:max_y, min_x:max_x] = 255

        final_masks.append(final_mask / 255)

    final_masks = tf.convert_to_tensor(final_masks, np.float32)

    # print(np.unique(final_masks))

    # print([metrics.iou(gt, pr) for gt, pr in zip(batch_y, final_masks)])
    # break

    loss_value = loss_fn(batch_y, final_masks)
    # print(loss_value)

    (
        iou_value, f1_score_value, 
        presicion_value, 
        recall_value
    ) = metrics.calculate_metrics(batch_y, final_masks)
    val_metrics["loss"].append(np.mean(loss_value))
    val_metrics["iou"].append(iou_value)
    val_metrics["f1"].append(f1_score_value)
    val_metrics["precision"].append(presicion_value)
    val_metrics["recall"].append(recall_value)

    # break

    mean_time.append(time.time() - start)
    # print(f"{len(valid_names)//8}/{i+1}", end='\r')
    print_progress("valid", val_metrics, i+1, len(valid_names)//8)
    # break
    if (i + 1) >= len(valid_names)//8:
        break

105/105: valid loss 9.1748, IOU 0.5016, f1 0.7208, prec 0.5476, rec 0.5883

In [83]:
val_metrics = {n:[] for n in ("loss", "iou", "f1", "precision", "recall")}
mean_time = []

batch_X, batch_y = utils.read_inf_sample(valid_names[:10], (512, 512))
for i, (X, y) in enumerate(zip(batch_X, batch_y)):

    start = time.time()

    raw = model(tf.expand_dims(X, 0), training=False)
    raw = tf.squeeze(raw)

    pred = utils.preprocess_raw_output(raw, 2, 100 * 10)
    # pred = utils.preprocess_raw_output(raw, 5, 1000 * 10)

    mean_time.append(time.time() - start)

    loss_value = loss_fn(y, pred)
    (
        iou_value, f1_score_value,
        presicion_value, 
        recall_value
    ) = metrics.calculate_metrics([y], [pred])

    val_metrics["loss"].append(np.mean(loss_value))
    val_metrics["iou"].append(iou_value)
    val_metrics["f1"].append(f1_score_value)
    val_metrics["precision"].append(presicion_value)
    val_metrics["recall"].append(recall_value)

    # print(X[:, :, 0].shape, raw.shape, y.shape, pred.shape)
    final_img = cv2.hconcat([
        np.array(X[:, :, 0] * 255, dtype=np.uint8),
        np.array(raw * 255, dtype=np.uint8),
        np.array(y * 255, dtype=np.uint8),
        np.array(pred * 255, dtype=np.uint8)
        # np.array(thresh)
    ])
    cv2.imwrite("pred_samples/{:.4f}_{}_image.png".format(iou_value, i), final_img)
    # break