In [None]:
TPU = 1
VAL = 0
LOAD = 0

In [None]:
import time
!pip install Levenshtein
import Levenshtein as lev

In [None]:
if 'tpu_init' not in globals():
    import pandas as pd
    import numpy as np
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    import json
    import tensorflow_addons as tfa

In [None]:
if 'tpu_init' not in globals():
    tpu_init = True
    if TPU:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    else:
        ngpu = len(tf.config.experimental.list_physical_devices('GPU'))
        if ngpu>1:
            print("Using multi GPU")
            strategy = tf.distribute.MirroredStrategy()
        elif ngpu==1:
            print("Using single GPU")
            strategy = tf.distribute.get_strategy()
            BATCH_SIZE = 64
        else:
            print("Using CPU")
            strategy = tf.distribute.get_strategy()
            BATCH_SIZE = 2 

In [None]:
from shutil import copyfile
# copy our file into the working directory (make sure it has .py suffix)
copyfile(src = "/kaggle/input/ctc-tpu/CTC_TPU.py", dst = "/kaggle/working//CTC_TPU.py")

# import all our functions
from CTC_TPU import classic_ctc_loss

In [None]:
COMPUTE_TYPE = 'float32'
USE_Z_AXIS = True
BATCH_SIZE = 64
FRAME_LEN = 220  # 128
PHRASE_MAX_LEN = 32 # + 2
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
DROP_OUT = 0.1
# NORM = tf.keras.layers.BatchNormalization()
# NORM = tf.keras.layers.LayerNormalization()
MASKING = False
if TPU:
    BATCH_SIZE = 25 * strategy.num_replicas_in_sync
    print(BATCH_SIZE)

In [None]:
drive_path = "/kaggle/input/asl-fingerspelling"
with open(f"{drive_path}/character_to_prediction_index.json", "r") as f:
    char_to_num = json.load(f)

pad_token = "P"
pad_token_idx = 59

char_to_num[pad_token] = pad_token_idx

CLASS_NUM = len(char_to_num)
print(CLASS_NUM)

num_to_char = {j: i for i, j in char_to_num.items()}

df = pd.read_csv(f"/kaggle/input/asl-fingerspelling/train.csv")

In [None]:
NOSE=[
    1,2,98,327
]
LNOSE = [98]
RNOSE = [327]
LIP = [ 0, 
    61, 185, 40, 39, 37, 267, 269, 270, 409,
    291, 146, 91, 181, 84, 17, 314, 405, 321, 375,
    78, 191, 80, 81, 82, 13, 312, 311, 310, 415,
    95, 88, 178, 87, 14, 317, 402, 318, 324, 308,
]
LLIP = [84,181,91,146,61,185,40,39,37,87,178,88,95,78,191,80,81,82]
RLIP = [314,405,321,375,291,409,270,269,267,317,402,318,324,308,415,310,311,312]

POSE = [490,491,492,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511,512,513,514,515,516,517,518,519,520,521,522]
POSE = [i-1 for i in POSE]

LPOSE = [494, 495, 497, 499, 501, 503, 505, 507, 509, 511, 513, 515, 517, 519]
LPOSE = [i-1 for i in LPOSE]

RPOSE = [491, 492, 496, 498, 500, 502, 504, 506, 508, 510, 512, 514, 516, 518]
RPOSE = [i-1 for i in RPOSE]



REYE = [
    33, 7, 163, 144, 145, 153, 154, 155, 133,
    246, 161, 160, 159, 158, 157, 173,
]
LEYE = [
    263, 249, 390, 373, 374, 380, 381, 382, 362,
    466, 388, 387, 386, 385, 384, 398,
]

LHAND = np.arange(468, 489).tolist()
RHAND = np.arange(522, 543).tolist()

POINT_LANDMARKS = LIP + LHAND + RHAND + NOSE + REYE + LEYE  + POSE # use POSE

def idx_to_cols(idx_array):
    """Create column names for landmarks without using training_args.json"""
    columns = []
    # For each index in the array, create x, y, z columns
    for idx in idx_array:
        columns.extend([
            f'x_{idx}',
            f'y_{idx}',
            f'z_{idx}' if USE_Z_AXIS else None
        ])
    # Filter out None values if not using Z axis
    columns = [col for col in columns if col is not None]
    return columns

SEL_COLS = idx_to_cols(POINT_LANDMARKS)
print(len(SEL_COLS))
print(len(set(SEL_COLS)))

def get_index(arr):
    """Get indices for the given landmark array"""
    cols = idx_to_cols(arr)
    idx = [SEL_COLS.index(c) for c in cols]
    return idx

In [None]:
DATA_DIM = 3 if USE_Z_AXIS else 2
HAND_NUMS = len(LHAND)+ len(RHAND)
FACE_NUMS = len(LIP) + len(REYE) + len(LEYE) + len(NOSE)
POSE_NUMS = len(POSE)
print(HAND_NUMS, FACE_NUMS, POSE_NUMS)
print(HAND_NUMS+FACE_NUMS+POSE_NUMS)

LIP_IDX = get_index(LIP)
LHAND_IDX = get_index(LHAND)
RHAND_IDX = get_index(RHAND)
NOSE_IDX = get_index(NOSE)
REYE_IDX = get_index(REYE)
LEYE_IDX = get_index(LEYE)

LLIP_IDX = get_index(LLIP)
RLIP_IDX = get_index(RLIP)
LNOSE_IDX = get_index(LNOSE)
RNOSE_IDX = get_index(RNOSE)

POSE_IDX = get_index(POSE)
LPOSE_IDX = get_index(LPOSE)
RPOSE_IDX = get_index(RPOSE)

In [None]:
NAN_FILL_VALUE = tf.constant(0, dtype=tf.float32)
NAN_VALUE = tf.constant(np.nan, dtype=tf.float32)
PADDING_MASKING_VALUE = tf.constant(-100, dtype=tf.float32)

def resize_pad(x):
    if tf.shape(x)[0] < FRAME_LEN:
        # if MASKING:
        #     if tf.shape(x)[0] < 32:
        #         x = tf.pad(
        #             x,
        #             ([[0, (32 - tf.shape(x)[0])], [0, 0], [0, 0]]),
        #             constant_values = PADDING_MASKING_VALUE,
        #         )
        x = tf.pad(
            x,
            ([[0, (FRAME_LEN - tf.shape(x)[0])], [0, 0], [0, 0]]),
            constant_values = NAN_FILL_VALUE,
        )
    else:
        x = tf.image.resize(x, (FRAME_LEN, tf.shape(x)[1]), "nearest")
    return x

def tf_nan_mean(x, axis=0, keepdims=False):
    return tf.reduce_sum(
        tf.where(tf.math.is_nan(x), tf.zeros_like(x), x), axis=axis, keepdims=keepdims
    ) / tf.reduce_sum(
        tf.where(tf.math.is_nan(x), tf.zeros_like(x), tf.ones_like(x)),
        axis=axis,
        keepdims=keepdims,
    )


def tf_nan_std(x, center=None, axis=0, keepdims=False):
    if center is None:
        center = tf_nan_mean(x, axis=axis, keepdims=True)
    d = x - center
    return tf.math.sqrt(tf_nan_mean(d * d, axis=axis, keepdims=keepdims))


def self_norm(x):
    # input batch, 21 + 39 + 33, 2
    mean_no_nan = tf_nan_mean(x, axis=[0,1],keepdims=True)
    std_no_nan = tf_nan_std(x, center=mean_no_nan, axis=[0,1],keepdims=True)
    x = (x - mean_no_nan) / (std_no_nan)
    return x

def global_norm(x):
    # face = x[:,:len(LIP_IDX),:]
    pose = x[:,-len(POSE_IDX):,:]
    mean_no_nan = tf_nan_mean(pose, axis=[0,1],keepdims=True)
    std_no_nan = tf_nan_std(x, center=mean_no_nan, axis=[0,1],keepdims=True)
    x = (x - mean_no_nan) / (std_no_nan) / 3
    return x


