# Transformer Training

## Import packages

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install ultralytics

from IPython import display
display.clear_output()

In [None]:
# Import necessary packages
import cv2
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, Dropout
import os
from tqdm import tqdm
import ultralytics
ultralytics.checks()

Ultralytics YOLOv8.0.120 🚀 Python-3.10.12 torch-2.0.1+cu118 CUDA:0 (Tesla T4, 15102MiB)
Setup complete ✅ (2 CPUs, 12.7 GB RAM, 24.1/78.2 GB disk)


## Helpers

In [None]:
def VideoBBs(videopath):
  # Use object tracker to get the bounding boxes and classIDs in a 'results' object
  bbs = np.zeros((1, 7))
  # It is also possible to pass the whole folder as path,
  # but we still want the flexibility to access single videos
  results = model.track(source=videopath, tracker="bytetrack.yaml")

  # Get class names
  classes = results[0].names

  # Iterate through each frame of a video to get all bounding boxes for a frame
  for frame in range(len(results)):

    # x_center, y_center, bbwidth, bbheight of bbs of this frame
    xywh = results[frame].boxes.xywh.detach().cpu().numpy()
    n = len(xywh) # number of bounding boxes
    cls = results[frame].boxes.cls.detach().cpu().numpy().reshape((n,1))
    # if the object tracker is currently tracking at least one object, save the trackingID for that object, else fill with -1 placeholder
    trackingID = results[frame].boxes.id.detach().cpu().numpy().reshape((n,1)) if results[frame].boxes.is_track else np.repeat(-1, n).reshape((n,1))
    frame_count = np.repeat(frame, n).reshape((n,1))

    # bind the data together
    data = np.concatenate((frame_count, cls, xywh, trackingID), axis=1)
    # add all data of this frame to all data of this video
    bbs = np.concatenate((bbs, data), axis=0)

  return bbs, classes

In [None]:
def Align(videopath, to_path, bounding_boxes, width=1920, height=1120):

  ############################ Read video ####################################
  cap = cv2.VideoCapture(videopath)
  frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  frameWidth = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  frameHeight = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  FPS = int(cap.get(cv2.CAP_PROP_FPS))

  video = np.zeros((frameCount, frameHeight, frameWidth, 3), np.dtype('uint8'))
  fc = 0
  ret = True

  while (fc < frameCount and ret):
      ret, video[fc] = cap.read()
      fc += 1

  cap.release()
  ############################################################################


  ############################ Correct video #################################
  bbs = bounding_boxes
  width = width
  height = height
  center_x = (width/2)
  center_y = (height/2)
  x_dists = []
  y_dists = []

  for box_pos in bbs:
    if np.isnan(box_pos).any():
      # if all values were nan, then 0 would be the max, so the corrected video would be the original video
      x_dists.append(0)
      y_dists.append(0)
    else:
      # get center of bb
      #center_x_bb = (box_pos[0] + box_pos[2]) / 2
      #center_y_bb = (box_pos[1] + box_pos[3]) / 2
      center_x_bb = box_pos[0]
      center_y_bb = box_pos[1]
      # calculate distances
      x_dists.append(abs(center_x-center_x_bb))
      y_dists.append(abs(center_y-center_y_bb))


  video_corrected = np.zeros((len(video),
                              int(height + max(y_dists)*2),
                              int(width + max(x_dists)*2),
                              3), dtype=int)
  ############################################################################


  ############################ Align video ###################################
  # get coords for placement height and width. may be switched
  start_row = (video_corrected.shape[1] - height) // 2
  start_col = (video_corrected.shape[2] - width) // 2

  for idx, frame in enumerate(tqdm(video)):

    # get matching bb
    box_curr = bbs[idx]

    # If there is no information on fruit, color the frame black (skip iter)
    if np.isnan(box_curr).any():
      continue
      # frame[:] = 0
      # frame = np.zeros(frame.shape)
    else:
      # get center of bb
      #center_x_bb = (box_curr[0] + box_curr[2]) / 2
      #center_y_bb = (box_curr[1] + box_curr[3]) / 2
      center_x_bb = box_curr[0]
      center_y_bb = box_curr[1]

      # get offset of center
      # pos if bounding box is to right of center, else negative
      x_offset = int(center_x_bb - center_x)
      # pos if bounding box is below center, else negative
      y_offset = int(center_y_bb - center_y)

      #!coordinates until here are for old video

      #get fitting indices for new video
      fixed_start_row = start_row - y_offset
      fixed_start_col = start_col - x_offset

      fixed_end_row = fixed_start_row + height
      fixed_end_col = fixed_start_col + width

      #Checkup
      if((fixed_start_row or
          fixed_start_col or
          fixed_end_row or
          fixed_end_col) < 0):
          print("Negative Index!")

      if (idx == 0):
          print("fixed_start_row:",fixed_start_row)
          print("fixed_start_col:",fixed_start_col)

          print(video.shape)
          print(video_corrected.shape)
          print(fixed_end_row - fixed_start_row)
          print(fixed_end_col - fixed_start_col)
          print(frame.shape)

      # Save centered + corrected in new video
      video_corrected[idx][fixed_start_row:fixed_end_row,
                          fixed_start_col:fixed_end_col] = frame
  ##########################################################################


  ############################ Save video ##################################
  video_corrected = np.uint8(video_corrected)

  height_new = int(video_corrected.shape[1])
  width_new = int(video_corrected.shape[2])

  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
  out = cv2.VideoWriter(to_path[:-4]+'_corrected.mp4', fourcc, FPS, (width_new, height_new), True)
  for idx in range(len(video)):
      out.write(video_corrected[idx])
  out.release()
  ##########################################################################

