In [None]:
import json
import os
import time
import sys

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from imgaug.augmentables.kps import KeypointsOnImage
from imgaug.augmentables.kps import Keypoint
import imgaug.augmenters as iaa


In [None]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)

In [None]:
IMG_SIZE = 224
BATCH_SIZE = 32

NUM_KEYPOINTS = 16*2
EPOCHS = NUM_KEYPOINTS*3

BASE_LEARNING_RATE = 1e-4

In [None]:
IMG_DIR = "datasets/mpii/images_selected/"
JSON = "datasets/mpii/trainval.json"


with open(JSON) as infile:
    json_data = json.load(infile)

json_dict = {i["image"]: i for i in json_data}

In [None]:
for k in json_dict.keys():
    for i in range(0,16):
        json_dict[k]['joints'][i].append(json_dict[k]['joints_vis'][i])
    del json_dict[k]['joints_vis']

In [None]:
train_aug = iaa.Sequential(
    [
        iaa.Resize(IMG_SIZE, interpolation="linear"),
        iaa.Fliplr(0.3),
        iaa.Sometimes(0.3, iaa.Affine(rotate=10, scale=(0.5, 0.7))),
    ]
)

test_aug = iaa.Sequential([iaa.Resize(IMG_SIZE, interpolation="linear")])

In [None]:
def get(name):
    data = json_dict[name]
    img_data = plt.imread(IMG_DIR + data["image"])
    # If the image is RGBA convert it to RGB.
    # if img_data.shape[-1] == 4:
    #     img_data = img_data.astype(np.uint8)
    #     img_data = Image.fromarray(img_data)
    #     img_data = np.array(img_data.convert("RGB"))
    data["img_data"] = img_data
    #del img_data

    return data

In [None]:

def visualize_keypoints(images, keypoints):
    fig, axes = plt.subplots(nrows=len(images), ncols=2, figsize=(16, 12))
    [ax.axis("off") for ax in np.ravel(axes)]

    for (ax_orig, ax_all), image, current_keypoint in zip(axes, images, keypoints):
        ax_orig.imshow(image)
        ax_all.imshow(image)

        if isinstance(current_keypoint, KeypointsOnImage):
            for idx, kp in enumerate(current_keypoint.keypoints):
                ax_all.scatter(
                    [kp.x], [kp.y], c='red', marker="v", s=50, linewidths=2
                )
        else:
            current_keypoint = np.array(current_keypoint)
            current_keypoint = current_keypoint[:, :2]
            for idx, (x, y) in enumerate(current_keypoint):
                ax_all.scatter([x], [y], c='red', marker="v", s=50, linewidths=2)

    plt.tight_layout(pad=2.0)
    plt.show()


samples = list(json_dict.keys())
num_samples = 4
selected_samples = np.random.choice(samples, num_samples, replace=False)

images, keypoints = [], []

for sample in selected_samples:
    data = get(sample)
    image = data["img_data"]
    keypoint = data["joints"]

    images.append(image)
    keypoints.append(keypoint)

visualize_keypoints(images, keypoints)

In [None]:
class KeyPointsDataset(tf.keras.utils.Sequence):
    
    def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True):
        self.image_keys = image_keys
        self.aug = aug
        self.batch_size = batch_size
        self.train = train
        self.on_epoch_end()

    def __len__(self):
        return len(self.image_keys) // self.batch_size

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.image_keys))
        if self.train:
            np.random.shuffle(self.indexes)

    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        image_keys_temp = [self.image_keys[k] for k in indexes]
        (images, keypoints) = self.__data_generation(image_keys_temp)

        return (images, keypoints)

    def __data_generation(self, image_keys_temp):
        batch_images = np.empty((self.batch_size, IMG_SIZE, IMG_SIZE, 3), dtype="int")
        batch_keypoints = np.empty(
            (self.batch_size, 1, 1, NUM_KEYPOINTS), dtype="float32"
        )

        for i, key in enumerate(image_keys_temp):
            
            #data = get(key)
            
            data = json_dict[key]
            img_data = plt.imread(IMG_DIR + data["image"])


            
            
            current_keypoint = np.array(data["joints"])[:, :2]
            kps = []

            for j in range(0, len(current_keypoint)):
                kps.append(Keypoint(x=current_keypoint[j][0], y=current_keypoint[j][1]))

            kps_obj = KeypointsOnImage(kps, shape=img_data.shape)

            (new_image, new_kps_obj) = self.aug(image=img_data, keypoints=kps_obj)
            batch_images[i,] = new_image

            kp_temp = []
            for keypoint in new_kps_obj:
                kp_temp.append(np.nan_to_num(keypoint.x))
                kp_temp.append(np.nan_to_num(keypoint.y))

            batch_keypoints[i,] = np.array(kp_temp).reshape(1, 1, NUM_KEYPOINTS)
            del data

        batch_keypoints = batch_keypoints / IMG_SIZE


        
        return (batch_images, batch_keypoints)

