In [None]:
import os
import json
import logging
import datetime

import numpy as np
import pandas as pd
import tensorflow as tf

logger = logging.getLogger(__name__)
logger.setLevel("INFO")

In [None]:
# Constants
IMG_SIZE = (512, 512, 3)
VAL_SIZE = 0.2

train_dataset_path = "train_landmark_files/"
save_dataset_path = "test/"
class_path = "train.csv"
class_to_label_path = "sign_to_prediction_index_map.json"
train_data_with_label = "train_data_with_label.csv"
val_participants = [55372, 61333, 62590]

log_dir = os.path.join("./logs/fit/", datetime.datetime.now().strftime("%Y%m%d-%H%M"))

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir, histogram_freq=1
)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(f'./model/{datetime.datetime.now().strftime("%Y%m%d-%H%M")}', 'cp-{epoch:02d}'),
    save_freq='epoch',
    monitor="val_loss",
    save_best_only=True,
    save_period=128
)


In [None]:
def get_train_dataset(class_path: str, class_to_label_path: str) -> pd.DataFrame:
    class_name = pd.read_csv(class_path)

    with open(class_to_label_path, "r") as fp:
        y_label = json.load(fp)
        y_label = pd.DataFrame.from_dict(y_label, orient="index", columns = ["y_label"])

    df = class_name.set_index("sign").join(y_label).reset_index()
    df["save_dataset_path"] = df.sequence_id.apply(lambda x: os.path.join(save_dataset_path + str(x) + ".npy"))
    df["example_set"] = df.participant_id.apply(lambda x: "val" if x in val_participants else "train")

    return df


In [None]:
df = get_train_dataset(class_path, class_to_label_path)
print(df[df.example_set == "train"].y_label.nunique(), df[df.example_set == "val"].y_label.nunique())


In [None]:
df.head()

In [None]:
class GISLRSequence(tf.keras.utils.Sequence):
    def __init__(
            self, df: pd.DataFrame,
            x_col: str,
            y_col: str,
            sample_size: int,
            batch_size: int,
            shuffle: bool = True
        ) -> None:
        self.batch_size = batch_size
        self.sample_size = sample_size
        self.shuffle = shuffle
        self.n_classes = df[y_col].nunique()

        self.x_files_path = df[x_col]
        self.y = df[y_col]

        self.indices = df.index.to_list()

    def __len__(self):
        return (len(self.x_files_path) // self.batch_size)

    def __getitem__(self, idx):
        subset = self.indices[(idx * self.batch_size):((idx + 1) * self.batch_size)]
        batch_x = self.x_files_path[subset]
        batch_y = self.y[subset]

        X, y = self._get_data(batch_x, batch_y)

        return X, tf.one_hot(y, depth=self.n_classes)

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    def _get_data(self, batch_x, batch_y):
        temp_X, temp_y = [], []

        for idx, idy in zip(batch_x, batch_y):
            temp = self._get_transformed_data(idx)

            temp_X.append(temp)
            temp_y.append(idy)

        return np.array(temp_X), np.array(temp_y)

    def _get_transformed_data(self, dataset_file_path) -> np.ndarray:
        all_imgs = np.load(dataset_file_path)
        logger.info(f"Found: {all_imgs.shape} frames in {dataset_file_path}")

        og_size = all_imgs.shape[0]
        bfill = True if np.random.uniform() < 0.5 else False
        for i in range(0, self.sample_size-og_size):
            frame_to_add = i//og_size
            if bfill:
                frame_to_add = abs(og_size-i)//og_size

            all_imgs = np.concatenate((
                all_imgs, np.expand_dims(all_imgs[frame_to_add], axis = 0)
            ))

        all_imgs = np.where(np.isnan(all_imgs), 0, all_imgs)

        return all_imgs[:self.sample_size, :]


In [None]:
test = GISLRSequence(
    df[df.example_set == "train"],
    x_col="save_dataset_path",
    y_col="y_label",
    sample_size=64,
    batch_size=16
)

for x, y in test:
    break

print(x.shape, y.shape)

In [None]:
train_dataset = GISLRSequence(
    df[df.example_set == "train"],
    x_col="save_dataset_path",
    y_col="y_label",
    sample_size=32,
    batch_size = 16
)

val_dataset = GISLRSequence(
    df[df.example_set == "val"],
    x_col="save_dataset_path",
    y_col="y_label",
    sample_size=32,
    batch_size = 16
)


In [None]:
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.add = tf.keras.layers.Add()

class GlobalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x
        )
        x = self.add([x, attn_output])
        return attn_output

class GISLRModelv2():
    def __init__(self, n_classes: int, input_shape = (32, 1629)) -> None:
        model_input = tf.keras.Input(shape=input_shape)
        # query_seq_encoding = tf.keras.layers.Conv1D(filters=128, kernel_size=4, padding='same')(model_input)
        query_seq_encoding = tf.keras.layers.DepthwiseConv1D(kernel_size=4, padding='same', depth_multiplier=1)(model_input)

        # mha = GlobalSelfAttention(num_heads=2, key_dim=128)(query_seq_encoding)
        query_seq_encoding_bn = tf.keras.layers.BatchNormalization()(query_seq_encoding)
        global_pool = tf.keras.layers.GlobalAveragePooling1D()(query_seq_encoding_bn)

        dense_attn = tf.keras.layers.Dense(256, activation="relu")(global_pool)
        dropout = tf.keras.layers.Dropout(0.2)(dense_attn)

        output = tf.keras.layers.Dense(
            n_classes, activation="softmax")(dropout)

        self.model = tf.keras.Model(model_input, output)

    def get_model(self):
        self.model.compile(
            loss='categorical_crossentropy',
            optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
            metrics=[tf.keras.metrics.CategoricalAccuracy(), tf.keras.metrics.CategoricalCrossentropy()]
        )
        print(self.model.summary())

        return self.model


In [None]:
model = GISLRModelv2(n_classes=df.y_label.nunique()).get_model()

In [None]:
history = model.fit(
    train_dataset,
    epochs=100,
    verbose=1,
    validation_data=val_dataset,
    callbacks=[tensorboard_callback, checkpoint]
)