# **Study Project:** *Transformer model for prediction of grasping movements*

### Setup

In [1]:
# Imports
import tensorflow as tf

2024-04-11 13:53:20.984034: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def deserialize(serialized_example):
    """
    Function to deserialize tensors from bytes.
    """

    feature_description = {
        'context': tf.io.FixedLenFeature([], tf.string),
        'input': tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(serialized_example, feature_description)
    context = tf.io.parse_tensor(example['context'], out_type=tf.float64)
    x = tf.io.parse_tensor(example['input'], out_type=tf.float64)
    target = tf.io.parse_tensor(example['target'], out_type=tf.float64)

    return context, x, target


def compute_mask(inputs, padding_token=0):
    return tf.cast(tf.not_equal(inputs, padding_token), tf.float64)


# Define custom functions
def masked_loss(label, pred, pad_token=-2):
    mask = label != pad_token
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
    return loss


def masked_accuracy(label, pred, pad_token=-2):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred

    mask = label != pad_token

    match = match & mask

    match = tf.cast(match, dtype=tf.float64)
    mask = tf.cast(mask, dtype=tf.float64)
    return tf.reduce_sum(match)/tf.reduce_sum(mask)

tf.keras.utils.get_custom_objects()['masked_loss'] = masked_loss
tf.keras.utils.get_custom_objects()['masked_accuracy'] = masked_accuracy

### Load Dataset

In [3]:
# Load tensorflow dataset
train_ds_path = "./data/train_ds.zip"
test_ds_path = "./data/test_ds.zip"

# Create a TFRecordDataset from the saved file
train_dataset = tf.data.TFRecordDataset(train_ds_path, compression_type='GZIP')
test_dataset = tf.data.TFRecordDataset(test_ds_path, compression_type='GZIP')

# Deserialize the zipped dataset
train_dataset = train_dataset.map(deserialize)
test_dataset = test_dataset.map(deserialize)

### Load Model

In [4]:
# Load the model from zip file
model_path = "./models/transformer"
PAD = -2

# Import the model
transformer = tf.keras.models.load_model(model_path)

### Application

In [5]:
MAX_SEQ_LEN = max(con.shape[1] for con, x, t in test_dataset) # get maximum sequence length of ds

class Predictor(tf.Module):
    def __init__(self, transformer):
        self.transformer = transformer

    def __call__(self, bbox_sequence, max_length):

        assert isinstance(bbox_sequence, tf.Tensor)

        # For first frame with shape (8) expand dim for (SEQ_LEN, 8)
        if len(bbox_sequence.shape) == 1:
            bbox_sequence = bbox_sequence[tf.newaxis, :]

        # If max_length is longer than current seq_len, just take the seq_len
        # -> does not interfere with pad_length down below
        max_length = min(len(bbox_sequence), max_length)

        # live application is just one continuous sequence, but transformer trained on batches
        # -> add batch_size dimension for shape (BS, SEQ_LEN, 8); necessary for inference?
        encoder_input = bbox_sequence[tf.newaxis, :, :]

        # Input start and output end tokens
        start = tf.constant([-333], dtype=tf.int32)
        end = tf.constant([-1], dtype=tf.int32)

        # `tf.TensorArray` is required here (instead of a Python list), so that the
        # dynamic-loop can be traced by `tf.function`.
        output_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
        output_array = output_array.write(0, start)

        for i in tf.range(max_length):
            # -> add batch_size dimension necessary for inference?
            output = output_array.stack()[tf.newaxis, :, :]

            # dynamically pad output
            pad_length = len(bbox_sequence) - int(output_array.size())
            paddings = tf.constant([[0, 0], [0, pad_length], [0, 0]])
            output = tf.pad(output, paddings, "CONSTANT")
            output = tf.cast(output, tf.float32)
            # mask necessary for prediction or only for training?
            
            predictions = self.transformer(inputs=(encoder_input, output), mask=None, training=False)

            # Select the last token from the `seq_len` dimension.
            predictions = predictions[:, -1:, :]  # Shape `(batch_size, 1, vocab_size)`.

            predicted_id = tf.cast(tf.argmax(predictions, axis=-1), dtype=tf.int32)

            # Concatenate the `predicted_id` to the output which is given to the
            # decoder as its input.
            output_array = output_array.write(i+1, predicted_id[0])

            if predicted_id == end:
                break

        output = output_array.stack()[tf.newaxis, :, :]
        output = tf.cast(output, tf.float32)

        # `tf.function` prevents us from using the attention_weights that were
        # calculated on the last iteration of the loop.
        # So, recalculate them outside the loop.
        self.transformer(inputs=(encoder_input, output[:, :-1, :]), mask=None, training=False)
        #attention_weights = self.transformer.decoder.last_attn_scores

        return output #, attention_weights

2024-04-11 13:53:37.874603: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [1]
	 [[{{node Placeholder/_0}}]]


### Predict

In [9]:
predictor = Predictor(transformer)

# Iterate through all batches in the DS
for batch, (context, x, target) in enumerate(test_dataset):

    # Iterate through all sequences (videos) in the batch
    for i, seq in enumerate(context):
        seq_len = len(seq)
        angles = predictor(bbox_sequence=seq, max_length=seq_len)
        pred = angles.numpy().flatten().tolist()
        pred = pred[1:]
        true = target[i].numpy().flatten().tolist()
        mse = tf.keras.losses.mean_squared_error(true, pred)

        # Print predictions
        print(f'Batch {batch}, iteration {i}')
        print(f'Prediction: {pred}')
        print(f'Ground truth: {true}')
        print("Mean Squared Error:", round(mse.numpy(), 2))
        print()
        break # debugging: print only one sequence in batch

Batch 0, iteration 0
Prediction: [171.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 347.0, 48.0, 84.0, 84.0, 84.0, 84.0, 84.0, 84.0, 84.0, 84.0, 84.0, 84.0, 84.0, 84.0, 84.0, 48.0]
Ground truth: [143.0, 156.0, 104.0, 104.0, 102.0, 95.0, 92.0, 101.0, 43.0, -1.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0]
Mean Squared Error: 86559.73

Batch 1, iteration 0
Prediction: [347.0, 347.0, 347.0, 347.0, 347.0, 347