<a href="https://colab.research.google.com/github/stevec12/VTubers-Analysis/blob/main/CommentPrompting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Comment Prompting
This Jupyter notebook looks at training a basic transformer to provide responses to prompts based on how a YoutTuber's comments would likely reply.

The YouTuber chosen is for the demo is [Ceres Fauna](!https://www.youtube.com/channel/UCO_aKKYxn4tvrqPjcTzZ6EQ), an English streamer with predominantly English comments.

The channel ID is `UCO_aKKYxn4tvrqPjcTzZ6EQ`.

# Data Extraction
The `YouTube Data API v3` can be used for this task, and an account-linked API-key can be obtained using your personal Google (Developer) Account.

In [None]:
import googleapiclient.discovery
import googleapiclient.errors

import numpy as np
import pandas as pd
!pip install xlsxwriter
import xlsxwriter

Collecting xlsxwriter
  Downloading XlsxWriter-3.1.9-py3-none-any.whl (154 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.8/154.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xlsxwriter
Successfully installed xlsxwriter-3.1.9


In [None]:
print("Input API Key: ")
api_key = input()

Input API Key: 


In [None]:
# Input target channel, example is @CeresFauna
channelID = 'UCO_aKKYxn4tvrqPjcTzZ6EQ'

In [None]:
api_service_name = "youtube"
api_version = "v3"
youtube = googleapiclient.discovery.build(api_service_name, api_version, developerKey=api_key)

In [None]:
def find_uploadedID(channelID):
  request = youtube.channels().list(
      part="contentDetails",
      id=channelID
    )
  response = request.execute()

  return response['items'][0]['contentDetails']['relatedPlaylists']['uploads']

In [None]:
uploadedID=find_uploadedID(channelID)

In [None]:
def find_uploaded(uploadedID):
  videoIDs = []
  request = youtube.playlistItems().list(
        part="contentDetails",
        playlistId = uploadedID,
        maxResults = 50
  )
  response = request.execute()
  for item in response['items']:
    videoIDs.append(item['contentDetails']['videoId'])
  while('nextPageToken' in response):
    request=youtube.playlistItems().list(
        part='contentDetails',
        playlistId=uploadedID,
        pageToken=response['nextPageToken'],
        maxResults=50)
    response = request.execute()
    for item in response['items']:
      videoIDs.append(item['contentDetails']['videoId'])

  return videoIDs

In [None]:
uploaded=find_uploaded(uploadedID)

In [None]:
def get_video_comments(videoID : str) -> pd.DataFrame:
  '''
  Given a videoID, return a pandas DataFrame with video info
  '''
  column_names = ['videoID','isTopLevel','topLevelID','commentID','authorDisplayName',
                  'likeCount','publishedAt','totalReplyCount','textOriginal']

  row_list = [] # Used to create list of dict of rows before conversion to dataframe, faster
  pageToken=''
  while(True):
    request=youtube.commentThreads().list(
        part="id,snippet,replies",
        videoId=videoID,
        pageToken=pageToken,
        maxResults=100
    )
    try:
      response=request.execute()
    except googleapiclient.errors.HttpError:
      break

    for commentThread in response['items']:
      # write top level comment
      topLevelID=commentThread['snippet']['topLevelComment']['id']
      commentID=topLevelID
      authorDisplayName=commentThread['snippet']['topLevelComment']['snippet']['authorDisplayName']
      likeCount=commentThread['snippet']['topLevelComment']['snippet']['likeCount']
      publishedAt=commentThread['snippet']['topLevelComment']['snippet']['publishedAt']
      totalReplyCount=commentThread['snippet']['totalReplyCount']
      textOriginal=commentThread['snippet']['topLevelComment']['snippet']['textOriginal']

      row_list.append({'videoID':videoID,'isTopLevel':True,'topLevelID':topLevelID,
                      'commentID':commentID,'authorDisplayName':authorDisplayName,
                      'likeCount':likeCount,'publishedAt':publishedAt,
                      'totalReplyCount':totalReplyCount,'textOriginal':textOriginal})

      # If any replies, write them as well
      if 'replies' in commentThread:
        for reply in commentThread['replies']['comments']:
          commentID=reply['id']
          authorDisplayName=reply['snippet']['authorDisplayName']
          likeCount=reply['snippet']['likeCount']
          publishedAt=reply['snippet']['publishedAt']
          textOriginal=reply['snippet']['textOriginal']

          row_list.append({'videoID':videoID,'isTopLevel':False,'topLevelID':topLevelID,
                           'commentID':commentID,'authorDisplayName':authorDisplayName,
                           'likeCount':likeCount,'publishedAt':publishedAt,
                           'totalReplyCount':totalReplyCount,'textOriginal':textOriginal})

    if 'nextPageToken' not in response:
      break
    else:
      pageToken=response['nextPageToken']

  return pd.DataFrame(row_list, columns=column_names)


In [None]:
def uploaded_comments_to_excel(file_name, uploaded = uploaded):
  '''
  Writes all comments in the Uploaded playlist to an excel file, as a single
  worksheet.
  '''
  column_names = ['videoID','isTopLevel','topLevelID','commentID','authorDisplayName',
                  'likeCount','publishedAt','totalReplyCount','textOriginal']
  comment_df = get_video_comments(uploaded[0])

  for videoID in uploaded[1:]:
    comment_df = pd.concat([comment_df, get_video_comments(videoID)])

  comment_df.to_excel(file_name, engine='xlsxwriter', index=False)


In [None]:
uploaded_comments_to_excel('ceres_fauna_comments_10_27_2023.xlsx')

# Preparing the Data
Preparing the data using TensorFlow preprocessing layers.

Here, we use the `ceres_fauna_comments_10_27_2023.xlsx` excel file generated earlier.

In [None]:
import tensorflow as tf
!pip install tensorflow_text
import tensorflow_text as text
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

Collecting tensorflow_text
  Downloading tensorflow_text-2.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.5 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m6.5/6.5 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorflow_text
Successfully installed tensorflow_text-2.14.0


In [None]:
# Load data into a pandas DataFrame
comments_df = pd.read_excel('ceres_fauna_comments_10_27_2023.xlsx')

In [None]:
# We filter out comments that are not at least two seperate words
multiple_word_indices = np.char.find(comments_df['textOriginal'].to_numpy(dtype='str'), " ") > -1
multiple_word_series = comments_df.copy().loc[multiple_word_indices]['textOriginal']

comments_tensor = tf.convert_to_tensor(multiple_word_series.to_numpy(dtype='str'), dtype='string')

We split the data into train, validation, and test splits.

For reasonable training times, we use a 50/10/40 split.

In [None]:
comment_ds = tf.data.Dataset.from_tensor_slices(comments_tensor).shuffle(1000, seed=12)

train_split = int(np.floor(0.5*len(comment_ds)))
val_split = int(np.floor(0.1*len(comment_ds)))
test_split = int(len(comment_ds) - train_split - val_split)

train_ds = comment_ds.take(train_split)
val_ds = comment_ds.skip(train_split).take(val_split)
test_ds = comment_ds.skip(train_split + val_split).take(test_split)

Generate vocabulary using [subword tokenizers](https://www.tensorflow.org/text/guide/subwords_tokenizer) tutorial.

In [None]:
bert_tokenizer_params=dict(lower_case=True)
reserved_tokens=["[PAD]", "[UNK]", "[START]", "[END]"]

bert_vocab_args = dict(
    # The target vocabulary size
    vocab_size = 8000,
    # Reserved tokens that must be included in the vocabulary
    reserved_tokens=reserved_tokens,
    # Arguments for `text.BertTokenizer`
    bert_tokenizer_params=bert_tokenizer_params,
    # Arguments for `wordpiece_vocab.wordpiece_tokenizer_learner_lib.learn`
    learn_params={},
)

In [None]:
vocab_file = 'vocab.txt'

In [None]:
%%time
# vocab = bert_vocab.bert_vocab_from_dataset(
    train_ds.batch(1000).prefetch(2),
    **bert_vocab_args
)

# Save vocab to file

with open(vocab_file, 'w') as f:
  for token in vocab:
    print(token, file=f)

CPU times: user 2min 26s, sys: 419 ms, total: 2min 27s
Wall time: 2min 28s


In [None]:
vocab_size = 0
with open(vocab_file, "rb") as f:
    vocab_size = sum(1 for _ in f)

Tokenize, trim (to `MAX_TOKENS`), and pad the inputs, as well as form into (input, label) Datasets where the label is the input right-shifted by one token.

Then batch (batch size = `BATCH_SIZE`) and prefetch data.

In [None]:
tokenizer = text.BertTokenizer(vocab_file, **bert_tokenizer_params)

In [None]:
MAX_TOKENS = 128
BATCH_SIZE = 64

In [None]:
START = tf.argmax(tf.constant(reserved_tokens) == "[START]")
END = tf.argmax(tf.constant(reserved_tokens) == "[END]")

def add_start_end(ragged):
  count = ragged.bounding_shape()[0]
  starts = tf.fill([count,1,1], START)
  ends = tf.fill([count,1,1], END)

  return tf.concat([starts, ragged, ends], axis=1)

In [None]:
def prepare_batch(input_batch : tf.Tensor):
  '''
  Take a tensor outputting only (input) and tensorflow.text.Tokenizer to form
  a Dataset outputting (input, feature) where the feature is one token
  right-shifted from the input.
  Additionally, output Dataset is trimmed and 0-padded dense tensor.
  '''
  # Tokenize

  in_tokenized = tokenizer.tokenize(input_batch)[:,:MAX_TOKENS-2,:]
  in_tokenized = add_start_end(in_tokenized) # Add [START],[END] to vectors

  te_tokenized = tokenizer.tokenize(input_batch)[:,:MAX_TOKENS-1,:]
  te_tokenized = add_start_end(te_tokenized)

  la_tokenized = tokenizer.tokenize(input_batch)[:,:MAX_TOKENS-1,:]
  la_tokenized = add_start_end(la_tokenized)

  # 0-Pad and convert to dense tensor
  in_tokenized = tf.squeeze(in_tokenized.to_tensor(shape=(BATCH_SIZE,MAX_TOKENS,1)))
  te_tokenized = tf.squeeze(te_tokenized[:,:-1,:].to_tensor(shape=(BATCH_SIZE,MAX_TOKENS,1)))
  la_tokenized = tf.squeeze(la_tokenized[:,1:,:].to_tensor(shape=(BATCH_SIZE,MAX_TOKENS,1)))
  # form Dataset
  output_batch = ((in_tokenized,te_tokenized),la_tokenized)

  return output_batch

In [None]:
def make_batches(ds):
  return (
      ds
      .batch(BATCH_SIZE)
      .map(prepare_batch, tf.data.AUTOTUNE)
      .prefetch(buffer_size=tf.data.AUTOTUNE))

In [None]:
train_batches = make_batches(train_ds)
val_batches = make_batches(val_ds)

Take a singular batch as an example.

In [None]:
for (input,teacher), label in train_batches.take(1):
  break

In [None]:
print(input[0])
print(teacher[0])
print(label[0])

tf.Tensor(
[   2 1006 2550  205 1157   17 1461 1021 1001 1007 1076 1538  988 1564
    3    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0], shape=(128,), dtype=int64)
tf.Tensor(
[   2 1006 2550  205 1157   17 1461 1021 1001 1007 1076 1538  988 1564
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0

Convert tokens to vectors with a `tf.keras.layers.Embedding` layer and add positional encoding.

In [None]:
def positional_encoding(length, depth):
  depth = depth/2

  positions = np.arange(length)[:, np.newaxis]     # (seq, 1)
  depths = np.arange(depth)[np.newaxis, :]/depth   # (1, depth)

  angle_rates = 1 / (10000**depths)         # (1, depth)
  angle_rads = positions * angle_rates      # (pos, depth)

  pos_encoding = np.concatenate(
      [np.sin(angle_rads), np.cos(angle_rads)],
      axis=-1)

  return tf.cast(pos_encoding, dtype=tf.float32)

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.d_model = d_model
    self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
    self.pos_encoding = positional_encoding(length=2048, depth=d_model)

  def compute_mask(self, *args, **kwargs):
    return self.embedding.compute_mask(*args, **kwargs)

  def call(self, x):
    length = tf.shape(x)[1]
    x = self.embedding(x)
    # This factor sets the relative scale of the embedding and positonal_encoding.
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x = x + self.pos_encoding[tf.newaxis, :length, :]
    return x

In [None]:
embed = PositionalEmbedding(vocab_size=vocab_size, d_model=512)
te_emb = embed(teacher)
te_emb._keras_mask;
in_emb = embed(input)
in_emb._keras_mask;

In [None]:
class BaseAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.layernorm = tf.keras.layers.LayerNormalization()
    self.add = tf.keras.layers.Add()

In [None]:
class CrossAttention(BaseAttention):
  '''def __init__(self, **kwargs):
    print('Initializing CrossAttention')
    super().__init__(self, **kwargs)'''

  def call(self, x, context):
    attn_output, attn_scores = self.mha(
        query = x,
        key = context,
        value = context,
        return_attention_scores = True
    )
    # Cache the attention scores for plotting later.
    self.last_attn_scores = attn_scores

    x = self.add([x,attn_output])
    x = self.layernorm(x)
    return x

In [None]:
sample_ca = CrossAttention(num_heads=2, key_dim=512)
print(in_emb.shape)
print(sample_ca(in_emb, te_emb).shape)

(64, 128, 512)
(64, 128, 512)


In [None]:
class GlobalSelfAttention(BaseAttention):
  def __call__(self, x):
    attn_output = self.mha(
        query = x,
        key = x,
        value = x
    )
    x = self.add([x, attn_output])
    x = self.layernorm(x)
    return x

In [None]:
sample_gsa = GlobalSelfAttention(num_heads=2, key_dim=512)
print(in_emb.shape)
print(sample_gsa(in_emb).shape)

(64, 128, 512)
(64, 128, 512)


In [None]:
class CausalSelfAttention(BaseAttention):
  def __call__(self, x):
    attn_output = self.mha(
        query = x,
        key = x,
        value = x,
        use_causal_mask = True
    )
    x = self.add([x, attn_output])
    x = self.layernorm(x)
    return x

In [None]:
sample_csa = CausalSelfAttention(num_heads=2, key_dim=512)
print(te_emb.shape)
print(sample_csa(te_emb).shape)

(64, 128, 512)
(64, 128, 512)


In [None]:
class FeedForward(tf.keras.layers.Layer):
  def __init__(self, d_model, dff, dropout_rate=0.1):
    super().__init__()
    self.seq = tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'),
        tf.keras.layers.Dense(d_model),
        tf.keras.layers.Dropout(dropout_rate)
    ])
    self.add = tf.keras.layers.Add()
    self.layer_norm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    x = self.add([x,self.seq(x)])
    x = self.layer_norm(x)
    return x

In [None]:
sample_ffn = FeedForward(512,2048)

print(te_emb.shape)
print(sample_ffn(te_emb).shape)

(64, 128, 512)
(64, 128, 512)


In [None]:
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self,*,d_model,num_heads,dff,dropout_rate=0.1):
    super().__init__()

    self.self_attention = GlobalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate
    )
    self.ffn = FeedForward(d_model,dff)

  def call(self,x):
    x = self.self_attention(x)
    x = self.ffn(x)
    return x

