In [1]:
import tensorflow as tf
import os

In [20]:
from functools import reduce
from itertools import accumulate

landmark_lens = (
    (33, 4),
    (468, 3),
    (21, 3),
    (21, 3)
)
landmark_locs = list(accumulate(landmark_lens, lambda a, b: a + b[0]*b[1], initial=0))
landmarks_len = reduce(lambda r, loc: r + loc[0] * loc[1], landmark_lens, 0)
print(landmark_locs)

[0, 132, 1536, 1599, 1662]


In [21]:
labels = [label for label in os.listdir('tracks_binary') if os.path.isdir(f'tracks_binary/{label}')]
NUM_CLASSES = len(labels)

labels_tensor = tf.constant(labels)
ids_tensor = tf.constant(range(len(labels)))

ids_from_labels = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        labels_tensor,
        ids_tensor
    ),
    default_value=-1
)

labels_from_ids = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        ids_tensor,
        labels_tensor
    ),
    default_value=""
)

def to_categorical(label):
    return tf.one_hot(
        ids_from_labels.lookup(label),
        depth=NUM_CLASSES
    )

In [22]:
def process_binary(file_path):
    label = tf.strings.split(file_path, os.sep)[-2]

    raw = tf.io.read_file(file_path)
    data = tf.io.decode_raw(raw, tf.float32)
    data = tf.reshape(data, [-1, landmarks_len])

    pose = tf.reshape(data[:, 0:132], [-1, 33, 4])
    face = tf.reshape(data[:, 132:1536], [-1, 468, 3])
    lh = tf.reshape(data[:, 1536:1599], [-1, 21, 3])
    rh = tf.reshape(data[:, 1599:1662], [-1, 21, 3])

    return (pose, face, lh, rh), to_categorical(label)

In [23]:
FRAMES = 64

def flatten(x):
    pose = tf.reshape(x[0], shape=[-1, 132])
    face = tf.reshape(x[1], shape=[-1, 1404])
    lh = tf.reshape(x[2], shape=[-1, 63])
    rh = tf.reshape(x[3], shape=[-1, 63])

    return tf.concat([pose, face, lh, rh], axis=1)


def random_window(x):
    def pad(x):
        missing = FRAMES - size
        start_pad = tf.math.ceil(missing / 2)
        end_pad = tf.math.floor(missing / 2)
        return tf.concat([
            tf.tile([x[0]], [start_pad, 1]),
            x,
            tf.tile([x[-1]], [end_pad, 1])
        ], axis=0)

    def random_slice(x):
        i = tf.random.uniform(shape=(), maxval=size+1-FRAMES, dtype=tf.int32)
        return x[i: i+FRAMES]

    size = tf.shape(x)[0]
    return tf.cond(
        size < FRAMES,
        lambda: pad(x),
        lambda: random_slice(x)
    )
    

def prepare(ds):
    ds = ds.map(lambda x, y: (flatten(x), y), num_parallel_calls=tf.data.AUTOTUNE)

    ds = ds.map(lambda x, y: (random_window(x), y), num_parallel_calls=tf.data.AUTOTUNE)

    ds = ds.shuffle(1000)

    ds = ds.batch(32)

    return ds.prefetch(buffer_size=tf.data.AUTOTUNE)

In [24]:
ds = tf.data.Dataset.list_files('tracks_binary/*/*')
ds = ds.map(process_binary)

ds = prepare(ds)

In [25]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Bidirectional
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ReduceLROnPlateau
from tensorflow import keras

In [26]:
TRIAL = 10
log_dir = os.path.join('Logs/{}'.format(TRIAL))
tb_callback = TensorBoard(log_dir=log_dir)
es_callback = EarlyStopping(monitor='val_loss', patience=20)
lr_callback = ReduceLROnPlateau(monitor='val_loss', patience=25)

2021-12-14 21:58:32.278818: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-12-14 21:58:32.278829: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.
2021-12-14 21:58:32.279127: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.


In [27]:
model = Sequential()
model.add(Bidirectional(LSTM(64, return_sequences=True), input_shape=(FRAMES, 1662)))
model.add(Dense(64, activation='relu'))
model.add(Bidirectional(LSTM(64, return_sequences=True)))
model.add(Dense(64, activation='relu'))
model.add(Bidirectional(LSTM(64, return_sequences=False, dropout=0.2)))
model.add(Dense(NUM_CLASSES, activation='softmax'))

In [28]:
opt = keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

In [29]:
model.summary()


Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bidirectional_6 (Bidirection (None, 64, 128)           884224    
_________________________________________________________________
dense_6 (Dense)              (None, 64, 64)            8256      
_________________________________________________________________
bidirectional_7 (Bidirection (None, 64, 128)           66048     
_________________________________________________________________
dense_7 (Dense)              (None, 64, 64)            8256      
_________________________________________________________________
bidirectional_8 (Bidirection (None, 128)               66048     
_________________________________________________________________
dense_8 (Dense)              (None, 10)                1290      
Total params: 1,034,122
Trainable params: 1,034,122
Non-trainable params: 0
____________________________________________

In [30]:
history = model.fit(ds, epochs=2000, callbacks=[tb_callback])

Epoch 1/2000


2021-12-14 21:58:37.192270: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-12-14 21:58:37.729164: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-12-14 21:58:37.729206: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-12-14 21:58:37.953610: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-12-14 21:58:37.962615: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-12-14 21:58:38.112249: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-12-14 21:58:38.120876: I tensorflow/core/grappler/optimizers/cust

 1/10 [==>...........................] - ETA: 34s - loss: 2.2833 - accuracy: 0.1250

2021-12-14 21:58:39.350593: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-12-14 21:58:39.350605: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.


 2/10 [=====>........................] - ETA: 2s - loss: 2.2849 - accuracy: 0.1562 

2021-12-14 21:58:39.653904: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2021-12-14 21:58:39.654817: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.
2021-12-14 21:58:39.656437: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: Logs/10/train/plugins/profile/2021_12_14_21_58_39

2021-12-14 21:58:39.657110: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to Logs/10/train/plugins/profile/2021_12_14_21_58_39/Stevens-MacBook-Air.local.trace.json.gz
2021-12-14 21:58:39.657853: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: Logs/10/train/plugins/profile/2021_12_14_21_58_39

2021-12-14 21:58:39.658028: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to Logs/10/train/plugins/profile/2021_12_14_21_58_39/Stevens-MacBook-Air.local.memory_profile

Epoch 2/2000
Epoch 3/2000
Epoch 4/2000
Epoch 5/2000
Epoch 6/2000
Epoch 7/2000
Epoch 8/2000
Epoch 9/2000
Epoch 10/2000
Epoch 11/2000
Epoch 12/2000
Epoch 13/2000
Epoch 14/2000
Epoch 15/2000
Epoch 16/2000
Epoch 17/2000
Epoch 18/2000
Epoch 19/2000
Epoch 20/2000
Epoch 21/2000
Epoch 22/2000
Epoch 23/2000
Epoch 24/2000
Epoch 25/2000
Epoch 26/2000
Epoch 27/2000
Epoch 28/2000
Epoch 29/2000
Epoch 30/2000
Epoch 31/2000
Epoch 32/2000
Epoch 33/2000
Epoch 34/2000
Epoch 35/2000
Epoch 36/2000
Epoch 37/2000
Epoch 38/2000
Epoch 39/2000
Epoch 40/2000

KeyboardInterrupt: 