# Pips

In [25]:
!pip install tensorflow-addons
!pip install cached_property
!pip intstall tensorflow==13.0

[0mERROR: unknown command "intstall" - maybe you meant "install"


# For Colab

In [26]:
import os
IN_COLAB = 'COLAB_GPU' in os.environ
if IN_COLAB:
  from google.colab import auth
  auth.authenticate_user()
  drive.mount('/content/drive')

In [27]:
!mkdir -p /kaggle/working

# Imports

In [28]:

import os
import numpy as np
import pandas as pd

import random
import psutil
import gc
import math
import tensorflow as tf
import tensorflow_addons as tfa
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)



# CTC Loss

In [29]:
from __future__ import annotations
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 18 20:29:39 2023
"""

# Copyright 2021 Alexey Tochin
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Union, Callable, List, Optional, Type
import tensorflow as tf
import numpy as np

from abc import ABC, abstractmethod
from cached_property import cached_property


inf = tf.constant(np.inf)


def logit_to_logproba(logit: tf.Tensor, axis: int) -> tf.Tensor:
    """Converts logits to logarithmic probabilities:
        logit_to_logproba(x) = x - log (sum along axis (exp(x))

    Args:
        logit:  tf.Tensor, dtype = tf.float32
        axis: integer, like for tf.reduce_logsumexp

    Returns:    tf.Tensor, of the same shape and size as input logit
    """
    log_probas = logit - tf.reduce_logsumexp(input_tensor=logit, axis=axis, keepdims=True)
    return log_probas


def apply_logarithmic_mask(tensor: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
    """Masks a logarithmic representation of a tensor, namely
    1. Keeps the value of tensor unchanged for True values of mask
    2. Replace the value of tensor by -tf.inf for False values of mask

    Args:
        tensor: tf.Tensor, dtype = tf.float32 of the same shape as mask or broadcastable
        mask:   tf.Tensor, dbool = tf.float32 of the same shape as tensor or broadcastable

    Returns:    tf.Tensor, dtype = tf.float32 of the same shape as tensor
    """
    return tensor + tf.math.log(tf.cast(mask, dtype=tf.float32))


def logsumexp(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """A numerically stable version of elementwise function
        logsumexp(x, y) = log (e ** x + e ** y)

    Args:
        x:      tf.Tensor of the shape and size as y or broadcastable
        y:      tf.Tensor of the shape and size as x or broadcastable

    Returns:    tf.Tensor of the shape and size as x and y
    """
    return tf.where(
        condition=x < y,
        x=y + tf.math.softplus(x - y),
        y=tf.where(
            condition=x > y,
            x=x + tf.math.softplus(y - x),
            y=x + np.log(2.)
        ),
    )


def subexp(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """A numerically stable version of elementwise function
        subexp(x,y) := exp x - exp y

    Args:
        x:      tf.Tensor, shape broadcastable to y
        y:      tf.Tensor, shape broadcastable to x

    Returns:    tf.Tensor, shape, the same as x and y
    """
    return tf.where(
        condition=x > y,
        x=-tf.exp(x) * tf.math.expm1(y - x),
        y=tf.where(
            condition=x < y,
            x=tf.exp(y) * tf.math.expm1(x - y),
            y=tf.zeros_like(x),
        ),
    )


def unsorted_segment_logsumexp(data: tf.Tensor, segment_ids: tf.Tensor, num_segments: Union[int, tf.Tensor])\
        -> tf.Tensor:
    """Computes the logarithmic sum of exponents along segments of a tensor
    like other operators from tf.math.unsorted_segment_* family.

    Args:
        data:           tf.Tensor,  shape = [...] + data_dims,
        segment_ids:    tf.Tensor,  shape = [...], dtype = tf.int32
        num_segments:   tf.Tensor,  shape = [], dtype = tf.int32

    Returns:            tf.Tensor,  shape = [num_segments] + data_dims, for the same type as data
    """
    data_max = tf.math.unsorted_segment_max(data=data, segment_ids=segment_ids, num_segments=num_segments)
    data_normed = data - tf.gather(params=data_max, indices=segment_ids)
    output = data_max + tf.math.log(tf.math.unsorted_segment_sum(
        data=tf.exp(data_normed),
        segment_ids=segment_ids,
        num_segments=num_segments,
    ))
    return output


def pad_until(
        tensor: tf.Tensor,
        desired_size: Union[tf.Tensor, int],
        axis: int,
        pad_value: Union[tf.Tensor, int, float, bool] = 0
) -> tf.Tensor:
    """Pads tensor until desired dimension from right,

    Args:
        tensor:         tf.Tensor, of any shape and type
        desired_size:   tf.Tensor or pythonic static integer
        axis:           pythonic static integer for pad axes
        pad_value:      tf.Tensor or pythonic numerical for padding

    Returns:            tf.Tensor, the same shape as tensor except axis that equals to desired_size.
    """
    rank = len(tensor.shape)
    if axis >= rank:
        raise ValueError()

    current_size = tf.shape(tensor)[axis]
    paddings = [[0, 0]] * axis + [[0, desired_size - current_size]] + [[0, 0]] * (rank - axis - 1)
    return tf.pad(tensor=tensor, paddings=paddings, constant_values=pad_value)


def insert_zeros(tensor: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
    """Inserts zeros into tensor before each masked element.
    For example:
    ```python
        output = insert_zeros(
            tensor =  tf.constant([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]], dtype = tf.int32),
            mask = tf.constant([[False, True, False, False, True], [False, True,  True, True,  False]]),
        )
        # -> [[1, 0, 2, 3, 4, 0, 5, 0], [10, 0, 20, 0, 30, 0, 40, 50]]
        # We insert 0s 2, 5, 20, 30, and 40 because their positions in input tensor corresponds to True value
        in mask.
    ```

    Args:
        tensor: tf.Tensor, shape = [batch, length], any type and the same shape as mask
        mask:   tf.Tensor, shape = [batch, length], dtype = tf.bool and the same shape as tensor

    Returns:    tf.Tensor, shape = [batch, length + max_num_insertions],
                where max_num_insertions is the maximal number of True values along the 0 batch dimension of mask.
                dtype = same as input tensor
    """
    batch_size = tf.shape(tensor)[0]
    length = tf.shape(mask)[1]

    delta = tf.cumsum(tf.cast(mask, dtype=tf.int32), exclusive=False, axis=1)
    max_num_insertions = tf.reduce_max(delta[:, -1])

    y, x = tf.meshgrid(tf.range(length), tf.range(batch_size))
    y = y + delta
    indices = tf.reshape(tf.stack([x, y], 2), [-1, 2])

    output = tf.scatter_nd(
        indices=indices,
        updates=tf.reshape(tensor, shape=[-1]),
        shape=tf.stack([batch_size, length + max_num_insertions])
    )

    return output


def unfold(
        init_tensor: tf.Tensor,
        iterfunc: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
        num_iters: Union[int, tf.Tensor],
        d_i: int,
        element_shape: tf.TensorShape,
        swap_memory: bool = False,
        name: str = "unfold",
) -> tf.Tensor:
    """Calculates a tensor by iterations over i that is the concatenation
        for d_i = +1:
            init_tensor
            iterfunc(init_tensor, 0)
            iterfunc(iterfunc(init_tensor, 0), 1)
            ...
            ..., num_iters - 1)
            ..., num_iters - 1), num_iters)
        for d_i = -1:
            ..., 2), 1), 0)
            ..., 2), 1)
            ...
            iterfunc(iterfunc(init_tensor, num_iters - 1), num_iters - 2)
            iterfunc(init_tensor, num_iters - 1)
            init_tensor
    For example:
    ```python
        unfold(
            init_tensor=tf.constant(0),
            iterfunc=lambda x, i: x + i,
            num_iters=5,
            d_i=1,
            element_shape=tf.TensorShape([]),
        )
        # -> [0, 0, 1, 3, 6, 10]
    ```

    Args:
        init_tensor:    tf.Tensor, of any shape that is the initial value of the iterations.
        iterfunc:       tf.Tensor, tf.Tensor -> tf.Tensor, that is the iteration function
                            from and onto the same shape as init_tensor
        num_iters:      tf.Tensor or static integer that is the number of iterations
        d_i:            either +1 or -1, where
                            +1 corresponds for the iterations from 0 to num_iters inclusive
                            -1 corresponds for the iterations from num_iters to 0 inclusive
        element_shape:  tf.TensorShape([]) that is the shape of init_tensor
        swap_memory:    the same as for tf.while_loop, argument
        name:           str, local tensor names scope

    Returns:            tf.Tensor, shape = [num_iters + 1] + init_tensor.shape
                        dtype the same as init_tensor
    """
    assert d_i in {-1, 1}
    positive_direction = d_i == 1

    with tf.name_scope(name):
        num_iters = tf.convert_to_tensor(num_iters)

        tensor_array = tf.TensorArray(
            dtype=init_tensor.dtype,
            size=num_iters + 1,
            element_shape=element_shape,
            clear_after_read=False,
            infer_shape=True,
            dynamic_size=False,
        )
        tensor_array = tensor_array.write(0 if positive_direction else num_iters, init_tensor)

        def body(i, tensor_slice):
            last_value = tensor_slice.read(i if positive_direction else i + 1)
            new_value = iterfunc(last_value, i)
            tensor_slice = tensor_slice.write(i + 1 if positive_direction else i, new_value)
            return i + d_i, tensor_slice

        n = tf.constant(0, dtype=tf.int32) if positive_direction else num_iters - 1
        _, array_out = tf.while_loop(
            cond=lambda i, _: tf.constant(True),
            body=body,
            loop_vars=(n, tensor_array),
            maximum_iterations=num_iters,
            swap_memory=swap_memory,
            name=f"unfold_while_loop",
        )
        return array_out.stack()


def reduce_max_with_default(input_tensor: tf.Tensor, default: tf.Tensor) -> tf.Tensor:
    """A version of tf.reduce_max function that supports default values for zero size input.
    Support axis=None case only that corresponds to scalar output

    Args:
        input_tensor:   tf.Tensor, of any shape and numerical type
        default:        tf.Tensor, shape = [], dtype the same as input_tensor

    Returns:            tf.Tensor, shape = [], dtype the same as input_tensor
    """
    total_size = tf.shape(tf.reshape(input_tensor, [-1]))[0]
    return tf.where(
        condition=total_size > 0,
        x=tf.reduce_max(input_tensor),
        y=default
    )


def expand_many_dims(input: tf.Tensor, axes: List[int]) -> tf.Tensor:
    """Analogous of tf.expand_dims for multiple new dimensions.
    Like for tf.expand_dims no new memory allocated for the output tensor.

    For example:
        expand_many_dims(tf.zeros(shape=[5, 1, 3]), axes=[0, 4, 5]).shape
        # -> [1, 5, 1, 3, 1, 1]

    Args:
        input:  tf.Tensor of any rank shape and type
        axes:   list of integer that are supposed to be the indexes of new dimensions.

    Returns:    tf.Tensor of the same type an input and rank = rank(input) + len(axes)
    """
    tensor = input
    for axis in axes:
        tensor = tf.expand_dims(input=tensor, axis=axis)

    return tensor


def smart_transpose(a: tf.Tensor, perm=List[int]) -> tf.Tensor:
    """Extension of tf.transpose.
    Parameter perm may be shorter list than rank on input tensor a.
    This case all dimensions that are beyond the list perm remain unchanged.

    For example:
        smart_transpose(tf.zeros(shape=[2, 3, 4, 5, 6]), [2, 1, 0]).shape
        # -> [4, 3, 2, 5, 6]

    Args:
        a:      tf.Tensor of any rank shape and type
        perm:   list of integers like for tf.transpose but in may be shorter than the shape of a.

    Returns:    tf.Tensor of the same type and rank as th input tensor a.
    """
    if len(perm) > len(a.shape):
        raise ValueError(f"Tensor with shape '{a.shape}' cannot be reshaped to '{perm}'")
    else:
        perm_rest = list(range(len(perm), len(a.shape)))

    return tf.transpose(a=a, perm=perm + perm_rest)


def smart_reshape(tensor: tf.Tensor, shape: List[Optional[Union[int, tf.Tensor]]]) -> tf.Tensor:
    """A version of tf.reshape.
    1. The ouput tensor is always of the same rank as input tensor.
    2. The parameter shape is supposed to be a list that is smaller or equal
    than the tensor shape.
    3. The list shape may contain None, that means "keep this dimension unchanged".
    4. The list shape is appended with None value to be of the same length as the input tensor shape.
    5. Like for tf.reshape output tensor does not requre new memory for allocation.

    For example:
    ```python
        smart_reshape(
            tensor=tf.zeros(shape=[2, 3, 4, 5]),
            shape=[8, None, 1]
        )
        # -> tf.Tensor([8, 3, 1, 5])
    ```

    Args:
        tensor: tf.Tensor of any shape and type
        shape:  list of optional static of dynamic integrates

    Returns:    tf.Tensor of the same typey and rank as the input tensor
    """
    if len(shape) > len(tensor.shape):
        raise ValueError(f"Tensor with shape {tensor.shape} cannot be reshaped to {shape}.")
    else:
        shape = shape + [None] * (len(tensor.shape) - len(shape))

    original_shape = tf.shape(tensor)
    new_shape = []
    for index, dim in enumerate(shape):
        if dim is None:
            new_shape.append(original_shape[index])
        else:
            new_shape.append(dim)

    return tf.reshape(tensor=tensor, shape=new_shape)



def ctc_loss(
        labels: tf.Tensor,
        logits: tf.Tensor,
        label_length: tf.Tensor,
        logit_length: tf.Tensor,
        blank_index: Union[int, tf.Tensor],
        ctc_loss_data_cls: Type[BaseCtcLossData],
) -> tf.Tensor:
    """Computes a version of CTC loss from
    http://www.cs.toronto.edu/~graves/icml_2006.pdf.

    Args:
        labels:             tf.Tensor, shape = [batch, max_label_length],       dtype = tf.int32
        logits:             tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32
        label_length:       tf.Tensor, shape = [batch],                         dtype = tf.int32
        logit_length:       tf.Tensor, shape = [batch],                         dtype = tf.int32
        blank_index:        static integer >= 0
        ctc_loss_data_cls:  BaseCtcLossData class

    Returns:                tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32
    """
    log_probas = logit_to_logproba(logit=logits, axis=2)
    loss = ctc_loss_from_logproba(
        labels=labels,
        logprobas=log_probas,
        label_length=label_length,
        logit_length=logit_length,
        blank_index=blank_index,
        ctc_loss_data_cls=ctc_loss_data_cls,
    )
    return loss


def ctc_loss_from_logproba(
        labels: tf.Tensor,
        logprobas: tf.Tensor,
        label_length: tf.Tensor,
        logit_length: tf.Tensor,
        blank_index: Union[int, tf.Tensor],
        ctc_loss_data_cls: Type[BaseCtcLossData],
) -> tf.Tensor:
    """Computes a version of CTC loss from logarothmic probabilities considered as independent parameters.

    Args:
        labels:             tf.Tensor, shape = [batch, max_label_length],       dtype = tf.int32
        logprobas:          tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32
        label_length:       tf.Tensor, shape = [batch],                         dtype = tf.int32
        logit_length:       tf.Tensor, shape = [batch],                         dtype = tf.int32
        blank_index:        static integer >= 0
        ctc_loss_data_cls:  BaseCtcLossData class

    Returns:                tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32
    """
    loss_data = ctc_loss_data_cls(
        labels=labels,
        logprobas=tf.stop_gradient(logprobas),
        label_length=label_length,
        logit_length=logit_length,
        blank_index=blank_index,
    )

    return loss_data.forward_fn(logprobas)


class BaseCtcLossData(ABC):
    """ Base class for CTC loss data. """
    def __init__(
            self,
            labels: tf.Tensor,
            logprobas: tf.Tensor,
            label_length: tf.Tensor,
            logit_length: tf.Tensor,
            blank_index: Union[int, tf.Tensor],
            swap_memory: bool = False,
            **kwargs
    ):
        super().__init__(**kwargs)
        self._logprobas = logprobas
        self._original_label = labels
        self._logit_length = logit_length
        self._original_label_length = label_length
        self.max_label_length_plus_one = tf.shape(labels)[1]
        self._verify_inputs()

        if isinstance(blank_index, (tf.Tensor, tf.Variable)):
            self._blank_index = blank_index
        else:
            self._blank_index = tf.constant(blank_index, dtype=tf.int32)

        self._swap_memory = swap_memory

    def _verify_inputs(self) -> None:
        assert len(self._logprobas.shape) == 3
        assert self._logprobas.dtype == tf.float32
        assert len(self._original_label.shape) == 2
        assert len(self._logit_length.shape) == 1
        assert len(self._original_label_length.shape) == 1

        assert self._logprobas.shape[0] == self._original_label.shape[0]
        assert self._logprobas.shape[0] == self._logit_length.shape[0]
        assert self._logprobas.shape[0] == self._original_label_length.shape[0]

    @tf.custom_gradient
    def forward_fn(self, unused_logprobas: tf.Tensor) -> tf.Tensor:
        def backprop(d_loss):
            return expand_many_dims(d_loss, axes=[1, 2]) * self.gradient_fn(unused_logprobas)

        return self.loss, backprop

    @tf.custom_gradient
    def gradient_fn(self, unused_logprobas: tf.Tensor) -> tf.Tensor:
        def backprop(d_gradient):
            output = tf.reduce_sum(
                input_tensor=expand_many_dims(d_gradient, axes=[1, 2]) * self.hessian_fn(unused_logprobas),
                axis=[3, 4]
            )
            return output

        return self.gradient, backprop

    @tf.custom_gradient
    def hessian_fn(self, unused_logprobas: tf.Tensor) -> tf.Tensor:
        def backprop(d_hessian):
            raise NotImplementedError(f"Third order derivative over the ctc loss function is not implemented.")

        return self.hessian, backprop

    @cached_property
    def hessian(self) -> tf.Tensor:
        """Calculates Hessian of loss w.r.t. input logits.

        Returns: tf.Tensor, shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens]
        """
        alpha_gamma_term = self.combine_transition_probabilities(a=self.alpha[:, :-1], b=self.gamma[:, 1:])
        # shape = [batch_size, max_logit_length, num_tokens, max_logit_length + 1, max_label_length + 1]
        alpha_gamma_beta_term = \
            self.combine_transition_probabilities(a=alpha_gamma_term[:, :, :, :-1], b=self.beta[:, 1:])
        # shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens]
        alpha_gamma_beta_loss_term = expand_many_dims(self.loss, axes=[1, 2, 3, 4]) + alpha_gamma_beta_term
        # shape = [batch_size, max_logit_length, num_tokens]
        logit_length_x_num_tokens = self.max_logit_length * self.num_tokens
        first_term = tf.reshape(
            tf.linalg.set_diag(
                input=tf.reshape(
                    tensor=alpha_gamma_beta_loss_term,
                    shape=[self.batch_size, logit_length_x_num_tokens, logit_length_x_num_tokens]
                ),
                diagonal=tf.reshape(
                    tensor=self.logarithmic_logproba_gradient,
                    shape=[self.batch_size, logit_length_x_num_tokens]
                )
            ),
            shape=tf.shape(alpha_gamma_beta_term),
        )

        mask = expand_many_dims(
            input=tf.linalg.band_part(tf.ones(shape=[self.max_logit_length] * 2, dtype=tf.bool), 0, -1),
            axes=[0, 2, 4]
        )
        symmetrized_first_term = tf.where(
            condition=mask,
            x=first_term,
            y=tf.transpose(first_term, [0, 3, 4, 1, 2]),
        )
        # shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens]
        hessian = \
            -tf.exp(symmetrized_first_term) \
            + expand_many_dims(self.gradient, [3, 4]) * expand_many_dims(self.gradient, [1, 2])
        # shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens]

        # Filter out samples with infinite loss
        hessian = tf.where(
            condition=expand_many_dims(self.loss == inf, [1, 2, 3, 4]),
            x=tf.zeros(shape=[1, 1, 1, 1, 1]),
            y=hessian,
        )
        # shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens]

        # Filter out logits that beyond logits length
        hessian = tf.where(
            condition=expand_many_dims(self.logit_length_mask, axes=[2, 3, 4]),
            x=hessian,
            y=0.
        )
        hessian = tf.where(
            condition=expand_many_dims(self.logit_length_mask, axes=[1, 2, 4]),
            x=hessian,
            y=0.
        )

        return hessian

    @cached_property
    def gradient(self) -> tf.Tensor:
        # shape = [batch_size, max_logit_length, num_tokens]
        return -tf.exp(self.logarithmic_logproba_gradient)

    @cached_property
    def logarithmic_logproba_gradient(self) -> tf.Tensor:
        """Calculates logarithmic gradient of log loss w.r.t. input logarithmic probabilities.

        Returns: tf.Tensor, shape = [batch_size, max_logit_length, num_tokens]
        """
        logarithmic_logproba_gradient = \
            tf.reshape(self.loss, [-1, 1, 1]) \
            + self.combine_transition_probabilities(a=self.alpha[:, :-1], b=self.beta[:, 1:])
        # shape = [batch_size, max_logit_length, num_tokens]

        # Filter out samples infinite loss
        logarithmic_logproba_gradient = tf.where(
            condition=expand_many_dims(self.loss == inf, [1, 2]),
            x=-inf,
            y=logarithmic_logproba_gradient,
        )
        # shape = [batch_size, max_logit_length, num_tokens]

        # Filter out logits that beyond logits length
        logarithmic_logproba_gradient = apply_logarithmic_mask(
            tensor=logarithmic_logproba_gradient,
            mask=tf.expand_dims(self.logit_length_mask, axis=2),
        )
        # shape = [batch_size, max_logit_length, num_tokens]

        return logarithmic_logproba_gradient

    @property
    @abstractmethod
    def alpha(self) -> tf.Tensor:
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, ...]
        raise NotImplementedError()

    @property
    @abstractmethod
    def beta(self) -> tf.Tensor:
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, ...]
        raise NotImplementedError()

    @property
    @abstractmethod
    def gamma(self) -> tf.Tensor:
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, ...,
        #   max_logit_length + 1, max_label_length + 1, ...]
        raise NotImplementedError()

    @cached_property
    def expected_token_logproba(self) -> tf.Tensor:
        """Logarithmic probability to predict label token.

        Returns:shape = [batch_size, max_logit_length, max_label_length + 1]
        """
        label_logproba = tf.gather(
            params=self.logproba,
            indices=self.label,
            axis=2,
            batch_dims=1,
        )
        expected_token_logproba = \
            apply_logarithmic_mask(label_logproba, tf.expand_dims(self.label_length_mask, axis=1))
        # shape = [batch_size, max_logit_length, max_label_length + 1]
        return expected_token_logproba

    @property
    @abstractmethod
    def loss(self) -> tf.Tensor:
        """Samplewise loss function value that is minus logarithmic probability to predict label sequence.

        Returns:    tf.Tensor, shape = [batch_size]
        """
        raise NotImplementedError()

    @cached_property
    def label_token_logproba(self) -> tf.Tensor:
        """ shape = [batch_size, max_logit_length, max_label_length + 1] """
        return tf.gather(
            params=self.logproba,
            indices=self.label,
            axis=2,
            batch_dims=1,
        )

    @cached_property
    def blank_logproba(self):
        """Calculates logarithmic probability to predict blank token for given logit.

        Returns:    tf.Tensor, shape = [batch_size, max_logit_length]
        """
        return self.logproba[:, :, self.blank_token_index]

    @cached_property
    def input_proba(self) -> tf.Tensor:
        """ shape = [batch_size, input_logit_tensor_length, num_tokens], dtype = tf.float32 """
        return tf.exp(self.logproba)

    @cached_property
    def logproba(self) -> tf.Tensor:
        mask = tf.expand_dims(tf.sequence_mask(lengths=self._logit_length, maxlen=self.max_logit_length), 2)
        blank_logprobas = tf.reshape(tf.math.log(tf.one_hot(self.blank_token_index, self.num_tokens)), shape=[1, 1, -1])
        logprobas = tf.where(
            condition=mask,
            x=self._logprobas,
            y=blank_logprobas,
        )
        return logprobas

    '''
    def cleaned_label(self) -> tf.Tensor:
        """ shape = [batch, max_label_length + 1] """
        _ = self.max_label_length_plus_one
    '''
    @cached_property
    def cleaned_label(self):
        # Repair padding- apparently, TPU/ GPU jit cannot handle the padding here; I'm not sure why. Anyway, it does not seem necessary in our case.
        labels = self._original_label[:, :self.max_label_length_plus_one]
        '''
        labels = tf.cond(
            pred=tf.shape(self._original_label)[1] > self.max_label_length,
            true_fn=lambda: self._original_label[:, :self.max_label_length_plus_one],
            false_fn=lambda: pad_until(
                tensor=self._original_label,
                desired_size=self.max_label_length_plus_one,
                pad_value=self.pad_token_index,
                axis=1
            )
        )
        '''
        mask = tf.sequence_mask(lengths=self._original_label_length, maxlen=tf.shape(labels)[1])
        blank_label = tf.ones_like(labels) * self.pad_token_index
        cleaned_label = tf.where(
            condition=mask,
            x=labels,
            y=blank_label,
        )
        return cleaned_label

    def select_from_act(self, act: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
        """Takes tensor of acts act_{b, a, t, u, ...} and labels label_{b,u},
        where b is the batch index, t is the logit index, and u is the label index,
        and returns for each token index k the tensor

            output_{b,a,t,k,...} = logsumexp_u act_{b,a,t,u_k,...} * kroneker_delta(u_k = label_{b,u})

        that is logarithmic sum of exponents of acts for all u_k = label_{b,u}, given b, t and k.

        Args:
            act:    tf.Tensor, shape = [batch_size, dim_a, max_logit_length, max_label_length + 1, ...]
            label:  tf.Tensor, shape = [batch_size, max_label_length + 1]

        Returns:    tf.Tensor, shape = [batch_size, max_label_length + 1, num_tokens, ...]
        """
        data = smart_transpose(a=act, perm=[0, 3, 2, 1])
        # shape = [batch_size, max_label_length + 1, max_logit_length, dim_a, ...]
        data = tf.squeeze(
            input=smart_reshape(
                tensor=data,
                shape=[1, self.batch_size * self.max_label_length_plus_one, self.max_logit_length]
            ),
            axis=0
        )
        # shape = [batch_size * (max_label_length + 1), max_logit_length, dim_a, ...]

        segment_ids = tf.reshape(label + tf.expand_dims(tf.range(self.batch_size), 1) * self.num_tokens, shape=[-1])
        # shape = [batch_size * (max_label_length + 1)]
        num_segments = self.batch_size * self.num_tokens

        output = unsorted_segment_logsumexp(data=data, segment_ids=segment_ids, num_segments=num_segments)
        # shape = [batch_size * num_tokens, max_logit_length, dim_a, ...]
        output = smart_reshape(tf.expand_dims(output, 0), [self.batch_size, self.num_tokens, self.max_logit_length])
        # shape = [batch_size, num_tokens, max_logit_length, dim_a, ...]
        output = smart_transpose(output, [0, 3, 2, 1])
        # shape = [batch_size, dim_a, max_logit_length, num_tokens, ...]
        return output

    @cached_property
    def max_logit_length_plus_one(self) -> tf.Tensor:
        return self.max_logit_length + tf.constant(1, dtype=tf.int32)

    @cached_property
    def max_logit_length(self) -> tf.Tensor:
        return tf.shape(self._logprobas)[1]

    @cached_property
    def max_label_length_plus_one(self) -> tf.Tensor:
        return self.max_label_length + tf.constant(1, dtype=tf.int32)

    @cached_property
    def max_label_length(self) -> tf.Tensor:
        return reduce_max_with_default(self._original_label_length, default=tf.constant(0, dtype=tf.int32))

    @cached_property
    def pad_token_index(self) -> tf.Tensor:
        return self.blank_token_index

    @cached_property
    def num_tokens(self) -> tf.Tensor:
        return tf.shape(self._logprobas)[2]

    @cached_property
    def blank_token_index(self) -> tf.Tensor:
        return self._blank_index

    @cached_property
    def logit_length_mask(self) -> tf.Tensor:
        """ shape = [batch_size, max_logit_length] """
        return tf.sequence_mask(
            lengths=self._logit_length,
            maxlen=self.max_logit_length,
        )

    @cached_property
    def label_length_mask(self) -> tf.Tensor:
        """ shape = [batch_size, max_label_length + 1], dtype = tf.bool """
        return tf.sequence_mask(lengths=self.label_length, maxlen=self.max_label_length_plus_one)

    @property
    def label_length(self) -> tf.Tensor:
        return self._original_label_length

    @cached_property
    def preceded_label(self) -> tf.Tensor:
        """Preceded label. For example, for label "abc_" the sequence "_abc" is returned.

        Returns:    tf.Tensor, shape = [batch_size, max_label_length + 1]
        """
        return tf.roll(self.label, shift=1, axis=1)

    @cached_property
    def label(self) -> tf.Tensor:
        """ shape = [batch, max_label_length + 1] """
        return self.cleaned_label

    @cached_property
    def batch_size(self) -> tf.Tensor:
        return tf.shape(self._logprobas)[0]

    @abstractmethod
    def combine_transition_probabilities(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor:
        """Given logarithmic probabilities a and b are merges like
            a, b -> log( exp a exp p exp b )
        """
        raise NotImplementedError()


def classic_ctc_loss(
        labels: tf.Tensor,
        logits: tf.Tensor,
        label_length: tf.Tensor,
        logit_length: tf.Tensor,
        blank_index: Union[int, tf.Tensor] = 0,
) -> tf.Tensor:
    """Computes CTC (Connectionist Temporal Classification) loss from
    http://www.cs.toronto.edu/~graves/icml_2006.pdf.

    Repeated non-blank labels will be merged.
    For example, predicted sequence
        a_bb_ccc_cc
    corresponds to label
        abcc
    where "_" is the blank token.

    If label length is longer then the logit length the output loss for the corresponding sample in the batch
    is +tf.inf and the gradient is 0. For example, for label "abb" at least 4 tokens are needed.
    It is because the output sequence must be at least "ab_b" in order to handle the repeated token.

    Args:
        labels:         tf.Tensor, shape = [batch, max_label_length],       dtype = tf.int32
        logits:         tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32
        label_length:   tf.Tensor, shape = [batch],                         dtype = tf.int32
        logit_length:   tf.Tensor, shape = [batch],                         dtype = tf.int32
        blank_index:    tf.Tensor or pythonic static integer between 0 <= blank_index < mum_tokens

    Returns:            tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32
    """
    return ctc_loss(
        labels=labels,
        logits=logits,
        label_length=label_length,
        logit_length=logit_length,
        blank_index=blank_index,
        ctc_loss_data_cls=ClassicCtcLossData
    )


class ClassicCtcLossData(BaseCtcLossData):
    """Calculate loss data for CTC (Connectionist Temporal Classification) loss from
    http://www.cs.toronto.edu/~graves/icml_2006.pdf.

    This loss is actually the logarithmic likelihood for the classification task with multiple expected class.
    All predicated sequences consist of tokens (denoted like "a", "b", ... below) and the blank "_".
    The classic CTC decoding merges all repeated non-blank labels and removes the blank.
    For example, predicted sequence
        a_bb_ccc_c is decoded as "abcc".
    All predicated sequences that coincided with the label after the decoding are the expected classes
    in the logarithmic likelihood loss function.

    Implementation:

    We calculate alpha_{b,t,l,s} and beta_{b,t,l,s} that are the logarithmic probabilities similar to
    this the ones from the sited paper and defined precisely below.
    Here, b corresponds to batch, t to logit position, l to label index, and s=0,1 to state (see below for details).

    During the decoding procedure, after handling of a part of the logit sequence,
    we predict only a part of the target label tokens. We call this subsequence the in the target space as "state".
    For example, two decode label "abc" we have to decode "a" first then add "b" and move tot the state "ab" and
    then to the state "abc".

    In order to handle the token duplication swap in the classic CTC loss we extend the set of all possible labels.
    For each token sequence we define two sequences called "closed" and "open".
    For example, for label "abc" we consider its two states denoted "abc>" (closed) and "abc<" (open).
    The difference between them is in their behaviour with respect to the token appending. The rules are:
        "...a>" + "_" -> "...a>",
        "...a<" + "_" -> "...a>",
        "...a>" + "a" -> "...aa<",
        "...a<" + "a" -> "...a<",
        "...a>" + "b" -> "...ab<",
        "...a<" + "b" -> "...ab<",
    for any different tokens "a" and "b" and any token sequence denoted by "...".
    Namely, appending a token the is equal to the last one to an open state does not change this state.
    Appending a blank to a state always males this state closed.

    This is why alpha_{b,t,l,s} and beta_{b,t,l,s} in the code below are equipped with an additional index s=0,1.
    Closed states corresponds s=0 and open ones to s=1.

    In particular, the flowing identity is satisfied
        sum_s sum_l exp alpha_{b,t,l,s} * exp beta_{b,t,l,s} = loss_{b}, for any b and t
    """
    @cached_property
    def diagonal_non_blank_grad_term(self) -> tf.Tensor:
        """ shape = [batch_size, max_logit_length, num_tokens] """
        input_tensor = \
            self.alpha[:, :-1] \
            + self.any_to_open_diagonal_step_log_proba \
            + tf.roll(self.beta[:, 1:, :, 1:], shift=-1, axis=2)
        # shape = [batch_size, max_logit_length, max_label_length + 1, states]
        act = tf.reduce_logsumexp(
            input_tensor=input_tensor,
            axis=3,
        )
        # shape = [batch_size, max_logit_length, max_label_length + 1]
        diagonal_non_blank_grad_term = self.select_from_act(act=act, label=self.label)
        # shape = [batch_size, max_logit_length, num_tokens]
        return diagonal_non_blank_grad_term

    @cached_property
    def horizontal_non_blank_grad_term(self) -> tf.Tensor:
        """Horizontal steps from repeated token: open alpha state to open beta state.

        Returns: shape = [batch_size, max_logit_length, num_tokens]
        """
        act = self.alpha[:, :-1, :, 1] + self.previous_label_token_log_proba + self.beta[:, 1:, :, 1]
        # shape = [batch_size, max_logit_length, max_label_length + 1]
        horizontal_non_blank_grad_term = self.select_from_act(act, self.preceded_label)
        return horizontal_non_blank_grad_term

    @cached_property
    def loss(self) -> tf.Tensor:
        """ shape = [batch_size] """
        params = tf.reduce_logsumexp(self.alpha[:, -1], -1)
        # shape = [batch_size, max_label_length + 1]
        loss = -tf.gather(
            params=params,                # shape = [batch_size, max_label_length + 1]
            indices=self.label_length,    # shape = [batch_size]
            batch_dims=1,
        )
        return loss

    @cached_property
    def gamma(self) -> tf.Tensor:
        """ shape = [
                batch_size,
                max_logit_length + 1,
                max_label_length + 1,
                state,
                max_logit_length + 1,
                max_label_length + 1,
                state,
            ],
        """
        # This is to avoid InaccessibleTensorError in graph mode
        _, _, _ = self.horizontal_step_log_proba, self.any_to_open_diagonal_step_log_proba, self.diagonal_gamma

        gamma_forward_transposed = unfold(
            init_tensor=self.diagonal_gamma,
            # init_tensor=tf.tile(self.diagonal_gamma, [self.batch_size, self.max_logit_length_plus_one, 1, 1, 1, 1]),
            iterfunc=self.gamma_step,
            d_i=1,
            num_iters=self.max_logit_length,
            element_shape=tf.TensorShape([None, None, None, None, None, None]),
            name="gamma_1",
        )
        # shape = [max_logit_length + 1, batch_size, max_logit_length + 1, max_label_length + 1, state,
        #   max_label_length + 1, state]

        gamma_forward = tf.transpose(gamma_forward_transposed, [1, 2, 3, 4, 0, 5, 6])
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state,
        #   max_logit_length + 1, max_label_length + 1, state]

        mask = expand_many_dims(
            input=tf.linalg.band_part(tf.ones(shape=[self.max_logit_length_plus_one] * 2, dtype=tf.bool), 0, -1),
            axes=[0, 2, 3, 5, 6]
        )
        # shape = [1, max_logit_length + 1, 1, 1, max_logit_length + 1, 1, 1]
        gamma = apply_logarithmic_mask(gamma_forward, mask)
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state,
        #   max_logit_length + 1, max_label_length + 1, state]

        return gamma

    def gamma_step(
        self,
        previous_slice: tf.Tensor,
        i: tf.Tensor,
    ) -> tf.Tensor:
        """Args:
            previous_slice: tf.Tensor,
                            shape = [batch_size, max_logit_length + 1, max_label_length + 1, state,
                                max_label_length + 1, state]
            i:              tf.Tensor,
                            shape = [], 0 <= i < max_logit_length + 1

        Returns:            tf.Tensor,
                            shape = [batch_size, max_logit_length + 1, max_label_length + 1, state,
                                max_label_length + 1, state]
        """
        horizontal_step_states = \
            expand_many_dims(self.horizontal_step_log_proba[:, i], axes=[1, 2, 3]) \
            + tf.expand_dims(previous_slice, 5)
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state,
        #          max_label_length + 1, next_state, previous_state]
        horizontal_step = tf.reduce_logsumexp(horizontal_step_states, axis=6)
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1, state]

        diagonal_step_log_proba = tf.reduce_logsumexp(
            expand_many_dims(self.any_to_open_diagonal_step_log_proba[:, i], axes=[1, 2, 3]) + previous_slice,
            axis=5
        )
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1]

        # We move by one token because it is a diagonal step
        moved_diagonal_step_log_proba = tf.roll(diagonal_step_log_proba, shift=1, axis=4)
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1]

        # Out state is always open:
        diagonal_step = tf.pad(
            tensor=tf.expand_dims(moved_diagonal_step_log_proba, 5),
            paddings=[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [1, 0]],
            constant_values=-np.inf
        )
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1, state]
        new_gamma_slice = logsumexp(
            x=horizontal_step,
            y=diagonal_step,
        )
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1, state]

        condition = tf.reshape(tf.range(self.max_logit_length_plus_one) <= i, shape=[1, -1, 1, 1, 1, 1])
        # shape = [1, max_logit_length + 1, 1, 1, 1, 1, 1]
        output_slice = tf.where(
            condition=condition,
            x=new_gamma_slice,
            y=self.diagonal_gamma,
        )
        # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1, state]

        return output_slice

    @cached_property
    def diagonal_gamma(self) -> tf.Tensor:
        """ shape = [batch_size, max_logit_length_plus_one, max_label_length + 1, state,
                     max_label_length + 1, state]
        """
        diagonal_gamma = tf.math.log(
            tf.reshape(
                tensor=tf.eye(self.max_label_length_plus_one * 2, dtype=tf.float32),
                shape=[1, 1, self.max_label_length_plus_one, 2, self.max_label_length_plus_one, 2]
            )
        )
        diagonal_gamma = tf.tile(diagonal_gamma, [self.batch_size, self.max_logit_length_plus_one, 1, 1, 1, 1])
        return diagonal_gamma

    @cached_property
    def beta(self) -> tf.Tensor:
        """Calculates the beta_{b,t,l,s} that is logarithmic probability of sample 0 <= b < batch_size - 1 in the batch
        with logit subsequence from
            t, t + 1, ... max_logit_length - 2, max_logit_length - 1,
        for t < max_logit_length
        to predict the sequence of tokens
            w_max_label_length, w_{max_label_length + 1}, ... w_{max_label_length - 2}, w_{max_label_length - 1}
        for l < max_label_length
        that is either closed s=0 or open s=1.
        from label_b = [w_0, w_1, ... w_{max_label_length - 2}, w_{max_label_length - 1}].

        This logarithmic probability is calculated by iterations
            exp beta_{t-1,l} = p_horizontal_step_{t-1,l} * exp beta_{t,l} + p_diagonal_step_{t-1,l} * exp beta_{t,l+1},
        for 0 <= t < max_logit_length,
        where p_diagonal_step_{t,l} is the probability to predict label token w_l with logit l
        and p_horizontal_step_{t,l} is the probability to skip token w_l prediction with logit l, for example, with
        the blank prediction.

        Returns:    tf.Tensor,  shape = [batch_size, max_logit_length + 1, max_label_length + 1, state],
                    dtype = tf.float32
        """
        # This is to avoid InaccessibleTensorError in graph mode
        _, _ = self.horizontal_step_log_proba, self.any_to_open_diagonal_step_log_proba

        beta = unfold(
            init_tensor=self.last_beta_slice,
            iterfunc=self.beta_step,
            d_i=-1,
            num_iters=self.max_logit_length,
            element_shape=tf.TensorShape([None, None, 2]),
            name="beta",
        )
        # shape = [logit_length + 1, batch, label_length + 1, state]
        return tf.transpose(beta, [1, 0, 2, 3])

    def beta_step(self, previous_slice: tf.Tensor, i: tf.Tensor) -> tf.Tensor:
        """ shape = [batch_size, max_label_length + 1, state] """
        horizontal_step = \
            tf.reduce_logsumexp(self.horizontal_step_log_proba[:, i] + tf.expand_dims(previous_slice, 3), 2)
        # shape = [batch_size, max_label_length + 1, state]
        diagonal_step = \
            self.any_to_open_diagonal_step_log_proba[:, i] + tf.roll(previous_slice[:, :, 1:], shift=-1, axis=1)
        # shape = [batch_size, max_label_length + 1, state]
        new_beta_slice = logsumexp(
            x=horizontal_step,  # shape = [batch_size, max_label_length + 1, state]
            y=diagonal_step,    # shape = [batch_size, max_label_length + 1, state]
        )
        # shape = [batch_size, max_label_length + 1, state]
        return new_beta_slice

    @cached_property
    def last_beta_slice(self) -> tf.Tensor:
        """ shape = [batch_size, max_label_length + 1, state] """
        beta_last = tf.math.log(tf.one_hot(indices=self.label_length, depth=self.max_label_length_plus_one))
        beta_last = tf.tile(input=tf.expand_dims(beta_last, axis=2), multiples=[1, 1, 2])
        return beta_last

    @cached_property
    def alpha(self) -> tf.Tensor:
        """Calculates the alpha_{b,t,l,s} that is
        the logarithmic probability of sample 0 <= b < batch_size - 1 in the batch
        with logits subsequence from 0, 1, 2, ... t - 2, t - 1, for t < max_logit_length
        to predict the sequence of tokens w_0, w_1, w_2, ... w_{l-2}, w_{l-1} for l < max_label_length + 1
        that is either closed s=0 or open s=1.
        from label_b = [w_0, w_1, ... w_{max_label_length - 2}, w_{max_label_length - 1}].

        This logarithmic probability is calculated by iterations
            exp alpha_{t + 1,l} = p_horizontal_step_{t,l} * exp alpha_{t,l} + p_diagonal_step_{t,l} * exp alpha_{t,l-1},
        for 0 <= t < max_logit_length,
        where p_diagonal_step_{t,l} is the probability to predict label token w_l with logit l
        and p_horizontal_step_{t,l} is the probability to skip token w_l prediction with logit l, for example, with
        the blank prediction.

        Returns:    tf.Tensor,  shape = [batch_size, max_logit_length + 1, max_label_length + 1, state],
                    dtype = tf.float32
        """
        # This is to avoid InaccessibleTensorError in graph mode
        _, _ = self.horizontal_step_log_proba, self.any_to_open_diagonal_step_log_proba

        alpha = unfold(
            init_tensor=self.first_alpha_slice,
            iterfunc=self.alpha_step,
            d_i=1,
            num_iters=self.max_logit_length,
            element_shape=tf.TensorShape([None, None, 2]),
            name="alpha",
        )
        # shape = [logit_length + 1, batch_size, label_length + 1, state]
        return tf.transpose(alpha, [1, 0, 2, 3])

    def alpha_step(self, previous_slice: tf.Tensor, i: tf.Tensor) -> tf.Tensor:
        """Args:
            previous_slice: shape = [batch_size, max_label_length + 1, state]
            i:

        Returns:            shape = [batch_size, max_label_length + 1, state]
        """
        temp = self.horizontal_step_log_proba[:, i] + tf.expand_dims(previous_slice, 2)
        # shape = [batch_size, max_label_length + 1, next_state, previous_state]
        horizontal_step = tf.reduce_logsumexp(temp, 3)
        # shape = [batch_size, max_label_length + 1, state]
        diagonal_step_log_proba = \
            tf.reduce_logsumexp(self.any_to_open_diagonal_step_log_proba[:, i] + previous_slice, 2)
        # shape = [batch_size, max_label_length + 1]

        # We move by one token because it is a diagonal step
        moved_diagonal_step_log_proba = tf.roll(diagonal_step_log_proba, shift=1, axis=1)
        # shape = [batch_size, max_label_length + 1]

        # Out state is always open:
        diagonal_step = tf.pad(
            tensor=tf.expand_dims(moved_diagonal_step_log_proba, 2),
            paddings=[[0, 0], [0, 0], [1, 0]],
            constant_values=-np.inf
        )
        # shape = [batch_size, max_label_length + 1, state]
        new_alpha_slice = logsumexp(
            x=horizontal_step,
            y=diagonal_step,
        )
        # shape = [batch_size, max_label_length + 1, state]
        return new_alpha_slice

    @cached_property
    def first_alpha_slice(self) -> tf.Tensor:
        """ shape = [batch_size, max_label_length + 1, state] """
        alpha_0 = tf.math.log(tf.one_hot(indices=0, depth=self.max_label_length_plus_one * 2))
        alpha_0 = tf.tile(input=tf.reshape(alpha_0, [1, -1, 2]), multiples=[self.batch_size, 1, 1])
        return alpha_0

    @cached_property
    def any_to_open_diagonal_step_log_proba(self) -> tf.Tensor:
        """Logarithmic probability to make a diagonal step from given state to an open state

        Returns:shape = [batch_size, max_logit_length, max_label_length + 1, state]
        """
        return tf.stack(
            values=[self.closed_to_open_diagonal_step_log_proba, self.open_to_open_diagonal_step_log_proba],
            axis=3
        )

    @cached_property
    def open_to_open_diagonal_step_log_proba(self) -> tf.Tensor:
        """Logarithmic probability to make a diagonal step from an open state to an open state
        with expected token prediction that is different from the previous one.

        Returns:shape = [batch_size, max_logit_length, max_label_length + 1]
        """
        # We check that the predicting token does not equal to previous one
        token_repetition_mask = self.label != tf.roll(self.label, shift=1, axis=1)
        # shape = [batch_size, max_label_length + 1]
        open_diagonal_step_log_proba = \
            apply_logarithmic_mask(
                self.closed_to_open_diagonal_step_log_proba,
                tf.expand_dims(token_repetition_mask, axis=1)
            )
        return open_diagonal_step_log_proba

    @cached_property
    def closed_to_open_diagonal_step_log_proba(self) -> tf.Tensor:
        """Logarithmic probability to make a diagonal step from a closed state to an open state
        with expected token prediction.

        Returns:shape = [batch_size, max_logit_length, max_label_length + 1]
        """
        return self.expected_token_logproba

    @cached_property
    def horizontal_step_log_proba(self) -> tf.Tensor:
        """Calculates logarithmic probability of the horizontal step for given logit x label position.

        This is possible in two alternative cases:
        1. Blank
        2. Not blank token from previous label position.

        Returns: tf.Tensor, shape = [batch_size, max_logit_length, max_label_length + 1, next_state, previous_state]
        """
        # We map closed and open states to closed states
        blank_term = tf.tile(
            input=tf.expand_dims(tf.expand_dims(self.blank_logproba, 2), 3),
            multiples=[1, 1, self.max_label_length_plus_one, 2]
        )
        # shape = [batch_size, max_logit_length, max_label_length + 1, 2]
        non_blank_term = tf.pad(
            tf.expand_dims(self.not_blank_horizontal_step_log_proba, 3),
            paddings=[[0, 0], [0, 0], [0, 0], [1, 0]],
            constant_values=tf.constant(-np.inf),
        )
        # shape = [batch_size, max_logit_length, max_label_length + 1, 2]
        horizontal_step_log_proba = tf.stack([blank_term, non_blank_term], axis=3)
        return horizontal_step_log_proba

    @cached_property
    def not_blank_horizontal_step_log_proba(self) -> tf.Tensor:
        """ shape = [batch_size, max_logit_length, max_label_length + 1] """
        mask = tf.reshape(1 - tf.one_hot(self.blank_token_index, depth=self.num_tokens), shape=[1, 1, -1])
        not_blank_log_proba = apply_logarithmic_mask(self.logproba, mask)
        not_blank_horizontal_step_log_proba = tf.gather(
            params=not_blank_log_proba,
            indices=tf.roll(self.label, shift=1, axis=1),
            axis=2,
            batch_dims=1,
        )
        # shape = [batch_size, max_logit_length, max_label_length + 1]
        return not_blank_horizontal_step_log_proba

    @cached_property
    def previous_label_token_log_proba(self) -> tf.Tensor:
        """Calculates the probability to predict token that preceded to label token.

        Returns:    tf.Tensor,  shape = [batch_size, max_logit_length, max_label_length + 1]
        """
        previous_label_token_log_proba = tf.gather(
            params=self.logproba,
            indices=self.preceded_label,
            axis=2,
            batch_dims=1,
        )
        # shape = [batch_size, max_logit_length, max_label_length + 1]
        return previous_label_token_log_proba

    @cached_property
    def blank_logproba(self) -> tf.Tensor:
        """ shape = [batch_size, max_logit_length] """
        return self.logproba[:, :, self.blank_token_index]

    def combine_transition_probabilities(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor:
        """Transforms logarithmic transition probabilities a and b.

        Args:
            a:      shape = [batch, DIMS_A, max_logit_length, max_label_length + 1, state]
            b:      shape = [batch, max_logit_length, max_label_length + 1, state, DIMS_B]

        Returns:    shape = [batch, DIMS_A, max_logit_length, num_tokens, DIMS_B]
        """
        assert len(a.shape) >= 4
        assert len(b.shape) >= 4
        assert a.shape[-1] == 2
        assert b.shape[3] == 2

        dims_a = tf.shape(a)[1:-3]
        dims_b = tf.shape(b)[4:]
        a = tf.reshape(a, shape=[self.batch_size, -1, self.max_logit_length, self.max_label_length_plus_one, 2, 1])
        # shape = [batch_size, dims_a, max_logit_length, max_label_length + 1, state, 1]
        b = tf.reshape(b, shape=[self.batch_size, 1, self.max_logit_length, self.max_label_length_plus_one, 2, -1])
        # shape = [batch_size, 1, max_logit_length, max_label_length + 1, state, dims_b]

        # Either open or closed state from alpha and only closed state from beta
        ab_term = tf.reduce_logsumexp(a, 4) + b[:, :, :, :, 0]
        # shape = [batch_size, dims_a, max_logit_length, max_label_length + 1, dims_b]

        horizontal_blank_grad_term = \
            expand_many_dims(self.blank_logproba, axes=[1, 3]) + tf.reduce_logsumexp(ab_term, axis=3)
        # shape = [batch_size, dims_a, max_logit_length, dims_b]

        act = a[:, :, :, :, 1] + expand_many_dims(self.previous_label_token_log_proba, axes=[1, 4]) + b[:, :, :, :, 1]
        # shape = [batch_size, dim_a, max_logit_length, max_label_length + 1, dim_b]

        horizontal_non_blank_grad_term = self.select_from_act(act, self.preceded_label)
        # shape = [batch_size, dim_a, max_logit_length, num_tokens, dim_b]

        input_tensor = a + expand_many_dims(self.any_to_open_diagonal_step_log_proba, axes=[1, 5]) + \
            tf.roll(b[:, :, :, :, 1:], shift=-1, axis=3)
        # shape = [batch_size, dim_a, max_logit_length, max_label_length + 1, states, dim_b]

        act = tf.reduce_logsumexp(input_tensor=input_tensor, axis=4)
        # shape = [batch_size, dim_a, max_logit_length, max_label_length + 1, dim_b]

        diagonal_non_blank_grad_term = self.select_from_act(act=act, label=self.label)
        # shape = [batch_size, dim_a, max_logit_length, num_tokens, dim_b]

        non_blank_grad_term = logsumexp(horizontal_non_blank_grad_term, diagonal_non_blank_grad_term)
        # shape = [batch_size, dim_a, max_logit_length, num_tokens, dim_b]

        blank_mask = self.blank_token_index == tf.range(self.num_tokens)
        # shape = [num_tokens]

        output = tf.where(
            condition=expand_many_dims(blank_mask, axes=[0, 1, 2, 4]),
            x=tf.expand_dims(horizontal_blank_grad_term, 3),
            y=non_blank_grad_term,
        )
        # shape = [batch, dim_a, max_logit_length, num_tokens, dim_b]
        output_shape = tf.concat(
            [
                tf.expand_dims(self.batch_size, axis=0),
                dims_a,
                tf.expand_dims(self.max_logit_length, axis=0),
                tf.expand_dims(self.num_tokens, axis=0),
                dims_b
            ],
            axis=0
        )
        output_reshaped = tf.reshape(output, shape=output_shape)
        # shape = [batch, DIMS_A, max_logit_length, num_tokens, DIMS_B]

        return output_reshaped

# Strategy

In [30]:
import tensorflow as tf


tpu_strategy=None
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])

  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  tpu_strategy = tf.distribute.TPUStrategy(tpu)

except ValueError:
  print("Not using TPU")
  #raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')


#from tensorflow.python.framework.ops import disable_eager_execution
#disable_eager_execution()  # LSTM layer can't use bfloat16 unless we do this.

Not using TPU


In [31]:
print("TensorFlow v" + tf.__version__)


TensorFlow v2.12.0


In [32]:
class MemoryUsageCallbackExtended(tf.keras.callbacks.Callback):
    """Monitor memory usage on epoch begin and end, collect garbage"""

    # def on_epoch_begin(self, epoch, logs=None):
    #    print("**Epoch {}**".format(epoch))
    #    print(
    #        f"Memory usage on epoch begin: {int(psutil.Process(os.getpid()).memory_info().rss)/1e9:.2f GB}"
    #    )

    def on_epoch_end(self, epoch, logs=None):
        print(
            f"Memory usage on epoch end: {int(psutil.Process(os.getpid()).memory_info().rss)/1e9:.2f} GB"
        )
        gc.collect()

# Scheduler

In [33]:

class CosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    """A LearningRateSchedule that uses a cosine decay with optional warmup.

    See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
    SGDR: Stochastic Gradient Descent with Warm Restarts.

    For the idea of a linear warmup of our learning rate,
    see [Goyal et al.](https://arxiv.org/pdf/1706.02677.pdf).

    When we begin training a model, we often want an initial increase in our
    learning rate followed by a decay. If `warmup_target` is an int, this
    schedule applies a linear increase per optimizer step to our learning rate
    from `initial_learning_rate` to `warmup_target` for a duration of
    `warmup_steps`. Afterwards, it applies a cosine decay function taking our
    learning rate from `warmup_target` to `alpha` for a duration of
    `decay_steps`. If `warmup_target` is None we skip warmup and our decay
    will take our learning rate from `initial_learning_rate` to `alpha`.
    It requires a `step` value to  compute the learning rate. You can
    just pass a TensorFlow variable that you increment at each training step.

    The schedule is a 1-arg callable that produces a warmup followed by a
    decayed learning rate when passed the current optimizer step. This can be
    useful for changing the learning rate value across different invocations of
    optimizer functions.

    Our warmup is computed as:

    ```python
    def warmup_learning_rate(step):
        completed_fraction = step / warmup_steps
        total_delta = target_warmup - initial_learning_rate
        return completed_fraction * total_delta
    ```

    And our decay is computed as:

    ```python
    if warmup_target is None:
        initial_decay_lr = initial_learning_rate
    else:
        initial_decay_lr = warmup_target

    def decayed_learning_rate(step):
        step = min(step, decay_steps)
        cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))
        decayed = (1 - alpha) * cosine_decay + alpha
        return initial_decay_lr * decayed
    ```

    Example usage without warmup:

    ```python
    decay_steps = 1000
    initial_learning_rate = 0.1
    lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate, decay_steps)
    ```

    Example usage with warmup:

    ```python
    decay_steps = 1000
    initial_learning_rate = 0
    warmup_steps = 1000
    target_learning_rate = 0.1
    lr_warmup_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate, decay_steps, warmup_target=target_learning_rate,
        warmup_steps=warmup_steps
    )
    ```

    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
    as the learning rate. The learning rate schedule is also serializable and
    deserializable using `tf.keras.optimizers.schedules.serialize` and
    `tf.keras.optimizers.schedules.deserialize`.

    Returns:
      A 1-arg callable learning rate schedule that takes the current optimizer
      step and outputs the decayed learning rate, a scalar `Tensor` of the same
      type as `initial_learning_rate`.
    """

    def __init__(
        self,
        initial_learning_rate,
        decay_steps,
        alpha=0.0,
        name=None,
        warmup_target=None,
        warmup_steps=0,
    ):
        """Applies cosine decay to the learning rate.

        Args:
          initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
            Python int. The initial learning rate.
          decay_steps: A scalar `int32` or `int32` `Tensor` or a Python int.
            Number of steps to decay over.
          alpha: A scalar `float32` or `float64` `Tensor` or a Python int.
            Minimum learning rate value for decay as a fraction of
            `initial_learning_rate`.
          name: String. Optional name of the operation.  Defaults to
            'CosineDecay'.
          warmup_target: None or a scalar `float32` or `float64` `Tensor` or a
            Python int. The target learning rate for our warmup phase. Will cast
            to the `initial_learning_rate` datatype. Setting to None will skip
            warmup and begins decay phase from `initial_learning_rate`.
            Otherwise scheduler will warmup from `initial_learning_rate` to
            `warmup_target`.
          warmup_steps: A scalar `int32` or `int32` `Tensor` or a Python int.
            Number of steps to warmup over.
        """
        super().__init__()

        self.initial_learning_rate = initial_learning_rate
        self.decay_steps = decay_steps
        self.alpha = alpha
        self.name = name
        self.warmup_steps = warmup_steps
        self.warmup_target = warmup_target

    def _decay_function(self, step, decay_steps, decay_from_lr, dtype):
        with tf.name_scope(self.name or "CosineDecay"):
            completed_fraction = step / decay_steps
            tf_pi = tf.constant(math.pi, dtype=dtype)
            cosine_decayed = 0.5 * (1.0 + tf.cos(tf_pi * completed_fraction))
            decayed = (1 - self.alpha) * cosine_decayed + self.alpha
            return tf.multiply(decay_from_lr, decayed)

    def _warmup_function(self, step, warmup_steps, warmup_target, initial_learning_rate):
        with tf.name_scope(self.name or "CosineDecay"):
            completed_fraction = step / warmup_steps
            total_step_delta = warmup_target - initial_learning_rate
            return total_step_delta * completed_fraction + initial_learning_rate

    def __call__(self, step):
        with tf.name_scope(self.name or "CosineDecay"):
            initial_learning_rate = tf.convert_to_tensor(
                self.initial_learning_rate, name="initial_learning_rate"
            )
            dtype = initial_learning_rate.dtype
            decay_steps = tf.cast(self.decay_steps, dtype)
            global_step_recomp = tf.cast(step, dtype)

            if self.warmup_target is None:
                global_step_recomp = tf.minimum(global_step_recomp, decay_steps)
                return self._decay_function(
                    global_step_recomp,
                    decay_steps,
                    initial_learning_rate,
                    dtype,
                )

            warmup_target = tf.cast(self.warmup_target, dtype)
            warmup_steps = tf.cast(self.warmup_steps, dtype)

            global_step_recomp = tf.minimum(global_step_recomp, decay_steps + warmup_steps)

            return tf.cond(
                global_step_recomp < warmup_steps,
                lambda: self._warmup_function(
                    global_step_recomp,
                    warmup_steps,
                    warmup_target,
                    initial_learning_rate,
                ),
                lambda: self._decay_function(
                    global_step_recomp - warmup_steps,
                    decay_steps,
                    warmup_target,
                    dtype,
                ),
            )

    def get_config(self):
        return {
            "initial_learning_rate": self.initial_learning_rate,
            "decay_steps": self.decay_steps,
            "alpha": self.alpha,
            "name": self.name,
            "warmup_target": self.warmup_target,
            "warmup_steps": self.warmup_steps,
        }


# Constants

In [34]:
def get_char_dict():
    char_dict = {
        " ": 0,
        "!": 1,
        "#": 2,
        "$": 3,
        "%": 4,
        "&": 5,
        "'": 6,
        "(": 7,
        ")": 8,
        "*": 9,
        "+": 10,
        ",": 11,
        "-": 12,
        ".": 13,
        "/": 14,
        "0": 15,
        "1": 16,
        "2": 17,
        "3": 18,
        "4": 19,
        "5": 20,
        "6": 21,
        "7": 22,
        "8": 23,
        "9": 24,
        ":": 25,
        ";": 26,
        "=": 27,
        "?": 28,
        "@": 29,
        "[": 30,
        "_": 31,
        "a": 32,
        "b": 33,
        "c": 34,
        "d": 35,
        "e": 36,
        "f": 37,
        "g": 38,
        "h": 39,
        "i": 40,
        "j": 41,
        "k": 42,
        "l": 43,
        "m": 44,
        "n": 45,
        "o": 46,
        "p": 47,
        "q": 48,
        "r": 49,
        "s": 50,
        "t": 51,
        "u": 52,
        "v": 53,
        "w": 54,
        "x": 55,
        "y": 56,
        "z": 57,
        "~": 58,
    }
    char_dict["P"] = 59
    #char_dict["SOS"] = 60
    #char_dict["EOS"] = 61
    return char_dict


class Constants:
    ROWS_PER_FRAME = 543
    MAX_STRING_LEN = 50
    INPUT_PAD = -100.0
    char_dict = get_char_dict()
    LABEL_PAD = char_dict["P"]
    inv_dict = {v: k for k, v in char_dict.items()}
    NOSE = [1, 2, 98, 327]

    REYE = [33, 7, 163, 144, 145, 153, 154, 155, 133, 246, 161, 160, 159, 158, 157, 173]
    LEYE = [263, 249, 390, 373, 374, 380, 381, 382, 362, 466, 388, 387, 386, 385, 384, 398]

    LHAND = list(range(468, 489))
    RHAND = list(range(522, 543))

    LNOSE = [98]
    RNOSE = [327]

    LLIP = [84, 181, 91, 146, 61, 185, 40, 39, 37, 87, 178, 88, 95, 78, 191, 80, 81, 82]
    RLIP = [
        314,
        405,
        321,
        375,
        291,
        409,
        270,
        269,
        267,
        317,
        402,
        318,
        324,
        308,
        415,
        310,
        311,
        312,
    ]
    POSE = [500, 502, 504, 501, 503, 505, 512, 513]
    LPOSE = [513, 505, 503, 501]
    RPOSE = [512, 504, 502, 500]

    POINT_LANDMARKS_PARTS = [LHAND, RHAND, LLIP, RLIP, LPOSE, RPOSE, NOSE, REYE, LEYE]
    # POINT_LANDMARKS_PARTS = [LHAND, RHAND, NOSE]
    POINT_LANDMARKS = [item for sublist in POINT_LANDMARKS_PARTS for item in sublist]
    parts = {
        "LLIP": LLIP,
        "RLIP": RLIP,
        "LHAND": LHAND,
        "RHAND": RHAND,
        "LPOSE": LPOSE,
        "RPOSE": RPOSE,
        "LNOSE": LNOSE,
        "RNOSE": RNOSE,
        "REYE": REYE,
        "LEYE": LEYE,
    }

    LANDMARK_INDICES = {}  # type: ignore
    for part in parts:
        LANDMARK_INDICES[part] = []
        for landmark in parts[part]:
            if landmark in POINT_LANDMARKS:
                LANDMARK_INDICES[part].append(POINT_LANDMARKS.index(landmark))

    CENTER_LANDMARKS = LNOSE + RNOSE
    CENTER_INDICES = LANDMARK_INDICES["LNOSE"] + LANDMARK_INDICES["RNOSE"]

    NUM_NODES = len(POINT_LANDMARKS)
    NUM_INPUT_FEATURES = 2 * NUM_NODES # (x,y)
    CHANNELS = 6 * NUM_NODES #(x,y,dx,dy,dx2,dy2)


# Utils

In [35]:

# Seed all random number generators
def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)


def selected_columns(file_example):
    df = pd.read_parquet(file_example)
    selected_x = df.columns[[x + 1 for x in Constants.POINT_LANDMARKS]].tolist()
    selected_y = [c.replace("x", "y") for c in selected_x]
    selected = []
    for i in range(Constants.NUM_NODES):
        selected.append(selected_x[i])
        selected.append(selected_y[i])
    return selected  # x1,y1,x2,y2,...



def num_to_char_fn(y):
    return [Constants.inv_dict.get(x, "") for x in y]


# A callback class to output a few transcriptions during training
class CallbackEval(tf.keras.callbacks.Callback):
    """Displays a batch of outputs after every epoch."""

    def __init__(self, model, dataset):
        super().__init__()
        self.dataset = dataset
        self.model = model

    def on_epoch_end(self, epoch: int, logs=None):
        predictions = []
        targets = []
        for batch in self.dataset:
            X, y = batch
            batch_predictions = self.model(X)
            batch_predictions = decode_batch_predictions(batch_predictions)
            predictions.extend(batch_predictions)
            for label in y:
                label = "".join(num_to_char_fn(label.numpy()))
                targets.append(label)
        print("-" * 100)
        # for i in np.random.randint(0, len(predictions), 2):
        for i in range(10):
            print(f"Target    : {targets[i]}")
            print(f"Prediction: {predictions[i]}, len: {len(predictions[i])}")
            print("-" * 100)


def decode_phrase(pred):
    # decode cts prediction by prunning
    # (T,CHAR_NUMS)
    x = tf.argmax(pred, axis=1) # (T,)
    paddings = tf.constant(
        [
            [0, 1],
        ]
    )
    x = tf.pad(x, paddings)
    diff = tf.not_equal(x[:-1], x[1:])
    adjacent_indices = tf.where(diff)[:, 0]
    x = tf.gather(x, adjacent_indices)
    mask = x != Constants.LABEL_PAD
    x = tf.boolean_mask(x, mask, axis=0)
    return x


# A utility function to decode the output of the network
def decode_batch_predictions(pred):
    output_text = []
    for result in pred:
        result = "".join(num_to_char_fn(decode_phrase(result).numpy()))
        output_text.append(result)
    return output_text




def code_to_label(label_code):
    label = [Constants.inv_dict[x] for x in label_code if Constants.inv_dict[x] != "P"]
    label = "".join(label)
    return label


def convert_to_strings(batch_label_code):
    output = []
    for label_code in batch_label_code:
        output.append(code_to_label(label_code))
    return output


def global_metric(val_ds, model):
    global_N, global_D = 0, 0
    count = 0
    metric = LevDistanceMetric()
    for batch in val_ds:
        count += 1
        print(count)
        feature, label = batch
        logits = model(feature)
        _, _, D = batch_edit_distance(label, logits)
        metric.update_state(label, logits)

    result = metric.result().numpy()

    return result


def sparse_from_dense_ignore_value(dense_tensor):
    mask = tf.not_equal(dense_tensor, Constants.LABEL_PAD)
    indices = tf.where(mask)
    values = tf.boolean_mask(dense_tensor, mask)
    
    return tf.SparseTensor(indices, values, tf.shape(dense_tensor, out_type=tf.int64))


def batch_edit_distance(y_true, y_logits):
    blank = Constants.LABEL_PAD
    #y_true=tf.ensure_shape(y_true,(128,Constants.MAX_STRING_LEN))
    #y_logits=tf.ensure_shape(y_logits,(128,128,60))
    #tf.print("edit distance true shape",tf.shape(y_true))
    #tf.print("edit distance logits shape",tf.shape(y_logits))

    B = tf.shape(y_logits)[0]
    seq_length = tf.shape(y_logits)[1]
    to_decode = tf.transpose(y_logits, perm=[1, 0, 2])
    sequence_length = tf.fill(dims=[B], value=seq_length)
    hypothesis = tf.nn.ctc_greedy_decoder(
        tf.cast(to_decode, tf.float32), sequence_length, blank_index=blank
    )[0][
        0
    ]  # full is [B,...]
    
    truth = sparse_from_dense_ignore_value(y_true)  # full is [B,...]
    truth = tf.cast(truth, hypothesis.dtype)
    edit_dist = tf.edit_distance(hypothesis, truth, normalize=False)

    non_ignore_mask = tf.not_equal(y_true, blank)
    N = tf.reduce_sum(tf.cast(non_ignore_mask, tf.float32))
    D = tf.reduce_sum(edit_dist)
    result = (N - D) / N
    result = tf.clip_by_value(result, 0.0, 1.0)
    return result, N, D


class LevDistanceMetric(tf.keras.metrics.Metric):
    def __init__(self, name="Lev", **kwargs):
        super().__init__(name=name, **kwargs)
        self.distance = self.add_weight(name="dist", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_logits, sample_weight=None):
        # if using with keras compile, make sure the model outputs logits, not softmax probabilities
        _, N, D = batch_edit_distance(y_true, y_logits)
        self.distance.assign_add(D)
        self.count.assign_add(N)

    def result(self):
        result = (self.count - self.distance) / self.count
        result = tf.clip_by_value(result, 0.0, 1.0)
        return result

    def reset_state(self):
        self.count.assign(0.0)
        self.distance.assign(0.0)



class SWA(tf.keras.callbacks.Callback):
    # Stochastic Weight Averaging
    def __init__(
        self,
        save_name,
        swa_epochs=[],
        strategy=None,
        train_ds=None,
        valid_ds=None,
        train_steps=1000,
    ):
        super().__init__()
        self.swa_epochs = swa_epochs
        self.swa_weights = None
        self.save_name = save_name
        self.train_ds = train_ds
        self.valid_ds = valid_ds
        self.train_steps = train_steps
        self.strategy = strategy

    def train_step(self, iterator):
        """The step function for one training step."""

        def step_fn(inputs):
            """The computation to run on each device."""
            x, y = inputs
            _ = self.model(x, training=True)

        for x in iterator:
            self.strategy.run(step_fn, args=(x,))

    def on_epoch_end(self, epoch, logs=None):
        if epoch in self.swa_epochs:
            if self.swa_weights is None:
                self.swa_weights = self.model.get_weights()
            else:
                w = self.model.get_weights()
                for i in range(len(self.swa_weights)):
                    self.swa_weights[i] += w[i]

    def on_train_end(self, logs=None):
        if len(self.swa_epochs):
            print("applying SWA...")
            for i in range(len(self.swa_weights)):
                self.swa_weights[i] = self.swa_weights[i] / len(self.swa_epochs)
            self.model.set_weights(self.swa_weights)
            if self.train_ds is not None:  # for the re-calculation of running mean and var
                self.train_step(self.train_ds.take(self.train_steps))
            print(f"save SWA weights to {self.save_name}-SWA.h5")
            self.model.save_weights(f"{self.save_name}-SWA.h5")
            if self.valid_ds is not None:
                self.model.evaluate(self.valid_ds)


class AWP(tf.keras.Model):
    # Adversarial Weight Perturbation
    def __init__(self, *args, delta=0.1, eps=1e-4, start_step=0, **kwargs):
        super().__init__(*args, **kwargs)
        self.delta = delta
        self.eps = eps
        self.start_step = start_step

    def train_step_awp(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        params = self.trainable_variables
        params_gradients = tape.gradient(loss, self.trainable_variables)
        for i in range(len(params_gradients)):
            grad = tf.zeros_like(params[i]) + params_gradients[i]
            delta = tf.math.divide_no_nan(
                self.delta * grad, tf.math.sqrt(tf.reduce_sum(grad**2)) + self.eps
            )
            self.trainable_variables[i].assign_add(delta)
        with tf.GradientTape() as tape2:
            y_pred = self(x, training=True)
            new_loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
            if hasattr(self.optimizer, "get_scaled_loss"):
                new_loss = self.optimizer.get_scaled_loss(new_loss)

        gradients = tape2.gradient(new_loss, self.trainable_variables)
        if hasattr(self.optimizer, "get_unscaled_gradients"):
            gradients = self.optimizer.get_unscaled_gradients(gradients)
        for i in range(len(params_gradients)):
            grad = tf.zeros_like(params[i]) + params_gradients[i]
            delta = tf.math.divide_no_nan(
                self.delta * grad, tf.math.sqrt(tf.reduce_sum(grad**2)) + self.eps
            )
            self.trainable_variables[i].assign_sub(delta)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # self_loss.update_state(loss)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def train_step(self, data):
        return tf.cond(
            self._train_counter < self.start_step,
            lambda: super(AWP, self).train_step(data),
            lambda: self.train_step_awp(data),
        )


# Lev Callback

In [36]:
import Levenshtein as lev
import json
with open ("/kaggle/input/asl-fingerspelling/character_to_prediction_index.json", "r") as f:
    character_map = json.load(f)
rev_character_map = {j:i for i,j in character_map.items()}
class val_lev_callback(tf.keras.callbacks.Callback):
    def __init__(self,val_set):
        super().__init__()
        self.val_set=val_set
    def on_epoch_end(self, epoch: int, logs=None):
        calculate_val_lev(self.model,self.val_set)
        
def calculate_val_lev(model,val_set):
    preds = []
    targets = []
    for batch_idx in range(len(val_set)):
        preds_batch = model.predict(val_set[batch_idx][0],verbose=0)[0]
        targets_batch = val_set[batch_idx][1]
        for pred_idx in range(len(preds_batch)):
            preds.append("".join([rev_character_map.get(s, "") for s in decode_phrase(preds_batch[pred_idx]).numpy()]))
            targets.append("".join([rev_character_map.get(s, "") for s in targets_batch[pred_idx].numpy()]))

    N = [len(phrase) for phrase in targets]
    lev_dist = [lev.distance(preds[i], targets[i]) for i in range(len(targets))]
    N=np.sum(N)
    D=np.sum(lev_dist)
    lev_d=(N-D)/N
    print("")    
    print('Lev distance: ',lev_d)

# Model

In [37]:

class CTCLossWrap(tf.keras.losses.Loss):
    def __init__(self, pad_token_idx,batch_size,max_string_len,output_dim,output_steps,replicas):
        self.pad_token_idx = pad_token_idx
        self.batch_size=batch_size
        self.max_string_len=max_string_len
        self.output_steps=output_steps
        self.output_dim=output_dim
        self.replicas=replicas
        super().__init__()

    def call(self, labels, logits):

        #logits=tf.ensure_shape(logits,(self.batch_size//self.replicas,self.output_steps,self.output_dim))
        #labels=tf.ensure_shape(labels,(self.batch_size//self.replicas,self.max_string_len))
        label_length = tf.reduce_sum(tf.cast(labels != self.pad_token_idx, tf.int32), axis=-1)
        logit_length = tf.ones(tf.shape(logits)[0], dtype=tf.int32) * tf.shape(logits)[1]

        #ctc_loss_fn = tf.nn.ctc_loss( 
        ctc_loss_fn=classic_ctc_loss(
            labels=labels,
            logits=logits,
            label_length=label_length,
            logit_length=logit_length,
            blank_index=self.pad_token_idx,
        )

        return ctc_loss_fn


class ECA(tf.keras.layers.Layer):
    # Efficient Channel Attention
    def __init__(self, kernel_size=5, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True
        self.kernel_size = kernel_size
        self.conv = tf.keras.layers.Conv1D(
            1, kernel_size=kernel_size, strides=1, padding="same", use_bias=False
        )

    def call(self, inputs, mask=None):
        nn = tf.keras.layers.GlobalAveragePooling1D()(inputs, mask=mask)
        nn = tf.expand_dims(nn, -1)
        nn = self.conv(nn)
        nn = tf.squeeze(nn, -1)
        nn = tf.nn.sigmoid(nn)
        nn = nn[:, None, :]
        return inputs * nn


class LateDropout(tf.keras.layers.Layer):
    def __init__(self, rate, noise_shape=None, start_step=0, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True
        self.rate = rate
        self.start_step = start_step
        self.dropout = tf.keras.layers.Dropout(rate, noise_shape=noise_shape)

    def build(self, input_shape):
        super().build(input_shape)
        agg = tf.VariableAggregation.ONLY_FIRST_REPLICA
        self._train_counter = tf.Variable(0, dtype="int32", aggregation=agg, trainable=False)

    def call(self, inputs, training=False):
        x = tf.cond(
            self._train_counter < self.start_step,
            lambda: inputs,
            lambda: self.dropout(inputs, training=training),
        )
        if training:
            self._train_counter.assign_add(1)
        return x


class CausalDWConv1D(tf.keras.layers.Layer):
    # Causal Depth Wise Convolution
    def __init__(
        self,
        kernel_size=17,
        dilation_rate=1,
        use_bias=False,
        depthwise_initializer="glorot_uniform",
        name="",
        **kwargs,
    ):
        super().__init__(name=name, **kwargs)
        self.causal_pad = tf.keras.layers.ZeroPadding1D(
            (dilation_rate * (kernel_size - 1), 0), name=name + "_pad"
        )
        self.dw_conv = tf.keras.layers.DepthwiseConv1D(
            kernel_size,
            strides=1,
            dilation_rate=dilation_rate,
            padding="valid",
            use_bias=use_bias,
            depthwise_initializer=depthwise_initializer,
            name=name + "_dwconv",
        )
        self.supports_masking = True

    def call(self, inputs):
        x = self.causal_pad(inputs)
        x = self.dw_conv(x)
        return x
    


def Conv1DBlock(
    channel_size,
    kernel_size,
    dilation_rate=1,
    drop_rate=0.0,
    expand_ratio=2,
    # se_ratio=0.25,
    activation="swish",
    name=None,
):
    """
    efficient conv1d block, @hoyso48
    """
    if name is None:
        name = str(tf.keras.backend.get_uid("mbblock"))

    # Expansion phase
    def apply(inputs):
        channels_in = tf.keras.backend.int_shape(inputs)[-1]
        channels_expand = channels_in * expand_ratio

        skip = inputs

        x = tf.keras.layers.Dense(
            channels_expand, use_bias=True, activation=activation, name=name + "_expand_conv"
        )(inputs)

        # Depthwise Convolution
        x = CausalDWConv1D(
            kernel_size, dilation_rate=dilation_rate, use_bias=False, name=name + "_dwconv"
        )(x)

        #x = tf.keras.layers.LayerNormalization(name=name + "_bn")(x)
        x = tf.keras.layers.BatchNormalization(name=name + "_bn")(x)

        x = ECA()(x)  # efficient channel attention

        x = tf.keras.layers.Dense(channel_size, use_bias=True, name=name + "_project_conv")(x)

        if drop_rate > 0:
            x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1), name=name + "_drop")(x)

        if channels_in == channel_size:
            x = tf.keras.layers.add([x, skip], name=name + "_add")
        return x

    return apply


class MultiHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(self, dim=256, num_heads=4, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.scale = self.dim**-0.5
        self.num_heads = num_heads
        self.qkv = tf.keras.layers.Dense(3 * dim, use_bias=False)
        self.drop1 = tf.keras.layers.Dropout(dropout)
        self.proj = tf.keras.layers.Dense(dim, use_bias=False)
        self.supports_masking = True

    def call(self, inputs, mask=None):
        qkv = self.qkv(inputs)
        qkv = tf.keras.layers.Permute((2, 1, 3))(
            tf.keras.layers.Reshape((-1, self.num_heads, self.dim * 3 // self.num_heads))(qkv)
        )
        q, k, v = tf.split(qkv, [self.dim // self.num_heads] * 3, axis=-1)

        attn = tf.matmul(q, k, transpose_b=True) * self.scale

        if mask is not None:
            mask = mask[:, None, None, :]

        attn = tf.keras.layers.Softmax(axis=-1)(attn, mask=mask)
        attn = self.drop1(attn)

        x = attn @ v
        x = tf.keras.layers.Reshape((-1, self.dim))(tf.keras.layers.Permute((2, 1, 3))(x))
        x = self.proj(x)
        return x


def TransformerBlock(
    dim=256, num_heads=4, expand=4, attn_dropout=0.2, drop_rate=0.2, activation="swish"
):
    def apply(inputs):
        x = inputs
        x = tf.keras.layers.LayerNormalization()(x)
        x = MultiHeadSelfAttention(dim=dim, num_heads=num_heads, dropout=attn_dropout)(x)
        x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1))(x)
        x = tf.keras.layers.Add()([inputs, x])
        attn_out = x

        x = tf.keras.layers.LayerNormalization()(x)
        x = tf.keras.layers.Dense(dim * expand, use_bias=False, activation=activation)(x)
        x = tf.keras.layers.Dense(dim, use_bias=False)(x)
        x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1))(x)
        x = tf.keras.layers.Add()([attn_out, x])
        return x

    return apply

def build_model1(
    output_dim,
    max_len=64,
    dropout_step=0,
    dim=192,
    input_pad=-100,
    with_transformer=False,
    drop_rate=0.2,
):
    inp = tf.keras.Input(shape=(max_len, Constants.CHANNELS), dtype=tf.float32, name="inputs")
    x = tf.keras.layers.Masking(mask_value=input_pad, input_shape=(max_len, Constants.CHANNELS))(
        inp
    )
    ksize = 17
    x = tf.keras.layers.Dense(dim, use_bias=False, name="stem_conv")(x)
    #x = tf.keras.layers.LayerNormalization(name="stem_bn")(x)
    x = tf.keras.layers.BatchNormalization(name="stem_bn")(x)

    x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
    if with_transformer:
        x = TransformerBlock(dim, expand=2)(x)

    #x = tf.keras.layers.AvgPool1D(2, 2)(x)


    x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
    if with_transformer:
        x = TransformerBlock(dim, expand=2)(x)

    x = tf.keras.layers.AvgPool1D(2, 2)(x) # [B,T,dim]

    lstm1 = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(units=dim//2), return_sequences=True)
    #lstm1=tf.keras.layers.LSTM(units=dim//2,return_sequences=True)
    x2 = tf.keras.layers.Bidirectional(lstm1)(x) #[B,T,dim]

    x2=tf.keras.layers.BatchNormalization()(x2)
    x2=tf.keras.layers.Dense(output_dim)(x2)
    soft=tf.keras.layers.Activation('softmax', dtype='float32')(x2)
    logsoft=tf.keras.layers.Activation('log_softmax',dtype='float32',name="internal")(x2)

    x=tf.keras.layers.Dense(dim)(soft)+x
    x=tf.keras.layers.BatchNormalization()(x)
    if dim == 384:  # for the 4x sized model
        x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
        x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
        x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
        if with_transformer:
            x = TransformerBlock(dim, expand=2)(x)

        x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
        x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
        x = Conv1DBlock(dim, ksize, drop_rate=drop_rate)(x)
        if with_transformer:
            x = TransformerBlock(dim, expand=2)(x)

    
    lstm2 = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(units=dim//2), return_sequences=True)
    #lstm2=tf.keras.layers.LSTM(units=dim//2,return_sequences=True)
    x = tf.keras.layers.Bidirectional(lstm2)(x)

    x = LateDropout(0.8, start_step=dropout_step)(x)

    x=tf.keras.layers.Dense(output_dim)(x)
    output = tf.keras.layers.Activation("log_softmax",name="final",dtype="float32")(x)  # logits

    model = tf.keras.Model(inp, outputs=[output,logsoft])
    return model


def get_model(output_dim, max_len, dim, input_pad,dropout_step=0,drop_rate=0.):

    model = build_model1(output_dim, max_len=max_len, input_pad=input_pad, dim=dim,  dropout_step=dropout_step,drop_rate=drop_rate)

    return model


# Config

In [38]:
from functools import lru_cache

@lru_cache(maxsize=None)
def get_strategy():
    logical_devices = tf.config.list_logical_devices()
    # Check if TPU is available

    gpu_available = any("GPU" in device.name for device in logical_devices)
    strategy = None
    is_tpu = False
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print("Running on TPU ", tpu.master())
        is_tpu = True
    except ValueError:
        is_tpu = False

    if is_tpu:
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)

        print("All devices: ", tf.config.list_logical_devices('TPU'))

        strategy = tf.distribute.TPUStrategy(tpu)
        #disable_eager_execution()  # LSTM layer can't use bfloat16 unless we do this.

    else:
        if gpu_available:

            ngpu = len(gpus)
            print("Num GPUs Available: ", ngpu)
            if ngpu > 1:
                strategy = tf.distribute.MirroredStrategy()
            else:
                strategy = tf.distribute.get_strategy()

        else:
            print("Runing on CPU")
            strategy = tf.distribute.get_strategy()
    replicas = strategy.num_replicas_in_sync

    print(f"get strategy replicas: {replicas}")

    return strategy, replicas, is_tpu



In [39]:

class CFG:
    # These 3 variables are update dynamically later by calling update_config_with_strategy.
    strategy = None  # type: ignore
    replicas = 1
    is_tpu = False

    save_output = True
    input_path = "/kaggle/input"
    output_path = "/kaggle/working"

    seed = 42
    verbose = 1  # 0) silent 1) progress bar 2) one line per epoch

    # max number of frames
    max_len = 256
    replicas = 1

    lr = 4e-4   # 5e-4
    weight_decay = 1e-4  # 4e-4
    epochs = 300

    batch_size=128

    snapshot_epochs = []  # type: ignore
    swa_epochs = list(range(3*(epochs//4),epochs+1))

    fp16=True

    awp = False
    awp_lambda = 0.15
    awp_start_epoch = 15
    dropout_start_epoch = 15
    resume = 0

    dim = 384

    comment = f"model-{dim}-seed{seed}"
    output_dim = 60
    num_eval = 6




In [40]:

def update_config_with_strategy(config):
    # cfg is configuration instance
    #strategy, replicas, is_tpu = get_strategy()
    if tpu_strategy is not None:
      strategy=tpu_strategy
      replicas=8
      is_tpu=True
      if IN_COLAB:
          config.input_path=config.input_path.replace("/kaggle","gs://asl-bucket71")
          config.output_path = config.output_path.replace("/kaggle","/content/drive/MyDrive/kaggle")

    else:
      strategy,replicas,is_tpu=get_strategy()
    print("Strategy",strategy)

    config.strategy = strategy
    config.replicas = replicas
    config.is_tpu = is_tpu
    config.lr = config.lr * replicas
    config.batch_size = config.batch_size * replicas
    return config

# Training

In [41]:

def count_data_items(dataset):
    dataset_size = 0
    for _ in dataset:
        dataset_size += 1
    return dataset_size


def interp1d_(x, target_len):
    target_len = tf.maximum(1, target_len)
    x = tf.image.resize(x, (target_len, tf.shape(x)[1]))
    return x


def tf_nan_mean(x, axis=0, keepdims=False):
    return tf.reduce_sum(
        tf.where(tf.math.is_nan(x), tf.zeros_like(x), x), axis=axis, keepdims=keepdims
    ) / tf.reduce_sum(
        tf.where(tf.math.is_nan(x), tf.zeros_like(x), tf.ones_like(x)), axis=axis, keepdims=keepdims
    )


def tf_nan_std(x, center=None, axis=0, keepdims=False):
    if center is None:
        center = tf_nan_mean(x, axis=axis, keepdims=True)
    d = x - center
    return tf.math.sqrt(tf_nan_mean(d * d, axis=axis, keepdims=keepdims))


def flip_lr(x):
    if x.shape[1] == Constants.ROWS_PER_FRAME:
        LHAND = Constants.LHAND
        RHAND = Constants.RHAND
        LLIP = Constants.LLIP
        RLIP = Constants.RLIP
        LEYE = Constants.LEYE
        REYE = Constants.REYE
        LNOSE = Constants.LNOSE
        RNOSE = Constants.RNOSE
        LPOSE = Constants.LPOSE
        RPOSE = Constants.RPOSE
    else:
        LHAND = Constants.LANDMARK_INDICES["LHAND"]
        RHAND = Constants.LANDMARK_INDICES["RHAND"]
        LLIP = Constants.LANDMARK_INDICES["LLIP"]
        RLIP = Constants.LANDMARK_INDICES["RLIP"]
        LEYE = Constants.LANDMARK_INDICES["LEYE"]
        REYE = Constants.LANDMARK_INDICES["REYE"]
        LNOSE = Constants.LANDMARK_INDICES["LNOSE"]
        RNOSE = Constants.LANDMARK_INDICES["RNOSE"]
        LPOSE = Constants.LANDMARK_INDICES["LPOSE"]
        RPOSE = Constants.LANDMARK_INDICES["RPOSE"]

    x, y = tf.unstack(x, axis=-1)
    x = 1 - x
    new_x = tf.stack([x, y], -1)
    new_x = tf.transpose(new_x, [1, 0, 2])
    lhand = tf.gather(new_x, LHAND, axis=0)
    rhand = tf.gather(new_x, RHAND, axis=0)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(LHAND)[..., None], rhand)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(RHAND)[..., None], lhand)
    llip = tf.gather(new_x, LLIP, axis=0)
    rlip = tf.gather(new_x, RLIP, axis=0)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(LLIP)[..., None], rlip)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(RLIP)[..., None], llip)
    lpose = tf.gather(new_x, LPOSE, axis=0)
    rpose = tf.gather(new_x, RPOSE, axis=0)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(LPOSE)[..., None], rpose)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(RPOSE)[..., None], lpose)
    leye = tf.gather(new_x, LEYE, axis=0)
    reye = tf.gather(new_x, REYE, axis=0)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(LEYE)[..., None], reye)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(REYE)[..., None], leye)
    lnose = tf.gather(new_x, LNOSE, axis=0)
    rnose = tf.gather(new_x, RNOSE, axis=0)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(LNOSE)[..., None], rnose)
    new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(RNOSE)[..., None], lnose)
    new_x = tf.transpose(new_x, [1, 0, 2])
    return new_x


def resample(x, rate=(0.8, 1.2)):
    rate = tf.random.uniform((), rate[0], rate[1])
    length = tf.shape(x)[0]
    new_size = tf.cast(rate * tf.cast(length, tf.float32), tf.int32)
    new_x = interp1d_(x, new_size)
    return new_x


def spatial_random_affine(
    xyz,
    scale=(0.8, 1.2),
    shear=(-0.1, 0.1),
    shift=(-0.1, 0.1),
    degree=(-20, 20),
):
    center = tf.constant([0.5, 0.5])
    if degree is not None:
        xy = xyz[..., :2]
        z = xyz[..., 2:]
        xy -= center
        degree = tf.random.uniform((), *degree)
        radian = degree / 180 * np.pi
        c = tf.math.cos(radian)
        s = tf.math.sin(radian)
        rotate_mat = tf.identity(
            [
                [c, s],
                [-s, c],
            ]
        )
        xy = xy @ rotate_mat
        xy = xy + center
        xyz = tf.concat([xy, z], axis=-1)

    if scale is not None:
        scale = tf.random.uniform((), *scale)
        xyz = scale * xyz

    if shear is not None:
        xy = xyz[..., :2]
        z = xyz[..., 2:]
        shear_x = shear_y = tf.random.uniform((), *shear)
        if tf.random.uniform(()) < 0.5:
            shear_x = 0.0
        else:
            shear_y = 0.0
        shear_mat = tf.identity([[1.0, shear_x], [shear_y, 1.0]])
        xy = xy @ shear_mat
        xyz = tf.concat([xy, z], axis=-1)

    if shift is not None:
        shift = tf.random.uniform((), *shift)
        xyz = xyz + shift

    return xyz


def temporal_mask(x, size=[1, 15], mask_value=float("nan")):
    l0 = tf.shape(x)[0]
    if size[1] > l0 // 8:
        size[1] = l0 // 8
        if size[1] <= 1:
            size[1] = 2
    mask_size = tf.random.uniform((), *size, dtype=tf.int32)
    mask_offset = tf.random.uniform((), 0, tf.clip_by_value(l0 - mask_size, 1, l0), dtype=tf.int32)
    x = tf.tensor_scatter_nd_update(
        x,
        tf.range(mask_offset, mask_offset + mask_size)[..., None],
        tf.fill([mask_size, tf.shape(x)[1], 2], mask_value),
    )
    return x


def spatial_mask(x, size=(0.05, 0.2), mask_value=float("nan")):
    mask_offset_y = tf.random.uniform(())
    mask_offset_x = tf.random.uniform(())
    mask_size = tf.random.uniform((), *size)
    mask_x = (mask_offset_x < x[..., 0]) & (x[..., 0] < mask_offset_x + mask_size)
    mask_y = (mask_offset_y < x[..., 1]) & (x[..., 1] < mask_offset_y + mask_size)
    mask = mask_x & mask_y
    x = tf.where(mask[..., None], mask_value, x)
    return x



def augment_fn(x):
    # shape (T,F)
    x = tf.reshape(x, (tf.shape(x)[0], -1, 2))
    if tf.random.uniform(()) < 0.8:
        x = resample(x, (0.5, 1.5))
    if tf.random.uniform(()) < 0.5:
        x = flip_lr(x)
    if tf.random.uniform(()) < 0.75:
        x = spatial_random_affine(x)
    if tf.random.uniform(()) < 0.5:
        x = temporal_mask(x)
    if tf.random.uniform(()) < 0.5:
        x = spatial_mask(x)
    x = tf.reshape(x, (tf.shape(x)[0], -1))
    return x


In [42]:

class Preprocess(tf.keras.layers.Layer):
    def __init__(self, max_len, normalize=False, **kwargs):
        super().__init__(**kwargs)
        self.max_len = max_len
        self.center = Constants.CENTER_INDICES
        self.normalize = normalize

    # preprocess a batch of data
    def call(self, x):
        # rank is 3: [B,T,F]
        # if your input is just [T,F], extend its dimesnion before calling.

        x = tf.reshape(x, (tf.shape(x)[0], tf.shape(x)[1], Constants.NUM_NODES, 2))
        # dimensions now are [B,T,F//2,2]

        x_selected = x
        if self.normalize:
            mean = tf_nan_mean(tf.gather(x, self.center, axis=2), axis=[1, 2], keepdims=True)
            mean = tf.where(tf.math.is_nan(mean), tf.constant(0.5, x.dtype), mean)
            std = tf_nan_std(x_selected, center=mean, axis=[1, 2], keepdims=True)
            x = (x_selected - mean) / std
        else:
            x = x_selected

        dx = tf.cond(
            tf.shape(x)[1] > 1,
            lambda: tf.pad(x[:, 1:] - x[:, :-1], [[0, 0], [0, 1], [0, 0], [0, 0]]),
            lambda: tf.zeros_like(x),
        )

        dx2 = tf.cond(
            tf.shape(x)[1] > 2,
            lambda: tf.pad(x[:, 2:] - x[:, :-2], [[0, 0], [0, 2], [0, 0], [0, 0]]),
            lambda: tf.zeros_like(x),
        )
        length = tf.shape(x)[1]

        x = tf.concat(
            [
                tf.reshape(x, (-1, length, 2 * Constants.NUM_NODES)),  # x1,y1,x2,y2,...
                tf.reshape(dx, (-1, length, 2 * Constants.NUM_NODES)),
                tf.reshape(dx2, (-1, length, 2 * Constants.NUM_NODES)),
            ],
            axis=-1,
        )

        # x1,y1,x2,y2,...dx1,dy1,dx2,dy2,...
        x = tf.where(tf.math.is_nan(x), tf.constant(0.0, x.dtype), x)
        return x


def pad_if_short(x, max_len):
    # shape (T,F)
    pad_len = max_len - tf.shape(x)[0]
    padding = tf.ones((pad_len, tf.shape(x)[1]), dtype=x.dtype) * Constants.INPUT_PAD
    x = tf.concat([x, padding], axis=0)
    return x


def shrink_if_long(x, max_len):
    # shape is [T,F]
    if tf.shape(x)[0] > max_len:
        # we need to extend the dimension to [T,F,channels]  for tf.image.resize
        x = tf.image.resize(x[..., None], (max_len, tf.shape(x)[1]))
        x = tf.squeeze(x, axis=2)

    return x

def preprocess(x, max_len, do_pad=True):
    # shape (T,F)
    x = shrink_if_long(x, max_len=max_len)
    # Preprocess expects a batch, so we extend the dimension to (None,T,F), then reduce the output back to (T,F).
    x = tf.cast(Preprocess(max_len=max_len)(x[None, ...])[0], tf.float32)

    if do_pad:  # we can avoid this step if there is batch padding
        x = pad_if_short(x, max_len=max_len)
        #x=tf.ensure_shape(x,(max_len,Constants.CHANNELS))
    else:
        #x=tf.ensure_shape(x,(None,Constants.CHANNELS))
        pass
    return x

In [43]:

def decode_tfrec(record_bytes):
    features = tf.io.parse_single_example(
        record_bytes,
        {
            "coordinates": tf.io.VarLenFeature(tf.float32),
            "label": tf.io.VarLenFeature(tf.int64),
        },
    )
    coords = tf.sparse.to_dense(features["coordinates"])
    coords = tf.reshape(coords, (-1, Constants.NUM_INPUT_FEATURES))
    label = tf.sparse.to_dense(features["label"])
    label=tf.cast(label,dtype=tf.int32)

    #coords=tf.ensure_shape(coords,(None,Constants.NUM_INPUT_FEATURES))
    #label=tf.ensure_shape(label,(None,))


    return (coords, label)

def ensure_shapes(x,y,batch_size,max_len):
  x=tf.ensure_shape(x,(batch_size,max_len,Constants.CHANNELS))
  y=tf.ensure_shape(y,(batch_size,Constants.MAX_STRING_LEN))
  tf.print("ensure",tf.shape(x),tf.shape(y))
  return x,y

In [44]:

def get_dataset(
    filenames,
    input_path,
    max_len,
    batch_size=64,
    drop_remainder=False,
    augment=False,
    shuffle_buffer=None,
    repeat=False,
    use_tfrecords=True,
):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False


    ds = tf.data.TFRecordDataset(
        filenames, num_parallel_reads=tf.data.AUTOTUNE, compression_type="GZIP"
    )
    ds.with_options(ignore_order)
    ds = ds.map(decode_tfrec, tf.data.AUTOTUNE)

    if augment:
        ds = ds.map(lambda x, y: (augment_fn(x), y), tf.data.AUTOTUNE)

    ds = ds.map(lambda x, y: (preprocess(x, max_len=max_len, do_pad=False), y), tf.data.AUTOTUNE)
    #if repeat:
    #    ds = ds.repeat()

    if shuffle_buffer is not None:
        ds = ds.shuffle(shuffle_buffer)

    ds = ds.padded_batch(
        batch_size,
        padding_values=(
            tf.constant(Constants.INPUT_PAD, dtype=tf.float32),
            tf.constant(Constants.LABEL_PAD, dtype=tf.int32),
        ),
        padded_shapes=([max_len, Constants.CHANNELS], [Constants.MAX_STRING_LEN]),
        drop_remainder=drop_remainder,
    )

    #tf.data.experimental.assert_cardinality(len(labels) // BATCH_SIZE)
    ds.map(lambda x,y: ensure_shapes(x,y,batch_size,max_len),tf.data.AUTOTUNE)
    ds = ds.prefetch(tf.data.AUTOTUNE)

    return ds


In [45]:

def train_run(train_files, valid_files, config, num_train, num_valid,experiment_id=0, use_tfrecords=True,summary=False,evaluate_only=False):
    #gc.collect()
    #tf.keras.backend.clear_session()


    if config.fp16:
        if config.is_tpu:
            policy = "mixed_bfloat16"
        else:
            policy = "mixed_float16"
    else:
        policy = "float32"


    tf.keras.mixed_precision.set_global_policy(policy)
    print(f"\n... TWO IMPORTANT ASPECTS OF THE GLOBAL MIXED PRECISION POLICY:")
    print(f'\t--> COMPUTE DTYPE  : {tf.keras.mixed_precision.global_policy().compute_dtype}')
    print(f'\t--> VARIABLE DTYPE : {tf.keras.mixed_precision.global_policy().variable_dtype}')
    augment_train= True
    repeat_train = True
    if config.is_tpu:
      shuffle_buffer = 16384 #4096
    else:
      shuffle_buffer=4096
    print("shuffle_buffer",shuffle_buffer)
    train_ds = get_dataset(
        train_files,
        input_path=config.input_path,
        max_len=config.max_len,
        batch_size=config.batch_size,
        drop_remainder=True,
        augment=augment_train,
        repeat=repeat_train,
        shuffle_buffer=shuffle_buffer,
        use_tfrecords=True,
    )
    if valid_files is not None:
        valid_ds = get_dataset(
            valid_files,
            input_path=config.input_path,
            max_len=config.max_len,
            batch_size=config.batch_size,
            use_tfrecords=True,
            drop_remainder=True
        )
    else:
        valid_ds = None
        valid_files = []

    valid_set_memory=[x for x in valid_ds]
    
    #num_train = count_data_items(train_ds)
    #num_valid = count_data_items(valid_ds)
    #print("num_train batches",num_train, "num_valid batches",num_valid,"batch_size", config.batch_size)
    #assert False

    steps_per_epoch = num_train // config.batch_size
    dropout_step = config.dropout_start_epoch * steps_per_epoch
    strategy = config.strategy
    with strategy.scope():
        model = get_model(
            max_len=config.max_len,
            output_dim=config.output_dim,
            input_pad=Constants.INPUT_PAD,
            dim=config.dim,
            dropout_step=dropout_step,
            drop_rate=0.2
        )

        base_lr = config.lr
        lr_schedule = CosineDecay(
            initial_learning_rate=base_lr,
            decay_steps=int(steps_per_epoch * config.epochs),
            alpha=0.005,
            name=None,
            warmup_target=None,
            warmup_steps=0
        )

        #opt = tf.keras.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=config.weight_decay)
        radam=tfa.optimizers.RectifiedAdam(learning_rate=lr_schedule,weight_decay=config.weight_decay)
        ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5)
        opt=ranger
        awp_step = config.awp_start_epoch * steps_per_epoch
        if config.awp:
            model = AWP(model.input, model.output, delta=config.awp_lambda, eps=0., start_step=awp_step)
            print("Using AWP")

        ctc_loss1 = CTCLossWrap(pad_token_idx=Constants.LABEL_PAD,batch_size=config.batch_size,
                           max_string_len=Constants.MAX_STRING_LEN,
                           output_dim=config.output_dim,
                           output_steps=config.max_len//2,replicas=config.replicas)
        ctc_loss2 = CTCLossWrap(pad_token_idx=Constants.LABEL_PAD,batch_size=config.batch_size,
                           max_string_len=Constants.MAX_STRING_LEN,
                           output_dim=config.output_dim,
                           output_steps=config.max_len//2,replicas=config.replicas)

        if not config.is_tpu:
          metrics=metrics= [LevDistanceMetric(),]
        else:
          metrics=None
        model.compile(
          optimizer=opt,
          loss=[ctc_loss1,ctc_loss2],
          loss_weights=[0.5,0.5],
          metrics= metrics,
          #steps_per_execution=16
        )



    if summary:
        print()
        model.summary()
        print()
        print(train_ds, valid_ds)
        print()
    print(f"---------experiment {experiment_id}---------")
    print(f"train:{num_train} ")
    print()

    if evaluate_only:
        model.load_weights(f"{config.output_path}/{config.comment}-exp{experiment_id}-best.h5")
        cv=model.evaluate(valid_ds,verbose=config.verbose)
        return model,cv,None

    if config.resume:
        print(f"resume from epoch{config.resume}")
        model.load_weights(f"{config.output_path}/{config.comment}-exp{experiment_id}-last.h5")
        if train_ds is not None:
            model.evaluate(train_ds.take(steps_per_epoch))
        if valid_ds is not None:
            model.evaluate(valid_ds)

    tb_logger = tf.keras.callbacks.TensorBoard(
        log_dir=config.output_path,
    )
    sv_loss = tf.keras.callbacks.ModelCheckpoint(
        f"{config.output_path}/{config.comment}-exp{experiment_id}-best.h5",
        monitor="val_final_loss",
        verbose=1,
        save_best_only=True,
        save_weights_only=True,
        mode="min",
        save_freq="epoch",
    )

    # Callback function to check transcription on the val set.
    # validation_callback = CallbackEval(model, valid_ds)
    memory_usage = MemoryUsageCallbackExtended()
    swa = SWA(
        f"{config.output_path}/{config.comment}-exp{experiment_id}",
        config.swa_epochs,
        strategy=strategy,
        train_ds=train_ds,
        valid_ds=valid_ds,
    )
    val_lev=val_lev_callback(valid_set_memory)
    callbacks = []
    if config.save_output:
        #callbacks.append(tb_logger)
        callbacks.append(val_lev)
        callbacks.append(swa)
        callbacks.append(sv_loss)
    #callbacks.append(memory_usage)
        callbacks.append(tf.keras.callbacks.TerminateOnNaN())
    # callbacks.append(validation_callback)

    history = model.fit(
        train_ds,
        epochs=config.epochs - config.resume,
        #steps_per_epoch=steps_per_epoch,
        #validation_steps=num_valid // config.batch_size,
        callbacks=callbacks,
        validation_data=valid_ds,
        verbose=config.verbose,
    )

    if config.save_output:  # reload the saved best weights checkpoint
        saved_based_model = f"{config.output_path}/{config.comment}-exp{experiment_id}-best.h5"
        if os.path.exists(saved_based_model):
            model.load_weights(saved_based_model)
        else:
            print(f"Warning: could not find {saved_based_model}")
    if valid_ds is not None:
        cv = model.evaluate(valid_ds, verbose=config.verbose)
    else:
        cv = None
    return model, cv, history



In [46]:

def train(config, experiment_id=0, use_supplemental=True,use_chicago=True,evaluate_only=False):
    #tf.keras.backend.clear_session()
    if config.strategy is None:
      update_config_with_strategy(config)
    print(f"using {config.replicas} replicas")
    print(f"batch size {config.batch_size}")
    print(f"learning rate {config.lr}")
    print(f"fp16={config.fp16}")
    seed_everything(config.seed)


    #all_filenames = tf.io.gfile.glob(config.input_path+"/asl-preprocessing/records/*.tfrecord")
    
    all_filenames = tf.io.gfile.glob(config.input_path+"/sign-tfrecords/*.tfrecord")
    
    regular = [x for x in all_filenames if "train" in x]
    supp = [x for x in all_filenames if "supp" in x]
    chicago = [x for x in all_filenames if "chicago" in x]
    all_filenames=sorted(regular)+sorted(supp)+sorted(chicago)

    data_filenames = regular
    if use_supplemental:
        data_filenames += supp
    if use_chicago:
        data_filenames +=chicago
    print("Using TFRECORDS")

    valid_files = data_filenames[: config.num_eval]  # first part in sorted list
    train_files = data_filenames[config.num_eval :]

    random.shuffle(train_files) # now shuffle only the train set.
    

    #df1 = pd.read_csv(config.input_path + "/asl-fingerspelling/train.csv")
    #df2 = pd.read_csv(config.input_path + "/asl-fingerspelling/supplemental_metadata.csv")
    #df_info = pd.concat([df1, df2])

    #ds = get_dataset(train_files, CFG.input_path,max_len=CFG.max_len, augment=False, batch_size=64)
    #print(ds)
    #for x,y in ds:
    #    print(x,y)
    #assert False

    if (not use_supplemental) and (not use_chicago):
        num_train = 1912 * 32  # without supplemental
    elif use_supplemental and (not use_chicago):
        num_train = 3567 * 32  # with supplemental
    else:
        num_train=5505*32

    num_valid=187*32
    #train_files=train_files[:6]
    #num_train=6000

    train_run(
        train_files,
        valid_files,
        config,
        num_train,
        num_valid,
        summary=False,
        experiment_id=experiment_id,
        use_tfrecords=True,
        evaluate_only=evaluate_only
    )



In [47]:
gc.collect()
tf.keras.backend.clear_session()

# Runnn!

In [None]:
if 'config' not in globals():
  config=CFG()
tf.debugging.disable_traceback_filtering()
train(config,use_supplemental=True)
#assert False

using 1 replicas
batch size 128
learning rate 0.0004
fp16=True
Using TFRECORDS

... TWO IMPORTANT ASPECTS OF THE GLOBAL MIXED PRECISION POLICY:
	--> COMPUTE DTYPE  : float16
	--> VARIABLE DTYPE : float32
shuffle_buffer 4096
---------experiment 0---------
train:176160 

Epoch 1/300


# Inference

In [None]:
import tensorflow as tf

In [None]:
class InferModel(tf.Module):
    def __init__(self, model,config=CFG):
        super().__init__()

        self.model = model
        self.max_len=config.max_len

    @tf.function(
        input_signature=[tf.TensorSpec(shape=(None,Constants.NUM_INPUT_FEATURES), dtype=tf.float32, name="inputs")]
    )
    def __call__(self, inputs):
        """
        Applies the feature generation model and main model to the input tensor.

        Args:
            inputs: Input tensor with shape (T, F).

        Returns:
            A dictionary with a single key 'outputs' and corresponding output tensor.
        """
        x=tf.cast(inputs,tf.float32)
        x = x[None] # trick to deal with empty frames
        x = tf.cond(tf.shape(x)[1] == 0, lambda: tf.zeros((1, 1, Constants.NUM_INPUT_FEATURES)), lambda: tf.identity(x))
        x = x[0]
        x = preprocess(x,max_len=self.max_len)

        x = self.model(x[None],training=False)[0][0]

        x=decode_phrase(x)
        x = tf.cond(tf.shape(x)[0] == 0, lambda: tf.zeros(1, tf.int32), lambda: tf.identity(x))

        outputs=tf.one_hot(x,depth=59,dtype=tf.float32)
        #outputs=x
        return {"outputs": outputs}


In [None]:

config=CFG

model = get_model(
    max_len=config.max_len,
    output_dim=config.output_dim,
    dim=config.dim,
    input_pad=Constants.INPUT_PAD,
)
experiment_id=0

saved_based_model = f"{config.input_path}/weights/{config.comment}-exp{experiment_id}-best.h5"
model.load_weights(saved_based_model)
print(f"model with weights {saved_based_model}")

In [None]:
# Sanity Check
import json
with open (config.input_path+"/asl-fingerspelling/character_to_prediction_index.json", "r") as f:
    character_map = json.load(f)
rev_character_map = {j:i for i,j in character_map.items()}

infer_keras_model=InferModel(model)

main_dir = config.input_path+'/asl-fingerspelling'
path = f'{main_dir}/train_landmarks/5414471.parquet'
cols=selected_columns(path)
df = pd.read_parquet(path, engine = 'auto', columns = cols)
seq_id=1816796431
seq=df.loc[seq_id]
data = seq[cols].to_numpy()
print(f'input shape: {data.shape}, dtype: {data.dtype}')
output = infer_keras_model(data)["outputs"]
prediction_str = "".join([rev_character_map.get(s, "") for s in np.argmax(output, axis=1)])

print(prediction_str)

In [None]:
SAVED_MODEL_PATH=config.output_path+"/infer_model"

tf.saved_model.save(infer_keras_model,SAVED_MODEL_PATH)
keras_model_converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
keras_model_converter.optimizations = [tf.lite.Optimize.DEFAULT]
keras_model_converter.target_spec.supported_types = [tf.float16]

#converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
#converter.allow_custom_ops=True
tflite_model = keras_model_converter.convert()
TFLITE_FILE_PATH=config.output_path+"/model.tflite"
with open(TFLITE_FILE_PATH, "wb") as f:
    f.write(tflite_model)

with open(config.output_path+'/inference_args.json', 'w') as f:
     json.dump({ 'selected_columns': cols }, f)



In [None]:
interpreter = tf.lite.Interpreter(TFLITE_FILE_PATH)
REQUIRED_SIGNATURE = "serving_default"
REQUIRED_OUTPUT = "outputs"
found_signatures = list(interpreter.get_signature_list().keys())
if REQUIRED_SIGNATURE not in found_signatures:
    print("Required input signature not found.")

prediction_fn = interpreter.get_signature_runner("serving_default")
output = prediction_fn(inputs=data)
prediction_str = "".join([rev_character_map.get(s, "") for s in np.argmax(output[REQUIRED_OUTPUT], axis=1)])
print(prediction_str)

In [None]:

!zip  submission.zip "/kaggle/working/model.tflite" "/kaggle/working/inference_args.json"

In [None]:
#!pip install /kaggle/input/tflite-wheels-2140/tflite_runtime_nightly-2.14.0.dev20230508-cp310-cp310-manylinux2014_x86_64.whl

In [None]:
import os


import json
import pandas as pd
import tflite_runtime.interpreter as tflite
import numpy as np
import time
from tqdm import tqdm
import Levenshtein as Lev
import glob
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [None]:
SEL_FEATURES = json.load(open('/kaggle/working/inference_args.json'))['selected_columns']

def load_relevant_data_subset(pq_path):
        return pd.read_parquet(pq_path, columns=SEL_FEATURES) #selected_columns)

with open ("/kaggle/input/asl-fingerspelling/character_to_prediction_index.json", "r") as f:
    character_map = json.load(f)
rev_character_map = {j:i for i,j in character_map.items()}


df_csv = pd.read_csv('/kaggle/input/asl-fingerspelling/train.csv')

idx = 0
sample = df_csv.loc[idx]
loaded = load_relevant_data_subset('/kaggle/input/asl-fingerspelling/' + sample['path'])
loaded = loaded[loaded.index==sample['sequence_id']].values
print(loaded.shape)
frames = loaded


In [None]:

st = time.time()
count=0
model_time = 0

N=0
D=0

files=sorted(glob.glob('/kaggle/input/asl-fingerspelling/train_landmarks/*.parquet'))[:6]
for f in files:
    fid=int(f.split("/")[-1].split(".")[0])
    df = load_relevant_data_subset(f)
    seq=df.index.drop_duplicates()
    for ind in tqdm(seq):
        sample=df_csv[(df_csv["sequence_id"]==ind) & (df_csv["file_id"]==fid)]
        #print(sample)
        loaded = df.loc[ind].values
        count+=1
        md_st = time.time()
        
        # out = infer_keras_model(loaded)["outputs"] # original model

        out = prediction_fn(inputs=loaded)[REQUIRED_OUTPUT] # tflite
        
        model_time += time.time() - md_st

        prediction_str = "".join([rev_character_map.get(s, "") for s in np.argmax(out, axis=1)])
        assert out.ndim==2
        assert out.shape[1]==59
        assert out.dtype==np.float32
        assert np.all(np.isfinite(out))
        s1=sample["phrase"].item()
        s2=prediction_str
        n = len(s1)
        d = Lev.distance(s1,s2)
        N=N+n
        D=D+d
        #print(ind,s1,s2,n,d)
lev=(N-D)/N
print(f'Lev: {lev:.4f}')
print(f'Mean time: {(time.time() - st)/count:.3f}')
print(f'Mean time only infer: {model_time/count:.3f}')
        
