In [1]:
import tensorflow as tf, matplotlib.pyplot as plt, os, numpy as np, matplotlib.pyplot as plt, matplotlib.patches as patches
from sklearn.metrics import confusion_matrix as C_M, accuracy_score as A_S, classification_report as C_R
from framework.utils import bbox_utils, data_utils, drawing_utils, eval_utils, io_utils, train_utils
from framework.models import rpn_vgg16, faster_rcnn

# Read Dataset from TFRecord

In [2]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
IMAGE_SIZE = 512  # make sure had same size with the picture

In [3]:
def decode_image(img):
    img = tf.cast(img, tf.int32)
    return img

In [4]:
def read_tfrecord(example):
    tfrecord_format = {
        "filename": tf.io.FixedLenFeature([], tf.string),
        "pic": tf.io.FixedLenFeature([], tf.string),
        "bbox": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    filename = tf.cast(example["filename"], tf.string)
    image = decode_image(tf.io.parse_tensor(example["pic"], out_type = tf.uint8))
    bbox = tf.io.parse_tensor(example["bbox"], out_type = tf.float32)
    label = tf.io.parse_tensor(example["label"], out_type = tf.int32)
    return {"filename": filename, "image": image, "bbox": bbox, "label": label}

In [5]:
def load_dataset(filenames):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    dataset = tf.data.TFRecordDataset(filenames)  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order)  # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_tfrecord)
    return dataset

In [6]:
def get_dataset(filenames):
    dataset = load_dataset(filenames)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    return dataset

In [7]:
def read_label_map(label_map_path):
    item_id = None
    item_name = None
    items = {}

    with open(label_map_path, "r") as file:
        for line in file:
            line.replace(" ", "")
            if line == "item{":
                pass
            elif line == "}":
                pass
            elif "id" in line:
                item_id = int(line.split(":", 1)[1].strip())
            elif "name" in line:
                item_name = line.split(":", 1)[1].replace("'", "").replace("\"", "").strip()

            if item_id is not None and item_name is not None:
                items[item_name] = item_id
                item_id = None
                item_name = None

    return items

In [8]:
label_map_path = "./data_preparation/label_map.pbtxt"
label_map_dict = read_label_map(label_map_path)

In [9]:
def get_label_text(result, doc = label_map_dict):
    for key, value in doc.items():
        if(value == result + 1):
            return key
    return "Unpredictable"

In [10]:
train_data = get_dataset("./data_preparation/train.tfrecord")

test_data = get_dataset("./data_preparation/test.tfrecord")

train_data

<PrefetchDataset shapes: {filename: (), image: <unknown>, bbox: <unknown>, label: <unknown>}, types: {filename: tf.string, image: tf.int32, bbox: tf.float32, label: tf.int32}>

In [11]:
def show_data(data, n):
    print(data)
    for dat in data.take(n):
        plt.imshow(dat["image"])
        for coord in dat["bbox"]: # bbox is ymin, xmin, ymax, xmax
            coord *= IMAGE_SIZE
            rect = patches.Rectangle(
                (coord[1].numpy(), coord[0].numpy()),  # x1, y1
                coord[3].numpy() - coord[1].numpy(),  # width
                coord[2].numpy() - coord[0].numpy(),  # height
                linewidth = 2, edgecolor = "r", fill = False)
            plt.gca().add_patch(rect)
        plt.show()

In [12]:
# show_data(train_data, 20)

# Train Model

In [13]:
batch_size = 4
epochs = 50
load_weights = False
backbone = "vgg16"

hyper_params = train_utils.get_hyper_params(backbone)
train_total_item = len(list(train_data))

In [14]:
labels = list(label_map_dict.keys())
# We add 1 class for background
hyper_params["total_labels"] = len(labels) + 1
train_data = train_data.map(lambda data : data_utils.preprocessing_before_frcnn(data, IMAGE_SIZE, IMAGE_SIZE))

In [15]:
data_shapes = data_utils.get_data_shapes()
padding_values = data_utils.get_padding_values()
train_data = train_data.padded_batch(batch_size, padded_shapes=data_shapes, padding_values=padding_values)

In [16]:
anchors = bbox_utils.generate_anchors(hyper_params)
frcnn_train_feed = train_utils.faster_rcnn_generator(train_data, anchors, hyper_params)

In [17]:
rpn_model, feature_extractor = rpn_vgg16.get_model_vgg16(hyper_params)
frcnn_model = faster_rcnn.get_model_frcnn(feature_extractor, rpn_model, anchors, hyper_params)
frcnn_model.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-5),
                    loss=[None] * len(frcnn_model.output))
faster_rcnn.init_model_frcnn(frcnn_model, hyper_params)