## Input

1. Get the bounding boxes of all objects in each raw video and filter for the hand and target fruit bounding boxes.
2. Center the target fruit for each video (TARGET.mp4) and save the corrected videos to another folder (Input X).
3. Get the bounding boxes for the hand in the centered videos (Labels Y).

The bounding boxes of the hand and target fruit in the corrected videos are the inputs X for the transformer training.
The bounding boxes of the hand in the corrected videos are the corresponding labels Y for the transformer training.

In [None]:
model = ultralytics.YOLO('yolov8n.pt')
folderpath = '/content/drive/MyDrive/Studium/Semester M2/Study Project: Grasping/0 Trials/'
labelpath = '/content/drive/MyDrive/Studium/Semester M2/Study Project: Grasping/1 Labels/'

folder = [f for f in os.listdir(folderpath) if os.path.isfile(os.path.join(folderpath, f))]
input = np.zeros((1, 10)) # filename, frame, x_center_f, y_center_f, bbwidth_f, bbheight_f, x_center_h, y_center_h, bbwidth_h, bbheight_h (raw)
labels = np.zeros((1, 6)) # filename, frame, x_center_h, y_center_h, bbwidth_h, bbheight_h (corrected)


# Iterate through every video file in the folder
for i, video in enumerate(folder):

  print(f"Video {i+1}/{len(folder)}: {video}\n")

  # Get length of input array (number of frames)
  cap = cv2.VideoCapture(folderpath+video)
  frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  cap.release()


  # 1. Get all bbs of this video
  # (x_center, y_center, bbwidth, bbheight, class, trackerID, frame, filename)
  bbs, classes = VideoBBs(folderpath+video)
  video_count = np.repeat(i, len(bbs)).reshape((len(bbs),1))
  bbs = np.concatenate((video_count, bbs), axis=1)


  # 2. Filter for hand and target fruit by class and trackingID
  # Get the target class to filter for as number
  if '.mp4' in video:
    # returns the string name without '00.mp4' ending (could add try-catch)
    target_class = video[:-6]

  # Get class number from lookup table
  target_class = list(classes.keys())[list(classes.values()).index(target_class)]
  hand_class = list(classes.keys())[list(classes.values()).index('person')]

  # Filter data
  target_bbs = bbs[(bbs[:, 2] == target_class)]
  hand_bbs = bbs[(bbs[:, 2] == hand_class)]

  # To Do: Use trackerID to track the same object
  # -> could be added later on, as we only use one object per trial?
  # -> but if we have detection of more than one target fruit, it breaks something
  # because it will try to calculate positions from more than one bb per frame.
  # problem so far: if detection ends for one frame, a new trackerID is assigned
  # -> for this approach to work we would have to compare the position of the
  # last tracked trackerID to each object with newly assigned trackerID and
  # take the closest one, and interpolate position if there are no new trackerIDs detected.
  # for now, we leave it out: remove class and trackingID
  target_bbs = target_bbs[:, [0,1,3,4,5,6]]
  hand_bbs = hand_bbs[:, [0,1,3,4,5,6]]


  # 3. Interpolation: take last known bounding box positions --> input X
  # (x_center_fruit, y_center_fruit, bbwidth_fruit, bbheight_fruit, class, trackerID, frame, filename, x_center_hand, y_center_hand, bbwidth_hand, bbheight_hand)
  vid_input = np.zeros((frameCount, 10))
  vid_input[:] = np.nan # use as 'no information' instead of nan, because nan is a string

  for frame in range(len(vid_input)):
      # get information on objects
      target_information = target_bbs[(target_bbs[:, 1] == frame), :]
      hand_information = hand_bbs[(hand_bbs[:, 1] == frame), 2:]

      # fill information for fruit
      if (frame in target_bbs[:, 1]) & (target_information.shape[0] == 1): # col 1 is frames
        vid_input[frame, :6] = target_information
      else:
        # if we have a frame without information on the object (object was not detected)
        # we take the information from the last row if it was not the first frame
        if frame == 0:
          vid_input[frame, 1] = frame
          vid_input[frame, 0] = i
          continue
        else:
          # interpolation method: last known information
          # (should be replaced by positional interpolation)
          vid_input[frame, :6] = vid_input[frame-1, :6]

      # fill information for hand
      if (frame in hand_bbs[:, 1]) & (hand_information.shape[0] == 1): # col 1 is frames
        """
        Fails in the line below if we have more than one detection for the same object/class per frame.
        current workaround: take last known information (like current interpolation method)
        -> better option: compare to last known information and take closer one
        """
        vid_input[frame, 6:] = hand_information
      else:
        # if we have a frame without information on the object (object was not detected)
        # we take the information from the last row if it was not the first frame
        if frame == 0:
          vid_input[frame, 1] = frame
          vid_input[frame, 0] = i
          continue
        else:
          vid_input[frame, 6:] = vid_input[frame-1, 6:]

      vid_input[frame, 1] = frame
      vid_input[frame, 0] = i


  # 4. Align: Center on the correct bounding boxes of the target fruit
  Align(folderpath+video, labelpath+video, vid_input[:, 2:6])
  bbs_corrected, classes = VideoBBs(labelpath+video[:-4]+'_corrected.mp4')
  video_count = np.repeat(i, len(bbs_corrected)).reshape((len(bbs_corrected),1))
  bbs_corrected = np.concatenate((video_count, bbs_corrected), axis=1)
  hand_bbs_corrected = bbs_corrected[(bbs_corrected[:, 2] == hand_class)]
  hand_bbs_corrected = hand_bbs_corrected[:, [0,1,3,4,5,6]]


  # 5. Filter for hand by class and trackingID
  vid_labels = np.zeros((frameCount, 6))
  vid_labels[:] = np.nan

  for frame in range(len(vid_labels)):
    # corrected hand bounding boxes: xc, yc, w, h
    hand_information = hand_bbs_corrected[(hand_bbs_corrected[:, 1] == frame)]

    # fill information
    if (frame in hand_bbs_corrected[:, 1]) & (hand_information.shape[0] == 1):
      vid_labels[frame] = hand_information
    else:
      # if we have a frame without information on the object (object was not detected)
      # we take the information from the last row if it was not the first frame
      if frame == 0:
        vid_labels[frame, 1] = frame
        vid_labels[frame, 0] = i
        continue
      else:
        vid_labels[frame, 2:] = vid_labels[frame-1, 2:]

    vid_labels[frame, 1] = frame
    vid_labels[frame, 0] = i


  # 6. Bind input X together and bind labels Y together
  input = np.concatenate((input, vid_input), axis=0)
  labels = np.concatenate((labels, vid_labels), axis=0)