In [None]:
sample_encoder_layer = EncoderLayer(d_model=512, num_heads=8,dff=2048)
print(in_emb.shape)
print(sample_encoder_layer(in_emb).shape)

(64, 128, 512)
(64, 128, 512)


In [None]:
class Encoder(tf.keras.layers.Layer):
  def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size, dropout_rate=0.1):
    super().__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEmbedding(
        vocab_size = vocab_size, d_model = d_model
    )

    self.enc_layers = [
        EncoderLayer(d_model=d_model,
                     num_heads=num_heads,
                     dff=dff,
                     dropout_rate=dropout_rate)
        for _ in range(num_layers)
    ]
    self.dropout = tf.keras.layers.Dropout(dropout_rate)

  def call(self,x):
    # `x` is token-IDs shape: (batch_size, seq_len)
    x = self.pos_embedding(x) # Shape '(batch_size, seq_len, d_model)'.

    # Add dropout
    x = self.dropout(x)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x)

    return x # Shape `(batch_size, seq_length, d_model)`

In [None]:
# Test encoder
sample_encoder = Encoder(num_layers=4, d_model=512, num_heads=8, dff=2048, vocab_size=vocab_size)
sample_encoder_output = sample_encoder(input,training=False)

print(in_emb.shape)
print(sample_encoder_output.shape) # Shape `(batch_size, input_seq_len, d_model)`

