In [109]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2

import os
from datetime import datetime
import sys
import time

sys.path.insert(0, "../scripts/")
sys.path.insert(0, os.path.join("..", "keras_unets"))

from utils import train_test_split, image_batch_generator, get_train_augmentation, random_batch_generator, get_table_augmentation
from utils import DATASET_PATH, DS_IMAGES, DS_MASKS, PAGE_IMAGES, SaveValidSamplesCallback
import utils
from metrics import iou, f1_score, jaccard_distance
import metrics
from vis import anshow, imshow
import vis
from models import TableNet, load_unet_model

from keras_unet_collection.models import att_unet_2d

IMAGE_NAMES = os.listdir(DS_IMAGES) + os.listdir(PAGE_IMAGES)

import importlib
importlib.reload(metrics)
importlib.reload(vis)
importlib.reload(utils)

<module 'utils' from 'c:\\Users\\user\\Desktop\\analysed.ai\\table_extraction_from_docs\\notebooks\\../scripts\\utils.py'>

In [110]:
TR_CONFIG = {
    "epochs" : 100,
    "batch_size" : 7,
    # "val_batch_size" : 32,
    "lr" : 1e-4,
    "input_shape" : (512, 512),
    "band_size" : 2,
    "three_channel" : False
}

In [111]:
down_scales = [32, 64, 128, 256]
# down_scales = [16, 32, 64, 128]
model = att_unet_2d((TR_CONFIG["input_shape"][0], TR_CONFIG["input_shape"][1], 2), down_scales, n_labels=1,
            stack_num_down=2, stack_num_up=2,
            activation='ReLU', atten_activation='ReLU', attention='add', output_activation="Sigmoid", 
            batch_norm=True, pool=False, unpool='bilinear', name='attunet'
        )

checkpoint = tf.train.Checkpoint(step=tf.Variable(1), optimizer=tf.keras.optimizers.Adam(), net=model)
print(f"loading checkpoint {'training_checkpoints/' + '2022.08.28-14/ckpt-114'}")
status = checkpoint.restore("../scripts/training_checkpoints/" + '2022.08.28-14/ckpt-114')
status.expect_partial()

loading checkpoint training_checkpoints/2022.08.28-14/ckpt-114


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x2be6a592fb0>

In [112]:
train_batch_generator = image_batch_generator(
                            IMAGE_NAMES, 
                            batch_size=TR_CONFIG["batch_size"], 
                            resize_shape=TR_CONFIG["input_shape"],
                            aug_transform=None,
                            normalize=True, three_channel=TR_CONFIG["three_channel"],
                            return_names=True
                        )

In [113]:
def print_progress(name, metrics, step, all_steps):
    str_prog = f"{all_steps}/{step}: "
    str_prog += "{} loss {:.4f}, tf_iou {:.4f}, iou {:.4f}, f1 {:.4f}, prec {:.4f}, rec {:.4f}".format(
        name,
        np.mean(metrics["loss"]),
        np.mean(metrics["tf_iou"]),
        np.mean(metrics["iou"]),
        np.mean(metrics["f1"]),
        np.mean(metrics["precision"]),
        np.mean(metrics["recall"])
    )

    print(str_prog, end='\r')

In [114]:
tr_metrics = {n:[] for n in ("loss", "iou", "tf_iou", "f1", "precision", "recall")}

for i, (batch_X, batch_y, image_names) in enumerate(train_batch_generator):

    # print(batch_X.dtype, batch_y.dtype)
    # print(batch_y.min(), batch_y.max())
    batch_X = tf.convert_to_tensor(batch_X, dtype=tf.float32)
    batch_y = tf.convert_to_tensor(batch_y, dtype=tf.float32)
    # print(type(data))

    pred = model(batch_X, training=False)
    pred = tf.squeeze(pred, -1)

    loss_value = metrics.jaccard_distance(pred, batch_y)

    for ind, (name, X, y, pred_y) in enumerate(zip(image_names, batch_X, batch_y, pred)):

        mask = np.array(pred_y)
        mask[mask < 0.9] = 0
        pred_y = tf.convert_to_tensor(mask)
        
        iou_val = metrics.iou(y, pred_y) * 100

        cv2.imwrite(
            "preds/{:.4f}_{}.png".format(iou_val, name),
            cv2.hconcat([
                np.array(X[:, :, 0] * 255, dtype=np.uint8), 
                np.array(pred_y * 255, dtype=np.uint8), 
                np.array(y * 255, dtype=np.uint8)
            ])
        )

    (
        iou_value, tf_iou_value,
        f1_score_value, 
        presicion_value, 
        recall_value
    ) = metrics.calculate_metrics(batch_y, pred)

    tr_metrics["loss"].append(loss_value)
    tr_metrics["iou"].append(iou_value)
    tr_metrics["tf_iou"].append(tf_iou_value)
    tr_metrics["f1"].append(f1_score_value)
    tr_metrics["precision"].append(presicion_value)
    tr_metrics["recall"].append(recall_value)

    print_progress("valid", tr_metrics, i+1, len(IMAGE_NAMES)//TR_CONFIG["batch_size"])
    # break
    if (i + 1) >= len(IMAGE_NAMES)//TR_CONFIG["batch_size"]:
        break

884/884: valid loss 0.1263, tf_iou 0.8922, iou 0.8737, f1 0.9066, prec 0.9484, rec 0.9238

In [115]:
import os

In [116]:
names = """58_248
85_250
53_134
85_231
94_49
6008_014
90_111
10.1.1.160.606_5
21_72
10.1.1.160.699_3
77_126
52_275
10.1.1.160.659_3
10.1.1.160.656_12
10.1.1.160.652_9
10.1.1.160.653_12
41_50
10.1.1.160.655_9
21_133
9_139
2208-10297-pdf_page_10_jpg.rf.4266a173854f0a9cc83bd4910d994e29
10.1.1.160.657_4
74_29
55_301
10.1.1.1.2103_4
30_4
10.1.1.160.651_22
44_92
29_122
29_15
11_150
33_252
15_261
81_90
2208-10406-pdf_page_5_jpg.rf.b80a8a94650a8487b97770eb1ee5b977
24_11
15_111
33_11
4_16
17_229
1852_095
63_59
6286_013
0651_008
NFE_Roster_page_1_jpg.rf.310a7dec8cce7418a155705fd2651f2d
NFE_Roster_page_0_jpg.rf.c30b226e49ac8dc4a9d6694059ab80a5
10.1.1.38.2480_2
5727_109
5727_096
5727_105
2022_Freakout_JFDS-pdf_page_0_jpg.rf.0278be037a923592bce7969534d65532""".split("\n")

In [118]:
len(names)

51

In [126]:
removed = 0

In [127]:
for n in names:
    for ext in ["jpg", "bmp", "png", "jpeg"]:
        if os.path.exists(f"../datasets/all_images/{n}.{ext}"):
            removed += 1
            os.remove(f"../datasets/all_images/{n}.{ext}")

        if os.path.exists(f"../datasets/all_masks/{n}_mask.{ext}"):
            # print(n, "mask")
            os.remove(f"../datasets/all_masks/{n}_mask.{ext}")

In [128]:
removed

0