# Input X and labels Y (remove first row because it is random init)
input = input[1:]
labels = labels[1:]

start_frame = 0
for i, frame in enumerate(input):
  if np.isnan(frame).any():
    start_frame = i+1
  else:
    break

# only keep rows with information on both, hand and target
input = input[start_frame:]
labels = labels[start_frame:]

Video 1/1: banana02.mp4





    causing potential out-of-memory errors for large sources or long-running streams/videos.

    Usage:
        results = model(source=..., stream=True)  # generator of Results objects
        for r in results:
            boxes = r.boxes  # Boxes object for bbox outputs
            masks = r.masks  # Masks object for segment masks outputs
            probs = r.probs  # Class probabilities for classification outputs

video 1/1 (1/116) /content/drive/MyDrive/Studium/Semester M2/Study Project: Grasping/0 Trials/banana02.mp4: 384x640 1 orange, 1 bed, 1 tv, 1 laptop, 1 keyboard, 1 cell phone, 7.6ms
video 1/1 (2/116) /content/drive/MyDrive/Studium/Semester M2/Study Project: Grasping/0 Trials/banana02.mp4: 384x640 1 person, 1 orange, 1 bed, 1 tv, 1 laptop, 1 keyboard, 8.2ms
video 1/1 (3/116) /content/drive/MyDrive/Studium/Semester M2/Study Project: Grasping/0 Trials/banana02.mp4: 384x640 1 orange, 1 bed, 1 tv, 1 laptop, 1 keyboard, 9.5ms
video 1/1 (4/116) /content/drive/MyDrive/Studium/Se