In [None]:
np.random.shuffle(samples)
train_keys, validation_keys = (
    samples[int(len(samples) * 0.15) :],
    samples[: int(len(samples) * 0.15)],
)

In [None]:
train_dataset = KeyPointsDataset(train_keys, train_aug)
validation_dataset = KeyPointsDataset(validation_keys, test_aug, train=False)

print(f"Total batches in training set: {len(train_dataset)}")
print(f"Total batches in validation set: {len(validation_dataset)}")

sample_images, sample_keypoints = next(iter(train_dataset))

sample_keypoints = sample_keypoints[:4].reshape(-1, 16, 2) * IMG_SIZE
visualize_keypoints(sample_images[:4], sample_keypoints)

In [None]:
backbone = tf.keras.applications.MobileNetV3Small(
    weights="imagenet", include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3)
)

In [None]:
# for i in range(len(backbone.layers)-10):
#     backbone.layers[i].trainable = False
backbone.trainable = False
#backbone.summary()

In [None]:
inputs = tf.keras.layers.Input((IMG_SIZE, IMG_SIZE, 3))
x = tf.keras.applications.mobilenet_v3.preprocess_input(inputs)
x = backbone(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.SeparableConv2D(
    NUM_KEYPOINTS, kernel_size=5, strides=1, activation="relu"
)(x)
outputs = tf.keras.layers.SeparableConv2D(
    NUM_KEYPOINTS, kernel_size=3, strides=1, activation="sigmoid"
)(x)

model = tf.keras.Model(inputs, outputs, name="keypoint_detector")

model.summary()

In [None]:
# inputs = tf.keras.layers.Input((IMG_SIZE, IMG_SIZE, 3))
# x = tf.keras.applications.resnet.preprocess_input(inputs)
# x = backbone(x)
# x = tf.keras.layers.Dropout(0.2)(x)
# x = tf.keras.layers.SeparableConv2D(
#     NUM_KEYPOINTS, kernel_size=5, strides=1, activation="relu"
# )(x)
# outputs = tf.keras.layers.SeparableConv2D(
#     NUM_KEYPOINTS, kernel_size=3, strides=1, activation="sigmoid"
# )(x)

# model = tf.keras.Model(inputs, outputs, name="keypoint_detector")

# model.summary()

In [None]:
# inputs = tf.keras.layers.Input((IMG_SIZE, IMG_SIZE, 3))
# x = tf.keras.applications.resnet.preprocess_input(inputs)
# x = backbone(x)

# x = tf.keras.layers.Flatten()(x)
# NEURONS = 1000
# x = tf.keras.layers.Dense(NEURONS, activation='relu')(x)
# outputs = tf.keras.layers.Dense(NUM_KEYPOINTS, activation = 'softmax')(x)


# model = tf.keras.Model(inputs, outputs, name="keypoint_detector")

# model.summary()

In [None]:
# inputs = tf.keras.layers.Input((IMG_SIZE, IMG_SIZE, 3))
# x = tf.keras.applications.mobilenet_v3.preprocess_input(inputs)
# x = backbone(x)
# x = tf.keras.layers.Dropout(0.2)(x)
# x = tf.keras.layers.SeparableConv2D(
#     576, kernel_size=5, strides=1, activation="relu"
# )(x)
# x = tf.keras.layers.SeparableConv2D(
#     576, kernel_size=3, strides=1, activation="sigmoid"
# )(x)
# x = tf.keras.layers.Flatten()(x)
# x = tf.keras.layers.Dense(NUM_KEYPOINTS, activation='relu')(x)
# x = tf.keras.layers.Dropout(0.3)(x)
# x = tf.keras.layers.Dense(NUM_KEYPOINTS, activation='relu')(x)
# x = tf.keras.layers.Dropout(0.3)(x)
# outputs = tf.keras.layers.Dense(NUM_KEYPOINTS, activation='softmax')(x)

# model = tf.keras.Model(inputs, outputs, name="keypoint_detector")

# model.summary()

In [None]:
# model.compile(loss="mse", optimizer=tf.keras.optimizers.Adam(BASE_LEARNING_RATE),metrics=['accuracy'], run_eagerly=True)
# model.compile(optimizer=tf.keras.optimizers.Adam(BASE_LEARNING_RATE), loss='mean_squared_error', metrics=[tf.keras.metrics.RootMeanSquaredError()], run_eagerly=True)
# model.compile(optimizer='adam', loss='mean_squared_error', metrics=[tf.keras.metrics.RootMeanSquaredError()])#, run_eagerly=True)
# model.compile(loss="mse", optimizer=tf.keras.optimizers.Adam(BASE_LEARNING_RATE),metrics=['mae'], run_eagerly=True)
model.compile(optimizer=tf.keras.optimizers.Adam(BASE_LEARNING_RATE), loss='mse')

In [None]:
!mkdir -p training_ckp
checkpoint_path = "training_ckp/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

SAVE_PERIOD = 10

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, 
                                                    verbose=1, 
                                                    save_weights_only=True,
                                                    save_freq='epoch',
                                                    #save_freq=int(SAVE_PERIOD * STEPS_PER_EPOCH),
                                                )

