In [70]:
import os
from transformers import TFBertModel, BertTokenizer
import tensorflow as tf
import numpy as np

In [22]:
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
model = TFBertModel.from_pretrained("bert-base-multilingual-cased")

In [106]:
MAX_LENGTH = 70

In [69]:
first = "The cat is on the table."
second = "That is a very fat cat by the way."
third = "Instead, that is a big dog who is sleeping."

In [107]:
encoded_first = tf.convert_to_tensor([tokenizer.encode(first, max_length = MAX_LENGTH, 
                                                       pad_to_max_length=True, padding_side='right')])
encoded_second = tf.convert_to_tensor([tokenizer.encode(second, max_length = MAX_LENGTH, 
                                                        pad_to_max_length=True, padding_side='right')])
encoded_third = tf.convert_to_tensor([tokenizer.encode(third, max_length = MAX_LENGTH, 
                                                       pad_to_max_length=True, padding_side='right')])

In [108]:
max_seq_len = max(encoded_first.shape[1], encoded_second.shape[1], encoded_third.shape[1])

In [109]:
max_seq_len

70

In [110]:
last_hidden_states_first = model(encoded_first)[0]
last_hidden_states_second = model(encoded_second)[0]
last_hidden_states_third = model(encoded_third)[0]

In [111]:
last_hidden_states_first

<tf.Tensor: shape=(1, 70, 768), dtype=float32, numpy=
array([[[-0.15480314, -0.44193774,  0.8005953 , ...,  1.0991361 ,
          0.17631558, -0.17418985],
        [-0.03439797, -0.33064383,  0.25020745, ...,  0.6079755 ,
          0.2316818 , -0.32114932],
        [-0.0583407 , -0.2696473 ,  0.1768523 , ...,  0.61290264,
          0.24749398, -0.29331076],
        ...,
        [-0.14542915, -0.9514578 ,  1.0559226 , ...,  1.1773176 ,
          0.62122995, -0.09426124],
        [-0.10198347, -0.9444886 ,  1.087671  , ...,  1.1444662 ,
          0.5994523 , -0.08976214],
        [-0.06503955, -0.94576645,  1.0640883 , ...,  1.1529771 ,
          0.5711324 , -0.0948147 ]]], dtype=float32)>

In [112]:
last_hidden_states_second

<tf.Tensor: shape=(1, 70, 768), dtype=float32, numpy=
array([[[-6.70508891e-02, -4.45779800e-01,  9.40871835e-01, ...,
          1.08377862e+00,  1.65808022e-01, -1.19882293e-01],
        [-2.82645702e-01, -1.92244276e-01,  3.87238681e-01, ...,
          3.92304659e-01, -5.57910264e-01, -4.74223763e-01],
        [-6.97368741e-01, -3.58008325e-01,  1.45396039e-01, ...,
          7.27568746e-01, -4.63806361e-01, -4.39120159e-02],
        ...,
        [ 5.43940812e-04, -1.07205355e+00,  1.04225087e+00, ...,
          1.04811192e+00,  5.94487071e-01, -2.48051807e-02],
        [ 3.71246934e-02, -1.05915940e+00,  1.07118678e+00, ...,
          1.01368797e+00,  5.83640397e-01, -1.30290352e-02],
        [ 6.30370155e-02, -1.05388796e+00,  1.04737628e+00, ...,
          1.02272177e+00,  5.64887583e-01, -1.07908435e-02]]],
      dtype=float32)>

In [113]:
last_hidden_states_third

<tf.Tensor: shape=(1, 70, 768), dtype=float32, numpy=
array([[[-0.1631954 , -0.34210354,  0.60113823, ...,  0.8355855 ,
          0.1429974 ,  0.11953822],
        [-0.5517919 ,  0.10710899,  0.84756947, ...,  0.7586643 ,
         -0.17052856,  0.02101097],
        [-0.47430706, -0.33143467,  0.2719422 , ...,  0.52472717,
         -0.14685166,  0.90198225],
        ...,
        [-0.10447182, -1.0151714 ,  0.87278354, ...,  0.90464   ,
          0.59818995,  0.10899384],
        [-0.06288485, -1.0121874 ,  0.89962924, ...,  0.86582094,
          0.58338547,  0.11547725],
        [-0.03538071, -1.0143714 ,  0.87945837, ...,  0.8712214 ,
          0.56396484,  0.12210149]]], dtype=float32)>

In [117]:
averaged_first = tf.reduce_mean(last_hidden_states_first, axis=1)
averaged_second = tf.reduce_mean(last_hidden_states_second, axis=1)
averaged_third = tf.reduce_mean(last_hidden_states_third, axis=1)

In [123]:
res = np.average([averaged_first, averaged_second, averaged_third], axis=0)

In [129]:
res.shape

(1, 768)

In [130]:
final_res = []

for elem in res[0]:
    final_res.append(float(elem))

In [136]:
len(final_res)

768