(64, 128, 512)
(64, 128, 512)


In [None]:
class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self,*,d_model,num_heads,dff,dropout_rate=0.1):
    super(DecoderLayer,self).__init__()

    self.causal_self_attention = CausalSelfAttention(
        num_heads = num_heads,
        key_dim = d_model,
        dropout = dropout_rate
    )

    self.cross_attention = CrossAttention(
        num_heads = num_heads,
        key_dim = d_model,
        dropout = dropout_rate
    )

    self.ffn = FeedForward(d_model, dff)

  def call(self, x, context):
    x = self.causal_self_attention(x=x)
    x = self.cross_attention(x=x, context=context)

    # Cache last attention scores for plotting later
    self.last_attn_scores = self.cross_attention.last_attn_scores

    x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`
    return x

In [None]:
sample_decoder_layer = DecoderLayer(d_model=512, num_heads=8, dff=2048)

sample_decoder_layer_output = sample_decoder_layer(x=te_emb, context=in_emb)

print(te_emb.shape)
print(in_emb.shape)
print(sample_decoder_layer_output.shape) # `(batch_size, seq_len, d_model)`

(64, 128, 512)
(64, 128, 512)
(64, 128, 512)


In [None]:
class Decoder(tf.keras.layers.Layer):
  def __init__(self,*,num_layers, d_model, num_heads, dff, vocab_size, dropout_rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size, d_model=d_model)

    self.dropout = tf.keras.layers.Dropout(dropout_rate)

    self.dec_layers = [
        DecoderLayer(d_model=d_model, num_heads=num_heads,
                     dff=dff, dropout_rate=dropout_rate)
        for _ in range(num_layers)
    ]
    self.last_attn_scores = None

  def call(self, x, context):
    # `x` is token-IDs shape (batch, target_seq_len)
    x = self.pos_embedding(x) # (batch_size, target_seq_len, d_model)

    x = self.dropout(x)

    for i in range(self.num_layers):
      x = self.dec_layers[i](x,context)

    self.last_attn_scores = self.dec_layers[-1].last_attn_scores

    # shape of x is (batch_size, target_seq_len, d_model)
    return x


In [None]:
sample_decoder = Decoder(num_layers=4, d_model=512, num_heads=8,
                         dff=2048, vocab_size=vocab_size)

output = sample_decoder(x=teacher, context=in_emb)

print(teacher.shape)
print(in_emb.shape)
print(output.shape)

(64, 128)
(64, 128, 512)
(64, 128, 512)


In [None]:
sample_decoder.last_attn_scores

<tf.Tensor: shape=(64, 8, 128, 128), dtype=float32, numpy=
array([[[[0.0666575 , 0.06636824, 0.06678679, ..., 0.        ,
          0.        , 0.        ],
         [0.06661712, 0.06667682, 0.0665133 , ..., 0.        ,
          0.        , 0.        ],
         [0.06665023, 0.06635889, 0.06676417, ..., 0.        ,
          0.        , 0.        ],
         ...,
         [0.0078125 , 0.0078125 , 0.0078125 , ..., 0.0078125 ,
          0.0078125 , 0.0078125 ],
         [0.0078125 , 0.0078125 , 0.0078125 , ..., 0.0078125 ,
          0.0078125 , 0.0078125 ],
         [0.0078125 , 0.0078125 , 0.0078125 , ..., 0.0078125 ,
          0.0078125 , 0.0078125 ]],

        [[0.06671308, 0.06684019, 0.06685069, ..., 0.        ,
          0.        , 0.        ],
         [0.06643081, 0.06690039, 0.0667645 , ..., 0.        ,
          0.        , 0.        ],
         [0.0668325 , 0.0670995 , 0.06680577, ..., 0.        ,
          0.        , 0.        ],
         ...,
         [0.0078125 , 0.00781

In [None]:
class Transformer(tf.keras.Model):
  def __init__(self, *, num_layers, d_model, num_heads, dff,
               input_vocab_size, target_vocab_size, dropout_rate=0.1):
    super().__init__()
    self.encoder = Encoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                           dff=dff, vocab_size=input_vocab_size, dropout_rate=dropout_rate)
    self.decoder = Decoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                           dff=dff, vocab_size=target_vocab_size, dropout_rate=dropout_rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inputs):
    # To support Keras model '.fit', pass all inputs as first argument
    context, x = inputs

    context = self.encoder(context) # (batch_size, context_len, d_model)

    x = self.decoder(x, context) # (batch_size, target_len, d_model)

    logits = self.final_layer(x) # (batch_size, target_len, target_vocab_size)

    try:
      # Drop keras mask, so it doesn't scale losses/metrics
      del logits._keras_mask
    except AttributeError:
      pass

    # Return final output and attention weights
    return logits

## Hyperparameters

In [None]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1

## Testing

In [None]:
transformer = Transformer(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                          dff=dff, input_vocab_size=vocab_size, target_vocab_size=vocab_size,
                          dropout_rate=dropout_rate)

In [None]:
output = transformer((input,teacher))
print(teacher.shape)
print(input.shape)
print(output.shape)

tf.Tensor(
[[   2 1006 2550 ...    0    0    0]
 [   2   50 1522 ...    0    0    0]
 [   2 1665   44 ...    0    0    0]
 ...
 [   2  998  996 ...    0    0    0]
 [   2 1009  987 ...    0    0    0]
 [   2 1006  990 ...    0    0    0]], shape=(64, 128), dtype=int64)
(64, 128)
(64, 128)
(64, 128, 7682)


In [None]:
attn_scores = transformer.decoder.dec_layers[-1].last_attn_scores
print(attn_scores.shape) # batch, heads, target_seq, input_seq

(64, 8, 128, 128)


In [None]:
transformer.summary()

Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_1 (Encoder)         multiple                  3622144   
                                                                 
 decoder_1 (Decoder)         multiple                  5733120   
                                                                 
 dense_38 (Dense)            multiple                  990978    
                                                                 
Total params: 10346242 (39.47 MB)
Trainable params: 10346242 (39.47 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [None]:
pred_token_vector = tf.argmax(output,axis=2)
pred_token = tokenizer.detokenize(pred_token_vector)
pred_phrases = tf.strings.reduce_join(pred_token,axis=1,separator=' ')
print(pred_phrases)

tf.Tensor(
[b'\xca\x99\xe0\xb2\xa5 dang\xe0\xb2\xa5\xe0\xb2\xa5 wtfiest\xf0\x9f\x91\x8d \xf0\x9f\x8c\xb2 update \xf0\x9f\x8d\x8d todays \xe4\xb8\x80 chills pace pace pace todays todays todays todays todays \xf0\x9f\x99\x8a\xe3\x81\x8b\xe3\x81\x8b bee bee bee bee bee bee \xe2\x9c\xa8 \xe2\x8c\x9b \xe2\x8c\x9b molyffeffe mob mob \xe6\xa0\xb9 seiso seiso\xe3\x81\x8b \xf0\x9f\xa5\x82 \xf0\x9f\xa5\x82 \xf0\x9f\xa5\x82\xe2\x9a\xa1 seiso seiso seiso seiso seiso placement placementry seisojected strong heck heck shulker updatetifftifftifftifftiff speaking rule rule turnedant turned\xf0\x9f\x90\xba\xf0\x9f\x90\xba\xf0\x9f\x90\xba\xf0\x9f\x90\xba\xf0\x9f\x90\xba amnesia wow rude rude\xf0\x9f\x91\x8d\xf0\x9f\x91\x8d\xf0\x9f\x91\x8d 2023 seiso\xf0\x9f\x90\x9b seiso seiso seiso\xe7\x9c\xa0\xe7\x9c\xa0\xf0\x9f\x90\x9b\xf0\x9f\x90\x9b\xf0\x9f\x90\x9b pair pair seiso seiso seiso seiso chapter\xe7\x9c\xa0\xe7\x9c\xa0\xf0\x9f\x90\x9b\xf0\x9f\x90\x9b behind behinduffuffufflla crying tree tree tail tail t

In [None]:
pred = tf.Variable(
[b'40ized got cards \xf0\x9f\x92\x94 above \xf0\x9f\x92\x94 attacked es watchalong\xe7\x92\xb0 summon summon showsgogogogogogogogo stuck stuck stuck stuck\xe5\x84\xaa\xe5\x84\xaarasedingedinggogo\xf0\x9f\xa5\xa5\xf0\x9f\xa5\xa5\xf0\x9f\xa5\xa5\xf0\x9f\xa5\x95\xf0\x9f\xa5\x95gogh \xf0\x9f\x92\x94 \xf0\x9f\x92\x94 heartbeat\xf0\x9f\xa5\x95gogogogo \xe1\x84\x92 \xe1\x84\x92edingedingeding\xe3\x83\xaa\xe3\x83\xaa\xe3\x83\xaarayray\xe3\x83\xaa\xe3\x83\xaa# platform platform platformedingeding#go near near noise \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92gogogogogogogoedingedinggo\xe3\x83\xaa\xe3\x83\xaa robot robot robot plane plane color \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 bounce\xe3\x80\x8e\xe3\x80\x8e \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 heartbeat heartbeat \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xf0\x9f\x8c\x88 \xf0\x9f\x8c\x88 \xf0\x9f\x8c\x88 \xe1\x84\x92\xe8\xb6\xb3 heartbeat',
 b'##edingannereding sweepzzy \xf0\x9f\x92\x94\xe7\x81\xab attacked attacked attackedndingnding summonndingnding horrible\xe6\x9c\xaa wholesome stucknding stuck stuck stuck stuck \xe4\xbf\xa1 stuckzzy honest \xf0\x9f\x98\x8ceding \xe4\xbf\xa1 \xe4\xbf\xa1nding yupedingedingeding \xe7\xa7\x81\xe5\x84\xaa \xe1\x84\x92 \xf0\x9f\xa4\x8e mythzzyedingnding going \xe1\x84\x92 \xe1\x84\x92eding ost mythedingeding\xe3\x83\xaa\xe3\x83\xaa\xe3\x83\xaarayray\xe3\x83\xaa\xe3\x83\xaa# platform platform platformedingeding# near near near noise \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92 \xe1\x84\x92gogogogo\xe3\x83\xaa\xe3\x83\xaaedingedingedinggo\xe3\x83\xaa \xe3\x81\x95 robot robot robot robot robot color \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 bounce\xe3\x80\x8e\xe3\x80\x8e \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 heartbeat heartbeat \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xe3\x81\x95 \xf0\x9f\x8c\x88 \xf0\x9f\x8c\x88 \xf0\x9f\x8c\x88 \xe1\x84\x92\xe8\xb6\xb3 heartbeat'])


In [None]:
for codes in pred:
  print(codes.numpy().decode())

40ized got cards üíî above üíî attacked es watchalongÁí∞ summon summon showsgogogogogogogogo stuck stuck stuck stuckÂÑ™ÂÑ™rasedingedinggogoü••ü••ü••ü•ïü•ïgogh üíî üíî heartbeatü•ïgogogogo ·Ñí ·Ñíedingedingeding„É™„É™„É™rayray„É™„É™# platform platform platformedingeding#go near near noise ·Ñí ·Ñí ·Ñí ·Ñí ·Ñí ·Ñí ·Ñí ·Ñígogogogogogogoedingedinggo„É™„É™ robot robot robot plane plane color „Åï „Åï „Åï „Åï „Åï „Åï „Åï bounce„Äé„Äé „Åï „Åï „Åï heartbeat heartbeat „Åï „Åï „Åï „Åï „Åï „Åï „Åï „Åï „Åï „Åï üåà üåà üåà ·ÑíË∂≥ heartbeat
##edingannereding sweepzzy üíîÁÅ´ attacked attacked attackedndingnding summonndingnding horribleÊú™ wholesome stucknding stuck stuck stuck stuck ‰ø° stuckzzy honest üòåeding ‰ø° ‰ø°nding yupedingedingeding ÁßÅÂÑ™ ·Ñí ü§é mythzzyedingnding going ·Ñí ·Ñíeding ost mythedingeding„É™„É™„É™rayray„É™„É™# platform platform platformedingeding# near near near noise ·Ñí ·Ñí ·Ñí ·Ñí ·Ñí ·Ñí ·Ñí ·Ñígogogogo„É™„É™edingedingedinggo„É™ „Åï robot robot robot robot r

## Training
Uses Adam optimizer with original [Transformer paper](https://arxiv.org/abs/1706.03762) custom learning rate scheduler.

$$lrate = d_{model}^{-0.5}*\min\left(step_{num}^{-0.5},step_{num}*warmup\_steps^{-1.5}\right)$$

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super().__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    step = tf.cast(step, dtype=tf.float32)
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


In [None]:
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9,
                                     beta_2=0.98, epsilon=1e-9)

In [None]:
# Setup padding mask for calculating loss properly
def masked_loss(label, pred):
  mask = label != 0
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction='none'
  )
  loss = loss_object(label,pred)

  mask = tf.cast(mask, dtype=loss.dtype)
  loss *= mask

  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss

def masked_accuracy(label, pred):
  pred = tf.argmax(pred, axis=2)
  label = tf.cast(label, pred.dtype)
  match = label == pred

  mask = label != 0

  match = match & mask

  match = tf.cast(match, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(match)/tf.reduce_sum(mask)

In [None]:
transformer.compile(loss=masked_loss, optimizer=optimizer,
                    metrics=[masked_accuracy])

In [None]:
transformer.fit(train_batches, epochs=20, validation_data=val_batches)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x7edc18c37820>

Saving and loading model weights manually.

In [None]:
model_checkpoint_path = 'transformer_1'

In [None]:
# Save weights
# transformer.save_weights(model_checkpoint_path)

In [None]:
# Load weights
transformer = Transformer(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                          dff=dff, input_vocab_size=vocab_size, target_vocab_size=vocab_size,
                          dropout_rate=dropout_rate)
transformer.load_weights(model_checkpoint_path)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7b63dacb4af0>

## Run Inference
Create a model to generate comments from prompts:
* Encode prompt with `tokenizer`, trim, add `[START],[END]`, then pad - this is the encoder input
* calculate padding masks and look-ahead masks
* `decoder` outputs preds by looking at `encoder` output and own output
* Concatenate predicted token to decoder input and pass to of decoder
* Decoder predicts next token based on previous tokens it predicted

In [None]:
class Commentator(tf.Module):
  def __init__(self, tokenizers, transformer):
    self.tokenizers = tokenizers
    self.transformer = transformer

  def __call__(self, sentence, max_length=MAX_TOKENS):
    # Add '[START]' and '[END]' tokens to input sentence
    assert isinstance(sentence, tf.Tensor)
    if len(sentence.shape) == 0:
      sentence = sentence[tf.newaxis]

    sentence = self.tokenizers.tokenize(sentence)[:,:MAX_TOKENS-2,:]
    sentence = tf.squeeze(add_start_end(sentence).to_tensor(shape=(1,MAX_TOKENS,1)),axis=2)

    encoder_input = sentence

    # Init output with '[START]' token
    out = self.tokenizers.tokenize(tf.constant(['']))
    start_end = add_start_end(out)[0]
    start = start_end[0][tf.newaxis]
    end = start_end[1][tf.newaxis]

    # 'tf.TensorArray' required so dynamic-loop traced by tf.function
    output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
    output_array = output_array.write(0, start)

    for i in tf.range(max_length):
      output = tf.transpose(output_array.stack())
      output = tf.reshape(output,(1,output.shape[2],1))
      output = tf.concat([output, tf.zeros((1,MAX_TOKENS-output.shape[1],1),dtype='int64')], axis=1)
      # print(output)
      output = tf.squeeze(output, axis=2)
      # print(output)
      predictions = self.transformer([encoder_input, output], training = False)

      # Select last token for `seq_len` dimension
      predictions = predictions[:,-1:,:] # Shape `(batch_size, 1, vocab_size)`
      print(predictions[:,:,:20])
      predicted_id = tf.argmax(predictions, axis=-1)

      # Concatenate `predicted_id` to output given to decoder as input
      print(f'Token ID: {predicted_id}\nToken: {tokenizer.detokenize(predicted_id)}')
      output_array = output_array.write(i+1, predicted_id)

      if predicted_id == end:
        break

    output = tf.squeeze(tf.transpose(output_array.stack()), axis=0)
    # output shape `(1,tokens)`
    text = tf.strings.reduce_join(self.tokenizers.detokenize(output)[0], axis=0, separator=" ") # Shape: `()`

    tokens = self.tokenizers.detokenize(output)[0]
    print(f'Tokens: {tokens}')
    # `tf.function` prevents usage of attention_wieghts calculated
    # on last iteration of loop - recalc. outside of loop
    self.transformer([encoder_input, output[:,:-1]], training=False)
    attention_weights = self.transformer.decoder.last_attn_scores

    return text, tokens, attention_weights

In [None]:
commentator = Commentator(tokenizer, transformer)

def print_comment(sentence, tokens):
  print(f'{"Input:":15s}: {sentence}')
  print(f'{"Prediction":15s}: {tokens.numpy().decode("utf-8")}')

sentence = 'I miss'
output_text, output_tokens, attention_weights = commentator(tf.constant(sentence))
print_comment(sentence, output_text)

tf.Tensor(
[[[-13.883633    -7.338079   -13.902933     3.5451252    2.908659
    -0.09571538  -7.1459603   -5.9502454   -8.734787   -11.164272
     0.7126359   -9.874228    -5.017789    -6.0291224  -11.039936
     2.7767124   -4.7800946   -1.6145772   -7.7028856   -5.9787483 ]]], shape=(1, 1, 20), dtype=float32)
Token ID: [[3]]
Token: <tf.RaggedTensor [[b'[END]']]>
Tokens: [b'[START]' b'[END]']
Input:         : I miss
Prediction     : [START] [END]


In [None]:
sentence2 = tf.constant('Take care of')
out_text2, out_toks2, attn_wts2 = commentator(sentence2)
print_comment(sentence2, out_text2)

tf.Tensor(
[[[-14.055527   -7.3926663 -14.074556    3.4394395   2.9511986
    -0.3863879  -7.2340217  -5.9365435  -8.750813  -11.356359
     0.6319177 -10.059809   -5.1890182  -6.3959117 -11.280291
     2.7238004  -4.6797585  -1.5573555  -7.863919   -6.0857697]]], shape=(1, 1, 20), dtype=float32)
Token ID: [[3]]
Token: <tf.RaggedTensor [[b'[END]']]>
Tokens: [b'[START]' b'[END]']
Input:         : b'Take care of'
Prediction     : [START] [END]


In [None]:
class TemperatureCommentator(tf.Module):
  def __init__(self, tokenizers, transformer):
    self.tokenizers = tokenizers
    self.transformer = transformer

  def __call__(self, sentence, temperature = 0.1, max_length=MAX_TOKENS):
    # Add '[START]' and '[END]' tokens to input sentence
    assert isinstance(sentence, tf.Tensor)
    if len(sentence.shape) == 0:
      sentence = sentence[tf.newaxis]

    sentence = self.tokenizers.tokenize(sentence)[:,:MAX_TOKENS-2,:]
    sentence = tf.squeeze(add_start_end(sentence).to_tensor(shape=(1,MAX_TOKENS,1)),axis=2)
    #sentence = add_start_end(sentence).to_tensor(shape=(1,MAX_TOKENS,1))

    encoder_input = sentence

    # Init output with '[START]' token
    out = self.tokenizers.tokenize(tf.constant(['']))
    start_end = add_start_end(out)[0]
    start = start_end[0][tf.newaxis]
    end = start_end[1][tf.newaxis]

    # 'tf.TensorArray' required so dynamic-loop traced by tf.function
    output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
    output_array = output_array.write(0, start)

    for i in tf.range(max_length):
      # output = tf.squeeze(tf.transpose(output_array.stack()), axis=0)
      output = tf.transpose(output_array.stack())
      output = tf.reshape(output,(1,output.shape[2],1))
      output = tf.concat([output, tf.zeros((1,MAX_TOKENS-output.shape[1],1),dtype='int64')], axis=1)
      output = tf.squeeze(output, axis=2)

      predictions = self.transformer((encoder_input, output), training = False)


      # Select last token for `seq_len` dimension
      # print(predictions)
      predictions = tf.squeeze(predictions[:,-1:,:]/temperature, axis=0) # Shape `(batch_size, 1, vocab_size)`
      predicted_id = tf.random.categorical(predictions, num_samples=1)

      # Concatenate `predicted_id` to output given to decoder as input
      # print(f'Pred Token ID: {predicted_id}')
      print(f'Pred Token: {tokenizer.detokenize(predicted_id)}')
      output_array = output_array.write(i+1, predicted_id)

      if predicted_id == end:
        break

    output = tf.squeeze(tf.transpose(output_array.stack()), axis=0)
    # output shape `(1,tokens)`
    text = tf.strings.reduce_join(self.tokenizers.detokenize(output)[0], axis=0, separator=" ") # Shape: `()`

    tokens = self.tokenizers.detokenize(output)[0]
    # `tf.function` prevents usage of attention_wieghts calculated
    # on last iteration of loop - recalc. outside of loop
    self.transformer([encoder_input, output[:,:-1]], training=False)
    attention_weights = self.transformer.decoder.last_attn_scores

    return text, tokens, attention_weights

In [None]:
temp_commentator = TemperatureCommentator(tokenizer, transformer)

def print_comment(sentence, tokens):
  print(f'{"Input:":15s}: {sentence}')
  print(f'{"Prediction":15s}: {tokens.numpy().decode("utf-8")}')

sentence = 'I miss'
output_text, output_tokens, attention_weights = temp_commentator(tf.constant(sentence),1.5)
print_comment(sentence, output_text)

Pred Token: <tf.RaggedTensor [[b'skipped']]>
Pred Token: <tf.RaggedTensor [[b'tricky']]>
Pred Token: <tf.RaggedTensor [[b'?']]>
Pred Token: <tf.RaggedTensor [[b'looking']]>
Pred Token: <tf.RaggedTensor [[b'on']]>
Pred Token: <tf.RaggedTensor [[b'catch']]>
Pred Token: <tf.RaggedTensor [[b'are']]>
Pred Token: <tf.RaggedTensor [[b'!']]>
Pred Token: <tf.RaggedTensor [[b'was']]>
Pred Token: <tf.RaggedTensor [[b'[END]']]>
Input:         : I miss
Prediction     : [START] skipped tricky ? looking on catch are ! was [END]


In [None]:
sentence = 'Take care of'
output_text, output_tokens, attention_weights = temp_commentator(tf.constant(sentence),1.5)
print_comment(sentence, output_text)

Pred Token: <tf.RaggedTensor [[b'fauna']]>
Pred Token: <tf.RaggedTensor [[b'!']]>
Pred Token: <tf.RaggedTensor [[b'you']]>
Pred Token: <tf.RaggedTensor [[b'you']]>
Pred Token: <tf.RaggedTensor [[b'i']]>
Pred Token: <tf.RaggedTensor [[b'glad']]>
Pred Token: <tf.RaggedTensor [[b',']]>
Pred Token: <tf.RaggedTensor [[b'what']]>
Pred Token: <tf.RaggedTensor [[b'fauna']]>
Pred Token: <tf.RaggedTensor [[b',']]>
Pred Token: <tf.RaggedTensor [[b'quickly']]>
Pred Token: <tf.RaggedTensor [[b'for']]>
Pred Token: <tf.RaggedTensor [[b'[END]']]>
Input:         : Take care of
Prediction     : [START] fauna ! you you i glad , what fauna , quickly for [END]