earlystopping_callback = tf.keras.callbacks.EarlyStopping(monitor ="val_loss", mode ="min", patience = 20
                                                          , restore_best_weights = True)

!mkdir -p training_logs
tensorboard = tf.keras.callbacks.TensorBoard(log_dir="training_logs/{}".format(time.time()))

In [None]:
history = model.fit(train_dataset,
                    validation_data=validation_dataset,
                    epochs=EPOCHS,
                    callbacks=[
                                cp_callback,
                                earlystopping_callback,
                                tensorboard
                              ]
                   )

In [None]:
!mkdir -p saved_model
model.save('saved_model/before_finetuning_model.h5')

In [None]:
for layer in backbone.layers:
    layer.trainable = True
# backbone.trainable = True


print("Number of layers in the base model: ", len(backbone.layers))

In [None]:
model.summary()

In [None]:
# fine_tune_at = len(backbone.layers)-10

# for layer in backbone.layers[:fine_tune_at]:
#     layer.trainable = False

In [None]:
model.summary()

In [None]:
# model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
#               optimizer = tf.keras.optimizers.Adam(BASE_LEARNING_RATE/10),
#               metrics=['accuracy'])
# model.compile(loss="mse", optimizer=tf.keras.optimizers.Adam(BASE_LEARNING_RATE/10),metrics=['accuracy'], run_eagerly=True)
#model.compile(optimizer=tf.keras.optimizers.Adam(BASE_LEARNING_RATE/10), loss='mean_squared_error', metrics=['mae'], run_eagerly=True)
model.compile(optimizer=tf.keras.optimizers.Adam(BASE_LEARNING_RATE/10), loss='mean_squared_error', metrics=[tf.keras.metrics.RootMeanSquaredError()])


In [None]:
# add checkpoints
!mkdir -p finetuning_ckp
checkpoint_path = "finetuning_ckp/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights

cp_callback = tf.keras.callbacks.ModelCheckpoint(
                                                    filepath=checkpoint_path, 
                                                    verbose=1, 
                                                    save_weights_only=True,
                                                    save_freq='epoch',
                                                    period=5
                                                    )

earlystopping_callback = tf.keras.callbacks.EarlyStopping(monitor ="val_loss", mode ="min", patience = 20, restore_best_weights = True)


In [None]:
FINETUNE_EPOCHS = 20

history_fine = model.fit(train_dataset,
                         epochs=FINETUNE_EPOCHS,
                         validation_data=validation_dataset,
                         callbacks=[
                                cp_callback,
                                earlystopping_callback,
                                tensorboard
                              ]
                        )


In [None]:
model.save('saved_model/final_model.h5')