def split_data(x):
    # POINT_LANDMARKS = LIP + LHAND + RHAND + NOSE + REYE + LEYE

    lip = tf.gather(x, LIP_IDX, axis=1)
    if USE_Z_AXIS:
        lip_x = lip[:, 0 * (len(LIP_IDX) // 3) : 1 * (len(LIP_IDX) // 3)]
        lip_y = lip[:, 1 * (len(LIP_IDX) // 3) : 2 * (len(LIP_IDX) // 3)]
        lip_z = lip[:, 2 * (len(LIP_IDX) // 3) : 3 * (len(LIP_IDX) // 3)]
        lip = tf.concat(
            [lip_x[..., tf.newaxis], lip_y[..., tf.newaxis], lip_z[..., tf.newaxis]],
            axis=-1,
        )
    else:
        lip_x = lip[:, 0 * (len(LIP_IDX) // 2) : 1 * (len(LIP_IDX) // 2)]
        lip_y = lip[:, 1 * (len(LIP_IDX) // 2) : 2 * (len(LIP_IDX) // 2)]

    lhand = tf.gather(x, LHAND_IDX, axis=1)
    rhand = tf.gather(x, RHAND_IDX, axis=1)
    if USE_Z_AXIS:
        lhand_x = lhand[:, 0 * (len(LHAND_IDX) // 3) : 1 * (len(LHAND_IDX) // 3)]
        lhand_y = lhand[:, 1 * (len(LHAND_IDX) // 3) : 2 * (len(LHAND_IDX) // 3)]
        lhand_z = lhand[:, 2 * (len(LHAND_IDX) // 3) : 3 * (len(LHAND_IDX) // 3)]
        lhand = tf.concat(
            [lhand_x[..., tf.newaxis], lhand_y[..., tf.newaxis], lhand_z[..., tf.newaxis]],
            axis=-1,
        )
        rhand_x = rhand[:, 0 * (len(RHAND_IDX) // 3) : 1 * (len(RHAND_IDX) // 3)]
        rhand_y = rhand[:, 1 * (len(RHAND_IDX) // 3) : 2 * (len(RHAND_IDX) // 3)]
        rhand_z = rhand[:, 2 * (len(RHAND_IDX) // 3) : 3 * (len(RHAND_IDX) // 3)]
        rhand = tf.concat(
            [rhand_x[..., tf.newaxis], rhand_y[..., tf.newaxis], rhand_z[..., tf.newaxis]],
            axis=-1,
        )
    else:
        lhand_x = lhand[:, 0 * (len(LHAND_IDX) // 2) : 1 * (len(LHAND_IDX) // 2)]
        lhand_y = lhand[:, 1 * (len(LHAND_IDX) // 2) : 2 * (len(LHAND_IDX) // 2)]
        rhand_x = rhand[:, 0 * (len(RHAND_IDX) // 2) : 1 * (len(RHAND_IDX) // 2)]
        rhand_y = rhand[:, 1 * (len(RHAND_IDX) // 2) : 2 * (len(RHAND_IDX) // 2)]
        lhand = tf.concat([lhand_x[..., tf.newaxis], lhand_y[..., tf.newaxis]], axis=-1)
        rhand = tf.concat([rhand_x[..., tf.newaxis], rhand_y[..., tf.newaxis]], axis=-1)

    nose = tf.gather(x, NOSE_IDX, axis=1)
    if USE_Z_AXIS:
        nose_x = nose[:, 0 * (len(NOSE_IDX) // 3) : 1 * (len(NOSE_IDX) // 3)]
        nose_y = nose[:, 1 * (len(NOSE_IDX) // 3) : 2 * (len(NOSE_IDX) // 3)]
        nose_z = nose[:, 2 * (len(NOSE_IDX) // 3) : 3 * (len(NOSE_IDX) // 3)]
        nose = tf.concat(
            [nose_x[..., tf.newaxis], nose_y[..., tf.newaxis], nose_z[..., tf.newaxis]],axis=-1)
    else:
        nose_x = nose[:, 0 * (len(NOSE_IDX) // 2) : 1 * (len(NOSE_IDX) // 2)]
        nose_y = nose[:, 1 * (len(NOSE_IDX) // 2) : 2 * (len(NOSE_IDX) // 2)]
        nose = tf.concat([nose_x[..., tf.newaxis], nose_y[..., tf.newaxis]], axis=-1)

    reye = tf.gather(x, REYE_IDX, axis=1)
    leye = tf.gather(x, LEYE_IDX, axis=1)
    if USE_Z_AXIS:
        reye_x = reye[:, 0 * (len(REYE_IDX) // 3) : 1 * (len(REYE_IDX) // 3)]
        reye_y = reye[:, 1 * (len(REYE_IDX) // 3) : 2 * (len(REYE_IDX) // 3)]
        reye_z = reye[:, 2 * (len(REYE_IDX) // 3) : 3 * (len(REYE_IDX) // 3)]
        reye = tf.concat(
            [reye_x[..., tf.newaxis], reye_y[..., tf.newaxis], reye_z[..., tf.newaxis]],
            axis=-1,
        )
        leye_x = leye[:, 0 * (len(LEYE_IDX) // 3) : 1 * (len(LEYE_IDX) // 3)]
        leye_y = leye[:, 1 * (len(LEYE_IDX) // 3) : 2 * (len(LEYE_IDX) // 3)]
        leye_z = leye[:, 2 * (len(LEYE_IDX) // 3) : 3 * (len(LEYE_IDX) // 3)]
        leye = tf.concat(
            [leye_x[..., tf.newaxis], leye_y[..., tf.newaxis], leye_z[..., tf.newaxis]],
            axis=-1,
        )
    else:
        reye_x = reye[:, 0 * (len(REYE_IDX) // 2) : 1 * (len(REYE_IDX) // 2)]
        reye_y = reye[:, 1 * (len(REYE_IDX) // 2) : 2 * (len(REYE_IDX) // 2)]
        leye_x = leye[:, 0 * (len(LEYE_IDX) // 2) : 1 * (len(LEYE_IDX) // 2)]
        leye_y = leye[:, 1 * (len(LEYE_IDX) // 2) : 2 * (len(LEYE_IDX) // 2)]
        reye = tf.concat([reye_x[..., tf.newaxis], reye_y[..., tf.newaxis]], axis=-1)
        leye = tf.concat([leye_x[..., tf.newaxis], leye_y[..., tf.newaxis]], axis=-1)

    face = tf.concat([lip,nose,reye,leye], axis=1)


    lpose = tf.gather(x, LPOSE_IDX, axis=1)
    rpose = tf.gather(x, RPOSE_IDX, axis=1)
    if USE_Z_AXIS:
        lpose_x = lpose[:, 0 * (len(LPOSE_IDX) // 3) : 1 * (len(LPOSE_IDX) // 3)]
        lpose_y = lpose[:, 1 * (len(LPOSE_IDX) // 3) : 2 * (len(LPOSE_IDX) // 3)]
        lpose_z = lpose[:, 2 * (len(LPOSE_IDX) // 3) : 3 * (len(LPOSE_IDX) // 3)]
        lpose = tf.concat([lpose_x[..., tf.newaxis], lpose_y[..., tf.newaxis], lpose_z[..., tf.newaxis]], axis=-1)
        rpose_x = rpose[:, 0 * (len(RPOSE_IDX) // 3) : 1 * (len(RPOSE_IDX) // 3)]
        rpose_y = rpose[:, 1 * (len(RPOSE_IDX) // 3) : 2 * (len(RPOSE_IDX) // 3)]
        rpose_z = rpose[:, 2 * (len(RPOSE_IDX) // 3) : 3 * (len(RPOSE_IDX) // 3)]
        rpose = tf.concat([rpose_x[..., tf.newaxis], rpose_y[..., tf.newaxis], rpose_z[..., tf.newaxis]], axis=-1)

    pose = tf.concat([lpose,rpose], axis=1)

    x = tf.concat([face[:,:len(LIP)],lhand,rhand,face[:,len(LIP):],pose], axis=1)
    return x



def spatial_random_rotation_for_finger(
    xyz,
    degree=(-10, 10),
):
    # use first position for rotation center
    if USE_Z_AXIS:
        xy = xyz[:, :, 0:2]
    else:
        xy = xyz
    center = xy[:,0:1,:]
    # center = tf.reduce_mean(xy, axis=[0,1])
    if degree is not None:
        xy -= center
        degree = tf.random.uniform(shape=[], minval=degree[0], maxval=degree[1], dtype=tf.float32)
        radian = degree / 180 * np.pi
        c = tf.math.cos(radian)
        s = tf.math.sin(radian)
        rotate_mat = tf.identity(
            [
                [c, s],
                [-s, c],
            ]
        )
        xy = xy @ rotate_mat
        xy = xy + center
    if USE_Z_AXIS:
        return tf.concat([xy, xyz[:, :, 2:3]], axis=-1)
    else:
        return xy

def spatial_random_scale_for_finger(
    xyz,
    scale=(0.9, 1.1),
):
    # use first position for rotation center
    if USE_Z_AXIS:
        xy = xyz[:, :, 0:2]
    else:
        xy = xyz
    center = xy[:,0:1,:]
    # center = tf.reduce_mean(xy, axis=[0,1])
    if scale is not None:
        xy -= center
        scale = tf.random.uniform(shape=[], minval=scale[0], maxval=scale[1], dtype=tf.float32)
        xy = xy * scale
        xy = xy + center
    if USE_Z_AXIS:
        return tf.concat([xy, xyz[:, :, 2:3]], axis=-1)
    else:
        return xy

def only_rotate(xyz,degree=(-15,15),shear = (-0.10,0.10),scale  = (0.75,1.5),
                ):
    scale = tf.random.uniform((),*scale)
    xyz *= scale
    if USE_Z_AXIS:
        xy = xyz[:, :, 0:2]
    else:
        xy = xyz
    center = tf_nan_mean(xy, axis=[0,1])
    degree = tf.random.uniform((),*degree)
    xy -= center

    radian = degree/180*np.pi
    c = tf.math.cos(radian)
    s = tf.math.sin(radian)
    rotate_mat = tf.identity([
        [c,s],
        [-s, c],
    ])
    xy = xy @ rotate_mat


    shear_x = shear_y = tf.random.uniform((),*shear)
    if tf.random.uniform(()) < 0.5:
        shear_x = 0.
    else:
        shear_y = 0.
    shear_mat = tf.identity([
        [1.,shear_x],
        [shear_y,1.]
    ])
    xy = xy @ shear_mat



    xy = xy + center
    if USE_Z_AXIS:
        return tf.concat([xy, xyz[:, :, 2:3]], axis=-1)
    else:
        return xy

def inner_flip(x):
    x,y,z = tf.unstack(x, axis=-1)
    x = 2*tf_nan_mean(x, axis=[0,1], keepdims=True) -x
    new_x = tf.stack([x,y,z], -1)
    return new_x

def flip_lr(x):
    x,y,z = tf.unstack(x, axis=-1)
    face = tf.concat([x[:,:len(LIP)], x[:,len(LIP)+len(LHAND)+len(RHAND):-POSE_NUMS]], axis=1)
    face_mean = tf_nan_mean(face, axis=[0,1], keepdims=True)
    x = 2*face_mean -x
    # x = 1 - x
    new_x = tf.stack([x,y,z], -1)

    face = tf.concat([new_x[:,:len(LIP),:], new_x[:,len(LIP)+len(LHAND)+len(RHAND):-POSE_NUMS,:]], axis=1)
    leye = face[:,-len(LEYE):,:]
    reye = face[:,-len(REYE)-len(LEYE):-len(LEYE),:]
    face = tf.concat([face[:,:-len(LEYE)-len(REYE),:],leye,reye], axis=1)
    rhand = new_x[:,len(LIP)+len(LHAND):len(LIP)+len(LHAND)+len(RHAND),:]
    lhand = new_x[:,len(LIP):len(LIP)+len(LHAND),:]
    pose = new_x[:,-POSE_NUMS:,:]
    new_x = tf.concat([face[:,:len(LIP)],rhand, lhand,face[:,len(LIP):],pose], axis=1)

    lip = new_x[:,:len(LIP),:]
    lip = inner_flip(lip)
    nose = new_x[:,len(LIP)+len(LHAND)+len(RHAND):len(LIP)+len(LHAND)+len(RHAND)+len(NOSE),:]
    nose = inner_flip(nose)
    pose = new_x[:,-POSE_NUMS:,:]
    rpose = pose[:,POSE_NUMS//2:,:]
    lpose = pose[:,:POSE_NUMS//2,:]
    pose = tf.concat([rpose,lpose], axis=1)
    new_x = tf.concat([lip, new_x[:,len(LIP):len(LIP)+len(LHAND),:], new_x[:,len(LIP)+len(LHAND):len(LIP)+len(LHAND)+len(RHAND),:], nose,new_x[:,len(LIP)+len(LHAND)+len(RHAND)+len(NOSE):-POSE_NUMS,:],pose], axis=1)
    return new_x



def interp1d_(x, target_len, method="random"):
    target_len = tf.maximum(1, target_len)
    if method == "random":
        if tf.random.uniform(()) < 0.33:
            x = tf.image.resize(x, (target_len, tf.shape(x)[1]), "bilinear")
        else:
            if tf.random.uniform(()) < 0.5:
                x = tf.image.resize(x, (target_len, tf.shape(x)[1]), "bicubic")
            else:
                x = tf.image.resize(x, (target_len, tf.shape(x)[1]), "nearest")
    else:
        x = tf.image.resize(x, (target_len, tf.shape(x)[1]), method)
    return x


def resample(x, rate):
    # re-resample
    rate = tf.random.uniform((),rate[0], rate[1])

    length = tf.shape(x)[0]

    new_size = tf.cast(rate * tf.cast(length, tf.float32), tf.int32)
    return interp1d_(x, new_size)

def uniform_resample(x):
    l = 10/tf.cast(tf.shape(x)[0], tf.float32)
    r = FRAME_LEN/tf.cast(tf.shape(x)[0], tf.float32)
    return resample(x, (l,r))

def resample_sub(x, rate=(0.5, 1.5)):

    if tf.shape(x)[0] < 2:
        x = resample(x, (1.0,5.0))
        return x
    if tf.random.uniform(()) < 0.5:
        start = tf.random.uniform(shape=[], minval=0, maxval=tf.shape(x)[0]-1, dtype=tf.int32)
        end = tf.random.uniform(shape=[], minval=start+1, maxval=tf.shape(x)[0], dtype=tf.int32)
        if start > end:
            start, end = end, start
        x = tf.concat([x[:start], resample(x[start:end], rate), x[end:]], axis=0)
    else:
        x = resample(x, rate)
    return x


def mask_along_axis(tensor, param_min, param_max, axis, mask_value=float('nan')):
    tensor_shape = tf.shape(tensor)
    dim_size = tensor_shape[axis]
    min_mask_size = tf.cast(param_min * tf.cast(dim_size, tf.float32),tf.int32)
    max_mask_size = tf.cast(param_max * tf.cast(dim_size, tf.float32), tf.int32)
    mask_size = tf.cond(
        tf.equal(min_mask_size, max_mask_size),
        lambda: min_mask_size,
        lambda: tf.random.uniform((), min_mask_size, max_mask_size+1, dtype=tf.int32)
    )
    mask_start = tf.random.uniform([], 0, dim_size - mask_size, tf.int32)
    indices = tf.cast(tf.range(start=0, limit=dim_size, delta=1),tf.int32)
    mask = tf.logical_or(indices < mask_start , indices >= (mask_start + mask_size))
    if axis==1:
        mask = tf.reshape(tf.cast(mask, tf.float32),(1,dim_size,1))
    else:
        mask = tf.reshape(tf.cast(mask, tf.float32),(dim_size,1,1))
    masked_tensor = tf.where(mask==0,mask_value,tensor)
    # masked_tensor = mask * tensor
    return masked_tensor

def discrete_mask(tensor, param_min, param_max, axis,mask_value=float('nan')):
    tensor_shape = tf.shape(tensor)
    dim_size = tensor_shape[axis]

    min_mask_size = tf.cast(param_min * tf.cast(dim_size, tf.float32), tf.int32)
    max_mask_size = tf.cast(param_max * tf.cast(dim_size, tf.float32), tf.int32)

    mask_size = tf.cond(
        tf.equal(min_mask_size, max_mask_size),
        lambda: min_mask_size,
        lambda: tf.random.uniform((), min_mask_size, max_mask_size + 1, dtype=tf.int32)
    )

    mask_indices = tf.random.shuffle(tf.range(dim_size))[:mask_size]

    mask = tf.scatter_nd(
        tf.expand_dims(mask_indices, 1),
        tf.ones(mask_size, dtype=tf.float32),
        [dim_size]
    )

    if axis == 1:
        mask = tf.reshape(mask, (1, dim_size, 1))
    else:
        mask = tf.reshape(mask, (dim_size, 1, 1))

    masked_tensor = tf.where(mask==0,mask_value,tensor)
    return masked_tensor



fingers = [[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16],[17,18,19,20]]


def random_rotate_fingers(x):
    # x shape: (seq_len, 21, 2)
    finger1 = tf.gather(x, fingers[0], axis=1)
    finger2 = tf.gather(x, fingers[1], axis=1)
    finger3 = tf.gather(x, fingers[2], axis=1)
    finger4 = tf.gather(x, fingers[3], axis=1)
    finger5 = tf.gather(x, fingers[4], axis=1)
    finger1 = spatial_random_rotation_for_finger(finger1)
    finger2 = spatial_random_rotation_for_finger(finger2)
    finger3 = spatial_random_rotation_for_finger(finger3)
    finger4 = spatial_random_rotation_for_finger(finger4)
    finger5 = spatial_random_rotation_for_finger(finger5)
    finger1 = spatial_random_scale_for_finger(finger1)
    finger2 = spatial_random_scale_for_finger(finger2)
    finger3 = spatial_random_scale_for_finger(finger3)
    finger4 = spatial_random_scale_for_finger(finger4)
    finger5 = spatial_random_scale_for_finger(finger5)
    hand_root = tf.expand_dims(x[:,0,:], axis=1)
    x = tf.concat([hand_root, finger1, finger2, finger3, finger4, finger5], axis=1)
    return x


def spatial_mask(x, size=(0.1,0.4), mask_value=float('nan')):
    x_min = tf.math.reduce_min(x[~tf.math.is_nan(x[...,0])])
    x_max = tf.math.reduce_max(x[~tf.math.is_nan(x[...,0])])
    y_min = tf.math.reduce_min(x[~tf.math.is_nan(x[...,1])])
    y_max = tf.math.reduce_max(x[~tf.math.is_nan(x[...,1])])
    mask_offset_x = tf.random.uniform((1,), x_min, x_max)
    mask_offset_y = tf.random.uniform((1,), y_min, y_max)
    mask_size_x = tf.random.uniform((1,), *size) * (x_max - x_min)
    mask_size_y = tf.random.uniform((1,), *size) * (y_max - y_min)
    mask_x = (mask_offset_x<x[...,0]) & (x[...,0] < mask_offset_x + mask_size_x)
    mask_y = (mask_offset_y<x[...,1]) & (x[...,1] < mask_offset_y + mask_size_y)
    mask = mask_x & mask_y
    x = tf.where(mask[...,None], mask_value, x)
    return x


def temporal_mask(x, size=(0.1,0.3), mask_value=float('nan')):
    l = tf.shape(x)[0]
    mask_size = tf.random.uniform((), *size)
    mask_size = tf.cast(tf.cast(l, tf.float32) * mask_size, tf.int32)
    mask_offset = tf.random.uniform((), 0, tf.clip_by_value(l-mask_size,1,l), dtype=tf.int32)
    x = tf.tensor_scatter_nd_update(x,tf.range(mask_offset, mask_offset+mask_size)[...,None],tf.fill([mask_size,KPT_NUM//DATA_DIM,DATA_DIM],mask_value))
    return x


def random_shift(x, shift_range=0.1):
    # x shape TxKxdim
    shift_var = tf.random.uniform((1,1,3),-shift_range,shift_range)
    shift_var = tf.tile(shift_var, (tf.shape(x)[0], tf.shape(x)[1], 1))
    x = x + shift_var
    return x

def rotate_partial(x,partial=False):
    def apply(x):
        face = tf.concat([x[:,:len(LIP),:], x[:,len(LIP)+len(LHAND)+len(RHAND):-POSE_NUMS,:]], axis=1)
        face = only_rotate(face,degree=(-30,30))
        x = tf.concat([face[:,:len(LIP),:],x[:,len(LIP):len(LIP)+len(LHAND)+len(RHAND),:], face[:,len(LIP):,:],x[:,-POSE_NUMS:,:]], axis=1)
        lhand = x[:,len(LIP):len(LIP)+len(LHAND),:]
        lhand = only_rotate(lhand,degree=(-30,30))
        x = tf.concat([x[:,:len(LIP),:], lhand, x[:,len(LIP)+len(LHAND):,:]], axis=1)
        rhand = x[:,len(LIP)+len(LHAND):len(LIP)+len(LHAND)+len(RHAND),:]
        rhand = only_rotate(rhand,degree=(-30,30))
        x = tf.concat([x[:,:len(LIP)+len(LHAND),:], rhand, x[:,len(LIP)+len(LHAND)+len(RHAND):,:]], axis=1)
        pose = x[:,-POSE_NUMS:,:]
        pose = only_rotate(pose,degree=(-30,30))
        x = tf.concat([x[:,:-POSE_NUMS,:], pose], axis=1)
        return x
    if tf.shape(x)[0] > 10 and partial:
        start = tf.random.uniform(shape=[], minval=0, maxval=tf.shape(x)[0]-1, dtype=tf.int32)
        end = tf.random.uniform(shape=[], minval=start+1, maxval=tf.shape(x)[0], dtype=tf.int32)
        x = tf.concat([x[:start], apply(x[start:end]), x[end:]], axis=0)
    else:
        return apply(x)
    return x

def rotate_finger_partial(x,partial=False):
    def apply(x):
        lhand = x[:,len(LIP):len(LIP)+len(LHAND),:]
        lhand = random_rotate_fingers(lhand)
        rhand = x[:,len(LIP)+len(LHAND):len(LIP)+len(LHAND)+len(RHAND),:]
        rhand = random_rotate_fingers(rhand)
        x = tf.concat([x[:,:len(LIP),:], lhand, rhand, x[:,len(LIP)+len(LHAND)+len(RHAND):,:]], axis=1)
        return x
    if tf.shape(x)[0] > 10 and partial:
        start = tf.random.uniform(shape=[], minval=0, maxval=tf.shape(x)[0]-1, dtype=tf.int32)
        end = tf.random.uniform(shape=[], minval=start+1, maxval=tf.shape(x)[0], dtype=tf.int32)
        x = tf.concat([x[:start], apply(x[start:end]), x[end:]], axis=0)
    else:
        return apply(x)
    return x

def shift_partial(x,partial=False):
    def apply(x):
        face = tf.concat([x[:,:len(LIP),:], x[:,len(LIP)+len(LHAND)+len(RHAND):-POSE_NUMS,:]], axis=1)
        face = random_shift(face)
        lhand = x[:,len(LIP):len(LIP)+len(LHAND),:]
        lhand = random_shift(lhand)
        rhand = x[:,len(LIP)+len(LHAND):len(LIP)+len(LHAND)+len(RHAND),:]
        rhand = random_shift(rhand)
        pose = x[:,-POSE_NUMS:,:]
        pose = random_shift(pose)
        x = tf.concat([face[:,:len(LIP),:], lhand, rhand, face[:,len(LIP):,:],pose], axis=1)

        return x
    if tf.shape(x)[0] > 10 and partial:
        start = tf.random.uniform(shape=[], minval=0, maxval=tf.shape(x)[0]-1, dtype=tf.int32)
        end = tf.random.uniform(shape=[], minval=start+1, maxval=tf.shape(x)[0], dtype=tf.int32)
        x = tf.concat([x[:start], apply(x[start:end]), x[end:]], axis=0)
    else:
        return apply(x)
    return x


def combined_mask_along_axis(tensor, s1, s2, mask_value=float('nan')):
    def get_mask(tensor, param_min, param_max, axis, mask_value=float('nan')):
        tensor_shape = tf.shape(tensor)
        dim_size = tensor_shape[axis]
        min_mask_size = tf.cast(param_min * tf.cast(dim_size, tf.float32),tf.int32)
        max_mask_size = tf.cast(param_max * tf.cast(dim_size, tf.float32), tf.int32)
        mask_size = tf.cond(
            tf.equal(min_mask_size, max_mask_size),
            lambda: min_mask_size,
            lambda: tf.random.uniform((), min_mask_size, max_mask_size+1, dtype=tf.int32)
        )
        mask_start = tf.random.uniform([], 0, dim_size - mask_size, tf.int32)
        indices = tf.cast(tf.range(start=0, limit=dim_size, delta=1),tf.int32)
        mask = tf.logical_or(indices < mask_start , indices >= (mask_start + mask_size))
        if axis==1:
            mask = tf.reshape(mask,(1,dim_size,1))
        else:
            mask = tf.reshape(mask,(dim_size,1,1))
        return mask
    t_mask = get_mask(tensor, s1[0], s1[1], 1, mask_value)
    f_mask = get_mask(tensor, s2[0], s2[1], 0, mask_value)
    mask = tf.logical_or(t_mask, f_mask)
    mask = tf.cast(mask, tf.float32)
    masked_tensor = tf.where(mask==0,mask_value,tensor)
    return masked_tensor

def augment_fn(x):
    if tf.random.uniform(()) > 0.33:
        x = resample_sub(x)
        # if tf.random.uniform(()) > 0.5:
        #     x = resample_sub(x) # temporal
        # else:
        #     x = uniform_resample(x) # temporal
    if tf.random.uniform(()) > 0.5:
        # filp face
        x = flip_lr(x)

    if tf.random.uniform(()) < 0.5:
        if tf.random.uniform(()) < 0.5:
            x = rotate_partial(x,partial=True)
        else:
            x = rotate_partial(x,partial=False)

    if tf.random.uniform(()) < 0.5:
        if tf.random.uniform(()) < 0.5:
            x = rotate_finger_partial(x,partial=True)
        else:
            x = rotate_finger_partial(x,partial=False)

    # # ramdom shift
    if tf.random.uniform(()) > 0.5:
        if tf.random.uniform(()) < 0.5:
            x = shift_partial(x,partial=True)
        else:
            x = shift_partial(x,partial=False)

    if tf.random.uniform(()) < 0.5:
        # T = tf.minimum(tf.cast(tf.shape(x)[0],tf.float32),200.0)
        # factor = T/200 * 0.3
        if tf.random.uniform(()) > 0.33:
            x = mask_along_axis(x,0.0,0.4,0) # originally 0.2~0.4, can change to 0.1~0.3 for faster convergence
        elif tf.random.uniform(()) > 0.5:
            x = temporal_mask(x,size=(0.0,0.4))
        else:
            x = discrete_mask(x,0.0,0.4,0)

    if tf.random.uniform(()) < 0.5:
        # spatial masking
        if tf.random.uniform(()) > 0.33:
            x = mask_along_axis(x,0.0,0.4,1) # can be 0.2,0.4
        elif tf.random.uniform(()) > 0.5:
            x = spatial_mask(x,size=(0.0,0.4))
        else:
            x = discrete_mask(x,0.0,0.4,1)


    # x = tf.where(tf.math.is_nan(x),NAN_FILL_VALUE,x)
    return x


KPT_NUM = (len(LIP) + len(LHAND) + len(RHAND) + len(NOSE) + len(LEYE) + len(REYE) + len(LPOSE) + len(RPOSE)) * DATA_DIM


def preprocess1(x):
    x = split_data(x)
    # x = tf.where(tf.math.is_nan(x), NAN_FILL_VALUE, x)
    return x


def preprocess2(x):
    # x = global_norm(x)
    face = tf.concat([x[:,:len(LIP),:], x[:,len(LIP)+len(LHAND)+len(RHAND):-POSE_NUMS,:]], axis=1)
    pose = x[:, -POSE_NUMS:, :]
    lhand = x[:,len(LIP):len(LIP)+len(LHAND),:]
    rhand = x[:,len(LIP)+len(LHAND):len(LIP)+len(LHAND)+len(RHAND),:]
    face = self_norm(face)
    lhand = self_norm(lhand)
    rhand = self_norm(rhand)
    pose = self_norm(pose)
    x = tf.concat([face, lhand, rhand, pose], axis=1)
    # x = tf.concat([face, lhand, rhand], axis=1)
    # x = tf.concat([face[:,:len(LIP),:], x[:,len(LIP):len(LIP)+len(LHAND)+len(RHAND),:],face[:,len(LIP):,:],x[:,-POSE_NUMS:,:]], axis=1)
    # x = x[...,:2]


    x = resize_pad(x)


    dx = x[1:,:,:] - x[:-1,:,:]
    dx = tf.concat([tf.zeros((1, KPT_NUM//DATA_DIM, tf.shape(x)[-1])), dx], axis=0) # Tx21x2
    if tf.shape(x)[0] > 1:
        dx2 = x[2:,:,:] - x[:-2,:,:]
        dx2 = tf.concat([tf.zeros((2, KPT_NUM//DATA_DIM, tf.shape(x)[-1])), dx2], axis=0) # Tx21x2
    else:
        dx2 = tf.zeros_like(dx)
    x = tf.concat([x, dx,dx2], axis=-1)

    x = tf.reshape(x, (FRAME_LEN, INPUT_DIM))
    x = tf.where(tf.math.is_nan(x), NAN_FILL_VALUE, x)

    return x

def total_process(x, phrase, augment=False):

    x = preprocess1(x)
    if augment:
        x = augment_fn(x)
    x = preprocess2(x)
    return x,phrase

INPUT_DIM = KPT_NUM * 3 # coordinate + velocity
TEMPORAL_DIM = FRAME_LEN

def preprocess_(x):
    x = preprocess1(x)
    x = preprocess2(x)
    return x

def preprocess_fn(x, phrase, augment=False):
    batch = total_process(x, phrase,augment)
    return batch


print(TEMPORAL_DIM,INPUT_DIM)

In [None]:
table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=list(char_to_num.keys()),
        values=list(char_to_num.values()),
    ),
    default_value=tf.constant(-1),
    name="class_weight",
)


def decode_fn(record_bytes):
    schema = {COL: tf.io.VarLenFeature(dtype=tf.float32) for COL in SEL_COLS}
    schema["phrase"] = tf.io.FixedLenFeature([], dtype=tf.string)
    features = tf.io.parse_single_example(record_bytes, schema)

    phrase = features["phrase"]
    landmarks = [tf.sparse.to_dense(features[COL]) for COL in SEL_COLS]
    landmarks = tf.transpose(landmarks)
    phrase = tf.strings.bytes_split(phrase)

    phrase = table.lookup(phrase)
    phrase = tf.pad(
        phrase,
        paddings=[[0, PHRASE_MAX_LEN - tf.shape(phrase)[0]]],
        constant_values=pad_token_idx,
    )
    return landmarks, phrase


val_file_ids = [234418913]#, 1967755728, 425182931]
if VAL:
    val_tffiles = df[df.file_id.isin(val_file_ids)].file_id.map(lambda x: f"/kaggle/input/aslfr-preprocess-dataset/tfds-v2/{x}.tfrecord").unique()
    train_tffiles = df[~df.file_id.isin(val_file_ids)].file_id.map(lambda x: f"/kaggle/input/aslfr-preprocess-dataset/tfds-v2/{x}.tfrecord").unique()
else:
    train_tffiles = df.file_id.map(lambda x: f"/kaggle/input/aslfr-preprocess-dataset/tfds-v2/{x}.tfrecord").unique()

In [None]:
# val_len = int(len(tffiles) * 0.05)
val_size = BATCH_SIZE
if VAL:
    print(' Train Val Split')
    # np.random.shuffle(tffiles)
#     train_dataset = tf.data.TFRecordDataset(tffiles[val_len:],num_parallel_reads=tf.data.AUTOTUNE
#                                             ).map(decode_fn,num_parallel_calls=tf.data.AUTOTUNE).map(lambda x, phrase: preprocess_fn(x, phrase,augment=True),num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE, drop_remainder=True).prefetch(buffer_size=tf.data.AUTOTUNE)

#     val_dataset = tf.data.TFRecordDataset(tffiles[:val_len],num_parallel_reads=tf.data.AUTOTUNE
#                                           ).map(decode_fn,num_parallel_calls=tf.data.AUTOTUNE).map(preprocess_fn,tf.data.AUTOTUNE).batch(BATCH_SIZE, drop_remainder=True).prefetch(buffer_size=tf.data.AUTOTUNE)
    train_dataset = tf.data.TFRecordDataset(train_tffiles,num_parallel_reads=tf.data.AUTOTUNE
                                            ).map(decode_fn,num_parallel_calls=tf.data.AUTOTUNE).map(lambda x, phrase: preprocess_fn(x, phrase,augment=True),num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).repeat().prefetch(buffer_size=tf.data.AUTOTUNE)

    val_dataset = tf.data.TFRecordDataset(val_tffiles,num_parallel_reads=tf.data.AUTOTUNE
                                          ).map(decode_fn,num_parallel_calls=tf.data.AUTOTUNE).map(preprocess_fn,tf.data.AUTOTUNE).batch(BATCH_SIZE).repeat().prefetch(buffer_size=tf.data.AUTOTUNE)
    aux_dataset = tf.data.TFRecordDataset(train_tffiles[:1],num_parallel_reads=tf.data.AUTOTUNE
                                          ).map(decode_fn,num_parallel_calls=tf.data.AUTOTUNE).map(preprocess_fn,tf.data.AUTOTUNE).batch(BATCH_SIZE).repeat().prefetch(buffer_size=tf.data.AUTOTUNE)
    print('train size', 64211+1997)
    print('val size', 1000)
    print('train datafiles',train_tffiles)
    print('val datafiles',val_tffiles)
else:
    print('Full Train')
#     train_dataset = tf.data.TFRecordDataset(tffiles, num_parallel_reads=tf.data.AUTOTUNE
#                                             ).map(decode_fn,tf.data.AUTOTUNE).map(lambda x, phrase: preprocess_fn(x, phrase,augment=True),num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE, drop_remainder=True).prefetch(buffer_size=tf.data.AUTOTUNE)
    train_dataset = tf.data.TFRecordDataset(train_tffiles, num_parallel_reads=tf.data.AUTOTUNE
                                             ).map(decode_fn,tf.data.AUTOTUNE).map(lambda x, phrase: preprocess_fn(x, phrase,augment=True),num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).repeat().prefetch(buffer_size=tf.data.AUTOTUNE)

    print('train size', 64211+2997)

In [None]:
class LearningRateLogger(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer.lr
        # If the learning rate is a decayed value, compute its value
        if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
            lr = lr(self.model.optimizer.iterations)
        print(f"\nLearning rate at end of epoch {epoch}: {lr.numpy()}\n")

# Instantiate the callback
lr_logger = LearningRateLogger()

In [None]:
class OneCycleLR(tf.keras.optimizers.schedules.LearningRateSchedule):
    '''
    Unified single-cycle learning rate scheduler for tensorflow.
    2022 Hoyeol Sohn <hoeyol0730@gmail.com>
    '''
    def __init__(self,
                lr=1e-3,
                epochs=10,
                steps_per_epoch=100,
                steps_per_update=1,
                resume_epoch=0,
                decay_epochs=10,
                sustain_epochs=0,
                warmup_epochs=10,
                lr_start=0,
                lr_min=0,
                warmup_type='linear',
                decay_type='cosine',
                finetune_steps=0,
                finetune_lr=1e-5,
                **kwargs):

        super().__init__(**kwargs)
        self.lr = float(lr)
        self.epochs = float(epochs)
        self.steps_per_update = float(steps_per_update)
        self.resume_epoch = float(resume_epoch)
        self.steps_per_epoch = float(steps_per_epoch)
        self.decay_epochs = float(decay_epochs)
        self.sustain_epochs = float(sustain_epochs)
        self.warmup_epochs = float(warmup_epochs)
        self.lr_start = float(lr_start)
        self.lr_min = float(lr_min)
        self.decay_type = decay_type
        self.warmup_type = warmup_type
        self.finetune_steps = finetune_steps
        self.finetune_lr = float(finetune_lr)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        total_steps = self.epochs * self.steps_per_epoch
        warmup_steps = self.warmup_epochs * self.steps_per_epoch
        sustain_steps = self.sustain_epochs * self.steps_per_epoch
        decay_steps = self.decay_epochs * self.steps_per_epoch

        if self.resume_epoch > 0:
            step = step + self.resume_epoch * self.steps_per_epoch

        step = tf.cond(step > decay_steps, lambda :decay_steps, lambda :step)
        step = tf.math.truediv(step, self.steps_per_update) * self.steps_per_update

        warmup_cond = step < warmup_steps
        decay_cond = step >= (warmup_steps + sustain_steps)
        finetune_cond = step >= (total_steps - self.finetune_steps)

        lr = tf.cond(warmup_cond, lambda: tf.math.divide_no_nan(self.lr-self.lr_start , warmup_steps) * step + self.lr_start, lambda: self.lr)

        lr = tf.cond(decay_cond, lambda: 0.5 * (self.lr - self.lr_min) * (1 + tf.cos(3.14159265359 * (step - warmup_steps - sustain_steps) / (decay_steps - warmup_steps - sustain_steps))) + self.lr_min, lambda:lr)

        lr = tf.cond(finetune_cond, lambda: tf.constant(self.finetune_lr,tf.float32), lambda:lr)
        return lr


In [None]:
class LateDropout(tf.keras.layers.Layer):
    def __init__(self, rate, noise_shape=None, start_step=0, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True
        self.rate = rate
        self.start_step = start_step
        self.dropout = tf.keras.layers.Dropout(rate, noise_shape=noise_shape)

    def build(self, input_shape):
        super().build(input_shape)
        agg = tf.VariableAggregation.ONLY_FIRST_REPLICA
        self._train_counter = tf.Variable(0, dtype="int64", aggregation=agg, trainable=False)

    def call(self, inputs, training=False):
        x = tf.cond(self._train_counter < self.start_step, lambda:inputs, lambda:self.dropout(inputs, training=training))
        if training:
            self._train_counter.assign_add(1)
        return x

class CausalDWConv1D(tf.keras.layers.Layer):
    def __init__(self,
        kernel_size=17,
        dilation_rate=1,
        use_bias=True,
        depthwise_initializer='glorot_uniform',
        name='', **kwargs):
        super().__init__(name=name,**kwargs)
        self.causal_pad = tf.keras.layers.ZeroPadding1D((dilation_rate*(kernel_size-1),0))
        self.dw_conv = tf.keras.layers.DepthwiseConv1D(
                            kernel_size,
                            strides=1,
                            dilation_rate=dilation_rate,
                            padding='valid',
                            use_bias=use_bias,
                            depthwise_initializer=depthwise_initializer,)
        self.supports_masking = True

    def call(self, inputs):
        x = self.causal_pad(inputs)
        x = self.dw_conv(x)
        return x

class GLU(tf.keras.layers.Layer):
    def __init__(self, dim, **kwargs):
        super(GLU, self).__init__(**kwargs)
        self.dim = dim
        self.supports_masking = True

    def call(self, inputs):
        out, gate = tf.split(inputs, 2, axis=self.dim)
        return out * tf.sigmoid(gate)

    def compute_mask(self, inputs, mask=None):
        return mask



class CausalConv1D(tf.keras.layers.Layer):
    def __init__(self,
        hid_dim,
        kernel_size=17,
        dilation_rate=1,
        groups = 1,
        name='', **kwargs):
        super().__init__()
        self.causal_pad = tf.keras.layers.ZeroPadding1D((dilation_rate*(kernel_size-1),0))
        self.dw_conv = tf.keras.layers.Conv1D(
                            hid_dim,
                            kernel_size,
                            strides=1,
                            dilation_rate=dilation_rate,
                            padding='valid',groups=groups)
        self.supports_masking = True

    def call(self, inputs):
        x = self.causal_pad(inputs)
        x = self.dw_conv(x)
        return x


class SqueezeformerBlock(tf.keras.layers.Layer):
    def __init__(self, dim, num_heads=8, dropout=DROP_OUT, expansion_factor=2):
        super().__init__()
        self.dim = dim
        
        # Feed forward modules
        self.ff1_norm = tf.keras.layers.LayerNormalization()
        self.ff1 = tf.keras.Sequential([
            tf.keras.layers.Dense(dim*4),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.Dropout(dropout),
            tf.keras.layers.Dense(dim),
            tf.keras.layers.Dropout(dropout)
        ])
        
        # Attention module
        self.norm1 = tf.keras.layers.LayerNormalization()
        self.mhsa = MultiHeadSelfAttention(dim, num_heads, dropout)
        
        # Convolution module
        self.conv_norm = tf.keras.layers.LayerNormalization()
        self.conv1 = tf.keras.layers.Conv1D(dim*2, 1)
        self.glu = GLU(-1)
        if not MASKING:
            self.depthwise_conv = tf.keras.layers.Conv1D(dim, kernel_size=31, padding='same', groups=dim)
        else:
            self.depthwise_conv = CausalDWConv1D(kernel_size=31)
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.activation = tf.keras.layers.Activation('swish')
        self.pointwise_conv = tf.keras.layers.Conv1D(dim, 1)
        self.conv_dropout = tf.keras.layers.Dropout(dropout)
        
        # Feed forward module 2
        self.ff2_norm = tf.keras.layers.LayerNormalization()
        self.ff2 = tf.keras.Sequential([
            tf.keras.layers.Dense(dim*4),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.Dropout(dropout),
            tf.keras.layers.Dense(dim),
            tf.keras.layers.Dropout(dropout)
        ])
        
        self.dropout = tf.keras.layers.Dropout(dropout)
        self.scale = tf.Variable(1.0, trainable=True)

    def call(self, x):
        # First feed forward
        residual = x
        x = self.ff1_norm(x)
        x = self.ff1(x)
        x = residual + self.scale * x
        
        # Self attention
        residual = x
        x = self.norm1(x)
        x = self.mhsa(x)
        x = self.dropout(x)
        x = residual + self.scale * x
        
        # Convolution module
        residual = x
        x = self.conv_norm(x)
        x = self.conv1(x)
        x = self.glu(x)
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.activation(x)
        x = self.pointwise_conv(x)
        x = self.conv_dropout(x)
        x = residual + self.scale * x
        
        # Second feed forward
        residual = x
        x = self.ff2_norm(x)
        x = self.ff2(x)
        x = residual + self.scale * x
        
        return x


class MultiHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(self, dim=256, num_heads=4, dropout=DROP_OUT, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.scale = self.dim ** -0.5
        self.num_heads = num_heads
        self.qkv = tf.keras.layers.Dense(3 * dim)
        self.drop1 = tf.keras.layers.Dropout(dropout)
        self.proj = tf.keras.layers.Dense(dim)
        self.rel_embedding = self.add_weight("rel_embedding", shape=[FRAME_LEN * 2 - 1, dim // num_heads])
        self.supports_masking = True

    def call(self, inputs, mask=None):
        qkv = self.qkv(inputs)
        qkv = tf.keras.layers.Permute((2, 1, 3))(
            tf.keras.layers.Reshape(
                (-1, self.num_heads, self.dim * 3 // self.num_heads)
            )(qkv)
        )
        q, k, v = tf.split(qkv, [self.dim // self.num_heads] * 3, axis=-1)

        # seq_len = tf.shape(q)[-2]
        seq_len = FRAME_LEN
        rel_indices = tf.range(seq_len)[:, None] - tf.range(seq_len)[None, :] + seq_len - 1
        rel_indices = rel_indices[:tf.shape(q)[-2], :tf.shape(q)[-2]]
        rel_k = tf.nn.embedding_lookup(self.rel_embedding, rel_indices)
        rel_logits = tf.einsum('bhid,ijd->bhij', q, rel_k)

        attn = (tf.matmul(q, k, transpose_b=True) + rel_logits) * self.scale

        if mask is not None:
            mask1 = tf.cast(mask[:, None, None,:],tf.float32)
            mask2 = tf.cast(mask[:, None, : , None],tf.float32)
            mask = mask2 @ mask1
            attn = attn + ((1-mask) * -1e9)

        attn = tf.keras.layers.Softmax(axis=-1)(attn)
        attn = self.drop1(attn)

        x = attn @ v
        x = tf.keras.layers.Reshape((-1, self.dim))(
            tf.keras.layers.Permute((2, 1, 3))(x)
        )
        x = self.proj(x)
        return x


def ConformerBlock(dim=256, num_heads=8, expand=2, attn_dropout=DROP_OUT, drop_rate=DROP_OUT, activation='swish',ksize=11):
    def apply(inputs):
        x = inputs
        # mlp1
        x = tf.keras.layers.BatchNormalization()(x) #tf.keras.layers.BatchNormalization(x)
        x = tf.keras.layers.Dense(dim*expand, activation=activation)(x)
        x = tf.keras.layers.Dropout(drop_rate)(x)
        x = tf.keras.layers.Dense(dim, kernel_initializer="he_normal")(x)
        x = tf.keras.layers.Dropout(drop_rate)(x)
        mlp1_x = tf.keras.layers.Add()([inputs, 0.5*x])
        # attn
        x = tf.keras.layers.BatchNormalization()(mlp1_x) #tf.keras.layers.BatchNormalization(mlp1_x)
        x = MultiHeadSelfAttention(dim=dim,num_heads=num_heads,dropout=attn_dropout)(x)
        x = tf.keras.layers.Dropout(drop_rate)(x)
        attn_out = tf.keras.layers.Add()([mlp1_x, x])
        # attn_out = tf.keras.layers.BatchNormalization(x)
        # conv
        conv_out = ConformerConvModule(dim,ksize)(attn_out)
        # mlp2
        x = tf.keras.layers.BatchNormalization()(conv_out) #tf.keras.layers.BatchNormalization(conv_out)
        x = tf.keras.layers.Dense(dim*expand, activation=activation)(x)
        x = tf.keras.layers.Dropout(drop_rate)(x)
        x = tf.keras.layers.Dense(dim, kernel_initializer="he_normal")(x)
        x = tf.keras.layers.Dropout(drop_rate)(x)
        x = tf.keras.layers.Add()([conv_out, 0.5*x])
        # x = tf.keras.layers.BatchNormalization()(x)
        return x
    return apply

# copy from https://www.kaggle.com/code/markwijkhuizen/aslfr-transformer-training-inference#Landmark-Embedding
class LandmarkEmbedding(tf.keras.layers.Layer):
    def __init__(self, units, name='emb'):
        super(LandmarkEmbedding, self).__init__(name=f'{name}_embedding')
        self.supports_masking = True

        # self.empty_embedding = self.add_weight(
        #     name=f'{name}_empty_embedding',
        #     shape=[units],
        #     initializer=tf.keras.initializers.constant(0.0),
        #     trainable=True,
        # )
        # Embedding
        self.dense = tf.keras.Sequential([
            tf.keras.layers.Dense((units+INPUT_DIM)//2,activation='swish'),
            tf.keras.layers.Dense(units),
            tf.keras.layers.BatchNormalization(),
        ], name=f'{name}_dense')

    def call(self, x):
        return self.dense(x)
        # return tf.where(
        #         # Checks whether landmark is missing in frame
        #         tf.reduce_all(x==PADDING_MASKING_VALUE, axis=-1, keepdims=True),
        #         # If so, the empty embedding is used
        #         self.empty_embedding,
        #         # Otherwise the landmark data is embedded
        #         self.dense(x),
        #     )

In [None]:
def num_to_char_fn(y):
    return [num_to_char.get(x, "") for x in y]

# @tf.function()
# def decode_phrase(pred):
#     x = tf.argmax(pred, axis=1)
#     diff = tf.not_equal(x[:-1], x[1:])
#     adjacent_indices = tf.where(diff)[:, 0]
    
#     # Adding the index of the last token
#     adjacent_indices = tf.concat([adjacent_indices, [tf.size(x) - 1]], 0)
    
#     x = tf.gather(x, adjacent_indices)
#     mask = x != pad_token_idx
#     x = tf.boolean_mask(x, mask, axis=0)
#     return x

space_token_idx = char_to_num[' ']
@tf.function()
def decode_phrase(pred):
    x = tf.argmax(pred, axis=1)
    diff = tf.not_equal(x[:-1], x[1:])
    adjacent_indices = tf.where(diff)[:, 0]
    
    # Adding the index of the last token
    adjacent_indices = tf.concat([adjacent_indices, [tf.size(x) - 1]], 0)
    
    x = tf.gather(x, adjacent_indices)
    mask = x != pad_token_idx
    x = tf.boolean_mask(x, mask, axis=0)
#     no_space_x = tf.boolean_mask(x, x != space_token_idx, axis=0)
#     if tf.shape(no_space_x)[0] < 4:
#         x = tf.concat([no_space_x,tf.convert_to_tensor([char_to_num[c] for c in ' -aero'],dtype=tf.int64)],axis=0)
    if tf.shape(x)[0] < 4:
        x = tf.concat([x,tf.convert_to_tensor([char_to_num[c] for c in ' -aero'],dtype=tf.int64)],axis=0)
    return x

# A utility function to decode the output of the network
def decode_batch_predictions(pred):
    output_text = []
    for result in pred:
        result = "".join(num_to_char_fn(decode_phrase(result).numpy()))
        output_text.append(result)
    return output_text

# A callback class to output a few transcriptions during training
class CallbackEval(tf.keras.callbacks.Callback):
    """Displays a batch of outputs after every epoch."""

    def __init__(self, dataset,aux_dataset,model):
        super().__init__()
        self.dataset = dataset
        self.aux_dataset = aux_dataset
        self.model = model
        self.train_score = []
        self.val_score = []
    def on_epoch_end(self, epoch: int, logs=None):
        if epoch % 5 !=0:
            return
        start_time = time.time()
        d = 0
        n = 0
        for batch in self.dataset.take(1000//BATCH_SIZE):
            X, y = batch
            batch_predictions = self.model([X])
            batch_predictions = decode_batch_predictions(batch_predictions)
            # predictions.extend(batch_predictions)
            for i,label in enumerate(y):
                label = "".join(num_to_char_fn(label.numpy())).replace(pad_token,'')
                n += len(label)
                d += lev.distance(label, batch_predictions[i])
        print('val metric: ', (n-d)/n)
        self.val_score.append((n-d)/n)
        d = 0
        n = 0
        for batch in self.aux_dataset.take(1000//BATCH_SIZE):
            X, y = batch
            batch_predictions = self.model([X])
            batch_predictions = decode_batch_predictions(batch_predictions)
            # predictions.extend(batch_predictions)
            for i,label in enumerate(y):
                label = "".join(num_to_char_fn(label.numpy())).replace(pad_token,'')
                n += len(label)
                d += lev.distance(label, batch_predictions[i])
        print('train metric: ', (n-d)/n)
        self.train_score.append((n-d)/n)
        print('eval time: ', time.time() - start_time)

In [None]:
class RestoreBestWeightsIfIncrease(tf.keras.callbacks.Callback):
    def __init__(self,optimizer,patience=1, restore_threshold=0.1,min_epoch=0):
        super(RestoreBestWeightsIfIncrease, self).__init__()
        self.optimizer = optimizer
        self.best_optimizer_weights = None
        self.patience = patience
        self.restore_threshold = restore_threshold
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None
        # best_loss Will keep track of the lowest loss so far.
        self.best_loss = np.Inf
        # wait Will keep track of the number of epochs the training has waited when loss is no longer minimum.
        self.wait = 0
        self.min_epoch = min_epoch
    def on_epoch_end(self, epoch, logs=None):
        current_val_loss = logs.get("loss")
        if np.less(current_val_loss, self.best_loss):
            self.best_loss = current_val_loss
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
            self.best_optimizer_weights = self.optimizer.get_weights()

            print('record best loss = ',self.best_loss)
        else:
            self.wait += 1
            if self.wait >= self.patience and epoch > self.min_epoch:
                if current_val_loss - self.best_loss > self.restore_threshold:
                    self.model.set_weights(self.best_weights)
                    self.optimizer.set_weights(self.best_optimizer_weights)
                    print("\nRestoring model weights from the end of the best epoch.")
                    self.wait = 0


In [None]:
class SAM(tf.keras.Model):
    def __init__(self, *args, rho=.05, **kwargs):
        super().__init__(*args, **kwargs)
        self.rho = rho
    def _grad_norm(self, gradients):
        norm = tf.norm(
            tf.stack([
                tf.norm(grad) for grad in gradients if grad is not None
            ])
        )
        return norm
    def train_step_sam(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data
        e_ws = []

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)

        trainable_params = self.trainable_variables
        gradients = tape.gradient(loss, trainable_params)
        grad_norm = self._grad_norm(gradients)
        scale = self.rho / (grad_norm + 1e-12)

        for (grad, param) in zip(gradients, trainable_params):
            e_w = grad * scale
            param.assign_add(e_w)
            e_ws.append(e_w)


        with tf.GradientTape() as tape2:
            y_pred = self(x, training=True)
            new_loss = self.compiled_loss(y, y_pred)

        sam_gradients = tape2.gradient(new_loss, trainable_params)

        for (param, e_w) in zip(trainable_params, e_ws):
            param.assign_sub(e_w)

        self.optimizer.apply_gradients(
            zip(sam_gradients, trainable_params))

        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}


    def train_step(self, data):
        return self.train_step_sam(data)

In [None]:
class CustomModelCheckpoint(tf.keras.callbacks.Callback):
    def __init__(self,model):
        super().__init__()
        self.model = model
    def on_epoch_end(self, epoch, logs=None):
        if epoch >= 350 and (epoch + 1) % 10 == 0:  # Check if this is the 5th epoch (or multiple thereof)
            self.model.save_weights(f'model_weights_{epoch + 1:02d}.h5')

In [None]:
def positional_encoding(maxlen, num_hid):
    depth = num_hid/2
    positions = tf.range(maxlen, dtype = tf.float32)[..., tf.newaxis]
    depths = tf.range(depth, dtype = tf.float32)[np.newaxis, :]/depth
    angle_rates = tf.math.divide(1, tf.math.pow(tf.cast(10000, tf.float32), depths))
    angle_rads = tf.linalg.matmul(positions, angle_rates)
    pos_encoding = tf.concat(
      [tf.math.sin(angle_rads), tf.math.cos(angle_rads)],
      axis=-1)
    return pos_encoding

In [None]:
steps_per_epoch = int((64211 if VAL else 67208)//BATCH_SIZE) + 1


data_cfg = {}
data_cfg['FRAME_LEN'] = FRAME_LEN
data_cfg['INPUT_DIM'] = INPUT_DIM
data_cfg['batch_size'] = BATCH_SIZE




model_cfg = {}
model_cfg["num_layers_enc"] = 6
model_cfg["encoder_dim"] = 384
model_cfg["num_head"] =  8
model_cfg['kernel_size'] = 27 # should back to 13
model_cfg["decoder_dim"] = 384 # 256


model_cfg["source_maxlen"] = TEMPORAL_DIM
model_cfg["target_maxlen"] = PHRASE_MAX_LEN
model_cfg["num_classes"] = CLASS_NUM


def get_model():
    inp = tf.keras.layers.Input(shape=(TEMPORAL_DIM,INPUT_DIM))
    if not MASKING:
        x = inp
    else:
        x = tf.keras.layers.Masking(mask_value=PADDING_MASKING_VALUE)(inp)

    x = LandmarkEmbedding(model_cfg["encoder_dim"])(x)

    for i in range(model_cfg['num_layers_enc']):
        x = SqueezeformerBlock(
            dim=model_cfg["encoder_dim"],
            num_heads=model_cfg["num_head"]
        )(x)

    x = tf.keras.layers.GRU(model_cfg['decoder_dim'], return_sequences=True)(x)
    x = tf.keras.layers.Dense(model_cfg['num_classes'])(x)
    model = tf.keras.Model(inputs=inp, outputs=x)
    return model


def aux_ctc_loss(y_true, y_pred):
        y_true = tf.ensure_shape(y_true, [BATCH_SIZE//strategy.num_replicas_in_sync,PHRASE_MAX_LEN])
        y_pred = tf.ensure_shape(y_pred, [BATCH_SIZE//strategy.num_replicas_in_sync,TEMPORAL_DIM,CLASS_NUM])
        label_length = tf.reduce_sum(tf.cast(y_true != pad_token_idx, tf.int32), axis=-1)
        logit_length = tf.ones(tf.shape(y_pred)[0], dtype=tf.int32) * tf.shape(y_pred)[1]
        loss = classic_ctc_loss(
                labels=y_true,
                logits=y_pred,
                label_length=label_length,
                logit_length=logit_length,
                blank_index=pad_token_idx,
        )
        loss = tf.reduce_mean(loss)
        return loss

lr_cfg = {}
lr_cfg['lr'] =  1e-3
lr_cfg['weight_decay'] = 1e-3 #1e-6
lr_cfg['epochs'] = 500

lr_cfg['optimizer'] = tfa.optimizers.RectifiedAdam
lr_cfg['alpha'] = 0.05 # final lr = lr * alpha
lr_cfg['finetune_epochs'] = 0
lr_cfg['warmup_epochs'] = 0 # 0.1*lr_cfg['epochs']
initial_learning_rate = lr_cfg['lr']


lr_cfg['scheduler'] = OneCycleLR

lr_schedule = lr_cfg['scheduler'](
    lr = lr_cfg['lr'],
    epochs = lr_cfg['epochs'] + lr_cfg['finetune_epochs'],
    steps_per_epoch = steps_per_epoch,
    steps_per_update = 1,
    resume_epoch = 0,
    decay_epochs = lr_cfg['epochs'] - lr_cfg['warmup_epochs'],
    sustain_epochs = 0,
    warmup_epochs = lr_cfg['warmup_epochs'],
    finetune_steps = 0,
    finetune_lr = 0.05 * lr_cfg['lr'],
    lr_start = lr_cfg['lr']*0.01,
    lr_min = lr_cfg['lr']*lr_cfg['alpha'],
)


with strategy.scope():
    if VAL:
        batch = next(iter(val_dataset))
    else:
        batch = next(iter(train_dataset))
    print("raw input shape", batch[0].shape)
    print("raw target shape", batch[1].shape)


    model_cb = tf.keras.callbacks.ModelCheckpoint(
        'best.h5',
        monitor = 'val_loss' if VAL else 'loss',
        verbose = 1,
        save_best_only = True,
        save_weights_only= True,
        mode = 'auto',
        save_freq='epoch',
        options=None,
        initial_value_threshold=None,
    )

    optimizer = lr_cfg['optimizer'](lr_schedule, weight_decay = lr_cfg['weight_decay'],) # clipnorm = 5.0)
    optimizer = tfa.optimizers.Lookahead(optimizer, sync_period=5)
    model = get_model()
#     model = SAM(model.input,model.output)
    model.compile(optimizer=optimizer, loss=aux_ctc_loss,steps_per_execution=steps_per_epoch)
    model.summary()
    model.build((None,TEMPORAL_DIM,INPUT_DIM))

    restore_cb = RestoreBestWeightsIfIncrease(optimizer,patience=1, restore_threshold=1.0,min_epoch=10)
    epochwise_checkpoint = CustomModelCheckpoint(model)
    callbacks = [lr_logger,model_cb, restore_cb,epochwise_checkpoint]

    if not LOAD:
        history = model.fit(
            train_dataset,
            validation_data = val_dataset if VAL else None,
            callbacks=callbacks,
            epochs=lr_cfg['epochs'],
            steps_per_epoch=steps_per_epoch,
            validation_steps=0 if not VAL else -(1000//-BATCH_SIZE),
        )
        model.save_weights('model.h5')

# SAM -> 0.767 (100 -> 0.754)
# No SAM 50/0.71

In [None]:
class TFLiteModel(tf.Module):
    def __init__(
        self,
        model,
    ):
        super(TFLiteModel, self).__init__()
        self.model = model

    @tf.function(
        input_signature=[
            tf.TensorSpec(shape=[None, len(SEL_COLS)], dtype=tf.float32, name="inputs")
        ]
    )
    def __call__(self, inputs, training=False):
        # Preprocess Data
        x = tf.cast(inputs, tf.float32)
        x = x[None]
        x = tf.cond(
            tf.shape(x)[1] == 0,
            lambda: tf.zeros((1, 1, len(SEL_COLS))),
            lambda: tf.identity(x),
        )
        x = x[0]
        x = preprocess1(x)
        x = preprocess2(x)
        x = x[None]
        x = self.model(x)
        x = x[0]
        x = decode_phrase(x)
        x = tf.cond(tf.shape(x)[0] == 0, lambda: tf.zeros(1, tf.int64), lambda: tf.identity(x))
        x = tf.one_hot(x, 59)
        return {"outputs": x}

tflitemodel_base = TFLiteModel(model)

In [None]:
keras_model_converter = tf.lite.TFLiteConverter.from_keras_model(tflitemodel_base)
# before
keras_model_converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
# keras_model_converter.optimizations = [tf.lite.Optimize.DEFAULT] # comment for speed

# now
keras_model_converter.optimizations = [tf.lite.Optimize.DEFAULT]
keras_model_converter.target_spec.supported_types = [tf.float16]

tflite_model = keras_model_converter.convert()
with open("model.tflite", "wb") as f:
    f.write(tflite_model)

infargs = {"selected_columns": SEL_COLS}

with open("inference_args.json", "w") as json_file:
    json.dump(infargs, json_file)


In [None]:
import zipfile

# list of file names to be zipped
file_names = ["model.tflite", "inference_args.json"]

# name of the zip file
zip_name = "submission.zip"

# create a ZipFile object
with zipfile.ZipFile(zip_name, 'w') as zipf:
    # write each file into the zip file
    for file in file_names:
        zipf.write(file)

print(f'{zip_name} file is created.')