In [None]:
print(input.shape)
print(input[:50])
print()
print(labels.shape)
print(labels[:50])

(116, 10)
[[          0           0         nan         nan         nan         nan         nan         nan         nan         nan]
 [          0           1         nan         nan         nan         nan        1227      1063.3      470.29      83.989]
 [          0           2         nan         nan         nan         nan        1227      1063.3      470.29      83.989]
 [          0           3         nan         nan         nan         nan        1227      1063.3      470.29      83.989]
 [          0           4         nan         nan         nan         nan        1227      1063.3      470.29      83.989]
 [          0           5         nan         nan         nan         nan        1227      1063.3      470.29      83.989]
 [          0           6         nan         nan         nan         nan        1227      1063.3      470.29      83.989]
 [          0           7         nan         nan         nan         nan        1227      1063.3      470.29      83.989]
 [    

## Model

### Positional Encoding

In [None]:
#display.Image('/content/drive/MyDrive/Studium/Semester M2/Study Project: Grasping/Media/pos_encoding.png', width=960, height=480)

In [None]:
# https://www.youtube.com/watch?v=ZMxVe-HK174
"""
def PositionalEncoding(dims, max_seq_length):
    even_i = np.arange(0, dims, 2, dtype=float)
    denominator = tf.math.pow(10000, even_i/dims)
    position = np.arange(max_seq_length).reshape(max_seq_length, 1)
    even_PE = np.sin(position / denominator)
    odd_PE = np.cos(position / denominator)
    stacked = tf.stack([even_PE, odd_PE], 2)
    PE = tf.reshape(stacked, (max_seq_length, dims))

    return PE
""";

In [None]:
class PositionalEncoding(tf.keras.layers.Layer):
    def __init__(self, position, d_model):
        super(PositionalEncoding, self).__init__()
        self.positional_encoding = self.get_positional_encoding(position, d_model)

    def get_positional_encoding(self, sequence_length, input_dim):
        angle_rads = self.get_angles(tf.range(sequence_length, dtype=tf.float32)[:, tf.newaxis],
                                     tf.range(input_dim, dtype=tf.float32)[tf.newaxis, :],
                                     input_dim)

        # Apply sine to even indices in the array
        sines = tf.math.sin(angle_rads[:, 0::2])
        # Apply cosine to odd indices in the array
        cosines = tf.math.cos(angle_rads[:, 1::2])

        # Concatenate sines and cosines
        pos_encoding = tf.concat([sines, cosines], axis=-1)
        pos_encoding = pos_encoding[tf.newaxis, ...]
        return tf.cast(pos_encoding, dtype=tf.float32)

    def get_angles(self, sequence_length, i, input_dim):
        angle_rates = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(input_dim, tf.float32))
        return sequence_length * angle_rates

    def call(self, inputs):
        return inputs + self.positional_encoding[:, :tf.shape(inputs)[1], :]

### Encoder

In [None]:
# https://keras.io/examples/timeseries/timeseries_transformer_classification/

"""
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Normalization and Attention
    x = layers.LayerNormalization(epsilon=1e-6)(inputs)
    x = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(x, x)
    x = layers.Dropout(dropout)(x)
    res = x + inputs

    # Feed Forward Part
    x = layers.LayerNormalization(epsilon=1e-6)(res)
    x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
    return x + res
""";

In [None]:
class TransformerEncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(TransformerEncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = keras.Sequential([
            Dense(dff, activation='relu'),
            Dense(d_model)
        ])

        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)

        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.mha(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        return out2

class TransformerEncoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_dim, rate=0.1):
        super(TransformerEncoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.dense = Dense(d_model)
        self.enc_layers = [TransformerEncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]

        self.dropout = Dropout(rate)

    def call(self, x, training):
        x = self.dense(x)
        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training)

        return x

### Decoder

In [None]:
class TransformerDecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(TransformerDecoderLayer, self).__init__()

        self.masked_mha1 = MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.mha2 = MultiHeadAttention(num_heads=num_heads, key_dim=d_model)

        self.ffn = keras.Sequential([
            Dense(dff, activation='relu'),
            Dense(8) # hardcode test
        ])

        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.layernorm3 = LayerNormalization(epsilon=1e-6)

        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
        self.dropout3 = Dropout(rate)

    def call(self, inputs, enc_output, training, look_ahead_mask=None):
        attn1 = self.masked_mha1(inputs, inputs, attention_mask=look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(inputs + attn1)

        attn2 = self.mha2(out1, enc_output)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2)

        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)

        return out3

class TransformerDecoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, output_dim, rate=0.1):
        super(TransformerDecoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.dec_layers = [TransformerDecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]

        self.final_layer = Dense(output_dim)

        self.dropout = Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask=None):
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, enc_output, training, look_ahead_mask)

        x = self.final_layer(x)

        return x

### Build

In [None]:
"""
def build_model(
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0,
    mlp_dropout=0,
):
    inputs = keras.Input(shape=input_shape)
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)

    x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)
    outputs = layers.Dense(n_classes, activation="softmax")(x)
    return keras.Model(inputs, outputs)
""";

In [None]:
class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, sequence_length, input_dim, output_dim, rate=0.1):
        super(Transformer, self).__init__()

        self.positional_encoder = PositionalEncoding(sequence_length, input_dim)
        self.encoder = TransformerEncoder(num_layers, d_model, num_heads, dff, input_dim, rate)
        self.decoder = TransformerDecoder(num_layers, d_model, num_heads, dff, output_dim, rate)

    def call(self, inp, training, look_ahead_mask=None):
        x = self.positional_encoder(inp)
        enc_output = self.encoder(x, training)
        dec_output = self.decoder(x, enc_output, training, look_ahead_mask)

        return dec_output

## Training

In [None]:
"""
input_shape = x_train.shape[1:]

model = build_model(
    input_shape,
    head_size=256,
    num_heads=4,
    ff_dim=4,
    num_transformer_blocks=4,
    mlp_units=[128],
    mlp_dropout=0.4,
    dropout=0.25,
)

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    metrics=["sparse_categorical_accuracy"],
)
model.summary()

callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]

model.fit(
    x_train,
    y_train,
    validation_split=0.2,
    epochs=200,
    batch_size=64,
    callbacks=callbacks,
)

model.evaluate(x_test, y_test, verbose=1)
"""

## Prediction

In [None]:
# Define hyperparameters
num_layers = 4 # How often to stack encoder & decoder blocks
d_model = 128 # Dimensionality of the model's hidden states and the size of the model's embedding vectors
num_heads = 8 # Number of parallel self attention heads in the MHA layer
dff = 512 # Dimensionality of the feed-forward sublayer
sequence_length = len(input)  # length of the sequence
input_dim = 8  # Four coordinates each for hand and object (+ vidID + frame ?)
output_dim = 4  # Whats our output dimension? I would assume 4 as for the BB coordinates of the hand in each frame

# Instantiate the Transformer model
transformer = Transformer(num_layers, d_model, num_heads, dff, sequence_length, input_dim, output_dim)

# Generate dummy batched input
batch_size = 1
#input_data = tf.random.uniform((batch_size, sequence_length, input_dim))
input_data = input.reshape((batch_size, input.shape[0], input.shape[1]))[:, :, 2:]

# Generate upper triangular look-ahead mask for decoder
look_ahead_mask = tf.linalg.band_part(tf.ones((sequence_length, sequence_length)), -1, 0)

# Obtain predictions
predictions = transformer(input_data, training=False, look_ahead_mask=look_ahead_mask)

In [None]:
print(f"Shape: {predictions.shape}, \nPredictions: \n{predictions}")
print(len(predictions))

# guidance vector: current - prediction

Shape: (1, 102, 4), 
Predictions: 
[[[     3.0885     0.36155     -0.7416     -1.1828]
  [     3.0816     0.34689    -0.71526     -1.1852]
  [     3.0828     0.34478    -0.71666     -1.1863]
  [     3.0844      0.3456    -0.71775     -1.1856]
  [     3.0792     0.35227    -0.71318     -1.1896]
  [     3.0795      0.3535    -0.71284     -1.1888]
  [     3.0693     0.36939    -0.70453     -1.1888]
  [     3.0702     0.36762    -0.70494     -1.1892]
  [     3.0652     0.37001    -0.70356     -1.1969]
  [     3.0628     0.37275    -0.70219     -1.2002]
  [     3.0122     0.31882    -0.62378        -1.2]
  [     3.0157      0.3128      -0.624     -1.1976]
  [     3.0168     0.30994    -0.62416     -1.1968]
  [      3.076     0.35157    -0.68577     -1.1863]
  [     3.0815      0.3472    -0.68762     -1.1828]
  [     3.0856     0.34341     -0.6919     -1.1796]
  [     3.0853     0.34469    -0.69211     -1.1792]
  [     3.0858     0.34468    -0.69181     -1.1773]
  [     3.0869     0.34457   