In [1]:
from typing import List

import torch
from torch import nn
import numpy as np
import random
import tensorflow as tf
from torch.nn import ModuleList
from transformers import AutoTokenizer
from transformers import BertForMaskedLM
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel

random.seed(93)
np.random.seed(93)
torch.manual_seed(93)

torch.cuda.is_available()

2023-12-20 20:46:11.811805: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-20 20:46:11.811836: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-20 20:46:11.812763: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-20 20:46:11.817364: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


True

## Recreating BERT Github in PyTorch

In [2]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
masked = 4
example_str = "Mothers give birth to children."
exampled_tokenized = tokenizer(example_str)
exampled_tokenized["input_ids"][masked] = tokenizer.mask_token_id
tokenizer.decode(exampled_tokenized["input_ids"])

'[CLS] Mothers give [MASK] to children. [SEP]'

In [3]:
exampled_tokenized

{'input_ids': [101, 4872, 1116, 1660, 103, 1106, 1482, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [4]:
tf_chkpt = "./models/cased_L-12_H-768_A-12/bert_model.ckpt"
tf.train.list_variables(tf_chkpt)[:21]

[('bert/embeddings/LayerNorm/beta', [768]),
 ('bert/embeddings/LayerNorm/gamma', [768]),
 ('bert/embeddings/position_embeddings', [512, 768]),
 ('bert/embeddings/token_type_embeddings', [2, 768]),
 ('bert/embeddings/word_embeddings', [28996, 768]),
 ('bert/encoder/layer_0/attention/output/LayerNorm/beta', [768]),
 ('bert/encoder/layer_0/attention/output/LayerNorm/gamma', [768]),
 ('bert/encoder/layer_0/attention/output/dense/bias', [768]),
 ('bert/encoder/layer_0/attention/output/dense/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/key/bias', [768]),
 ('bert/encoder/layer_0/attention/self/key/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/query/bias', [768]),
 ('bert/encoder/layer_0/attention/self/query/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/value/bias', [768]),
 ('bert/encoder/layer_0/attention/self/value/kernel', [768, 768]),
 ('bert/encoder/layer_0/intermediate/dense/bias', [3072]),
 ('bert/encoder/layer_0/intermediate/dense/kernel',

In [5]:
class BertConfig:

    def __init__(self, vocab_size: int, vocab_pad: int = 0, d_model: int = 768, inter_size: int = 3072,
                 inter_activation: str = "GELU", seq_len: int = 512, attention_heads=12,
                 encoder_layers=12, layer_norm_eps=1e-12, hidden_dropout=0.1,
                 attn_dropout=0.1):
        self.vocab_size = vocab_size
        self.vocab_pad = vocab_pad
        self.d_model = d_model
        self.seq_len = seq_len
        self.layer_norm_eps = layer_norm_eps
        self.attention_head = attention_heads
        self.hidden_dropout = hidden_dropout
        self.attn_dropout = attn_dropout
        self.inter_size = inter_size
        self.inter_activation = inter_activation
        self.encoder_layers = encoder_layers


class BertEmbedding(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        self.word_embedding = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model,
                                           padding_idx=config.vocab_pad)
        self.segment_embedding = nn.Embedding(num_embeddings=2, embedding_dim=config.d_model)
        self.pos_embedding = nn.Embedding(num_embeddings=config.seq_len, embedding_dim=config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(p=config.hidden_dropout)

    def forward(self, seq, seq_seg):
        seq_positions = torch.IntTensor(list(range(seq.shape[1])))
        embedding = self.word_embedding(seq) + self.pos_embedding(seq_positions) + self.segment_embedding(seq_seg)
        return self.dropout(self.layer_norm(embedding))


def load_tf_var(chpt: str, src_var: str, target: nn.parameter.Parameter, processor=lambda x: x):
    src_val = tf.train.load_variable(chpt, src_var)
    src_val = processor(src_val)
    src_val = torch.from_numpy(src_val).float()
    target.copy_(src_val)
    assert torch.sum(target - src_val) <= 0.0001


def load_embeddings(tf_chk, embedding: BertEmbedding):
    with torch.no_grad():
        # layer norm
        load_tf_var(tf_chk, "bert/embeddings/LayerNorm/gamma", embedding.layer_norm.weight)
        load_tf_var(tf_chk, "bert/embeddings/LayerNorm/beta", embedding.layer_norm.bias)

        # token embeddings
        load_tf_var(tf_chk, "bert/embeddings/word_embeddings", embedding.word_embedding.weight)
        load_tf_var(tf_chk, "bert/embeddings/position_embeddings", embedding.pos_embedding.weight)
        load_tf_var(tf_chk, "bert/embeddings/token_type_embeddings", embedding.segment_embedding.weight)

In [6]:
config = BertConfig(vocab_size=28996)
embedding = BertEmbedding(config)
load_embeddings(tf_chkpt, embedding)
embedding.eval()
with torch.no_grad():
    example = embedding(torch.IntTensor([exampled_tokenized["input_ids"]]),
                        seq_seg=torch.IntTensor([exampled_tokenized["token_type_ids"]]))
example[0]

tensor([[ 0.4496,  0.0977, -0.2074,  ...,  0.0578,  0.0406, -0.0951],
        [ 0.0540,  0.3217,  0.6037,  ...,  0.3489, -0.8150, -0.2603],
        [ 0.4450,  0.7442,  0.6840,  ...,  0.5870,  0.7651, -0.7093],
        ...,
        [ 0.8370, -0.7599,  0.0051,  ..., -0.2732, -0.8569, -0.5360],
        [-0.4275,  0.6968,  0.8957,  ...,  0.1489,  0.6208,  0.2714],
        [-0.3162,  0.1007,  0.1413,  ...,  0.5393, -0.4997,  0.3309]])

In [7]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        assert config.d_model % config.attention_head == 0

        self.heads = config.attention_head
        self.size_per_head = config.d_model // self.heads
        self.d_model = config.d_model

        self.q_proj = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.k_proj = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.v_proj = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.o_proj = nn.Linear(in_features=self.size_per_head * self.heads, out_features=self.d_model)
        self.attn_dropout = nn.Dropout(config.attn_dropout)
        self.output_dropout = nn.Dropout(config.hidden_dropout)
        self.output_norm = nn.LayerNorm(self.d_model, eps=config.layer_norm_eps)

    def forward(self, seq: torch.Tensor, mask: torch.Tensor, output_attentions=False):
        k_proj = self.multihead_view(self.k_proj(seq), transpose=True)
        q_proj = self.multihead_view(self.q_proj(seq))
        v_proj = self.multihead_view(self.v_proj(seq))

        raw_scores = torch.matmul(q_proj, k_proj) / np.sqrt(q_proj.shape[-1])
        masked_scores = raw_scores.masked_fill(mask == 0, -10000)
        scaled_scores = torch.nn.functional.softmax(masked_scores, dim=-1)
        scaled_scores = self.attn_dropout(scaled_scores)
        results = torch.matmul(scaled_scores, v_proj).permute(0, 2, 1, 3)
        combined = results.reshape(seq.shape[0], seq.shape[1], -1)

        o_proj = self.output_dropout(self.o_proj(combined))
        new_embedding = self.output_norm(o_proj + seq)
        if not output_attentions:
            return new_embedding
        else:
            return new_embedding, scaled_scores

    def multihead_view(self, proj: torch.Tensor, transpose=False):
        proj_view = proj.view(proj.shape[0], proj.shape[1], self.heads, self.size_per_head)
        if not transpose:
            return proj_view.permute(0, 2, 1, 3)
        else:
            return proj_view.permute(0, 2, 3, 1)


class BertEncoderLayer(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        self.d_model = config.d_model
        self.intermediate = config.inter_size
        self.attention = MultiHeadAttention(config)
        self.intermediate_proj = nn.Linear(in_features=self.d_model, out_features=self.intermediate)
        self.intermediate_act = getattr(nn, config.inter_activation)()
        self.out_proj = nn.Linear(in_features=self.intermediate, out_features=self.d_model)
        self.out_dropout = nn.Dropout(config.hidden_dropout)
        self.out_norm = nn.LayerNorm(self.d_model, eps=config.layer_norm_eps)

    def forward(self, seq: torch.Tensor, mask: torch.Tensor):
        attn_out = self.attention(seq, mask)
        inter_out = self.intermediate_act(self.intermediate_proj(attn_out))
        layer_out = self.out_dropout(self.out_proj(inter_out))
        return self.out_norm(layer_out + attn_out)


def load_multihead_attention(tf_chk: str, layer: int, attention: MultiHeadAttention):
    with torch.no_grad():
        # K
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/self/key/bias", attention.k_proj.bias)
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/self/key/kernel", attention.k_proj.weight,
                    processor=lambda x: np.transpose(x))

        # Q
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/self/query/bias", attention.q_proj.bias)
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/self/query/kernel", attention.q_proj.weight,
                    processor=lambda x: np.transpose(x))

        # Q
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/self/value/bias", attention.v_proj.bias)
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/self/value/kernel", attention.v_proj.weight,
                    processor=lambda x: np.transpose(x))

        # Output
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/output/dense/bias", attention.o_proj.bias)
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/output/dense/kernel", attention.o_proj.weight,
                    processor=lambda x: np.transpose(x))
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/output/LayerNorm/beta", attention.output_norm.bias)
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/attention/output/LayerNorm/gamma",
                    attention.output_norm.weight)


def load_bert_encoder(tf_chk: str, layer: int, encoder: BertEncoderLayer):
    load_multihead_attention(tf_chk, layer, encoder.attention)
    with torch.no_grad():
        # Intermediate
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/intermediate/dense/bias", encoder.intermediate_proj.bias)
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/intermediate/dense/kernel", encoder.intermediate_proj.weight,
                    processor=lambda x: np.transpose(x))

        # Output
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/output/dense/bias", encoder.out_proj.bias)
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/output/dense/kernel", encoder.out_proj.weight,
                    processor=lambda x: np.transpose(x))
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/output/LayerNorm/beta", encoder.out_norm.bias)
        load_tf_var(tf_chk, f"bert/encoder/layer_{layer}/output/LayerNorm/gamma", encoder.out_norm.weight)

In [8]:
layer_0 = BertEncoderLayer(config)
load_bert_encoder(tf_chkpt, 0, layer_0)
layer_0.eval()
with torch.no_grad():
    example_layer0 = layer_0(example, torch.IntTensor([exampled_tokenized["attention_mask"]]))
example_layer0[0]

tensor([[ 0.2949, -0.0026, -0.1349,  ...,  0.0461, -0.0124, -0.0396],
        [ 0.4209,  0.6647,  0.4904,  ...,  0.4522, -0.9139, -0.2006],
        [ 0.0837,  0.8036,  0.8864,  ...,  0.3865,  0.7934, -0.3622],
        ...,
        [ 1.1550, -0.6692,  0.2890,  ..., -0.7546, -0.9407, -0.4982],
        [-0.0672,  0.2081,  0.3474,  ...,  0.3738,  0.4269,  0.0777],
        [-0.1206, -0.1941,  0.1503,  ...,  0.6953, -0.6948,  0.4040]])

In [9]:
class BertEncoder(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        self.embeddings = BertEmbedding(config)
        self.encoder_layers = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.encoder_layers)])

    def forward(self, seq: torch.Tensor, mask: torch.Tensor, seq_seg: torch.Tensor):
        word_embeddings = self.embeddings(seq, seq_seg)
        final_embeddings = word_embeddings
        for layer in self.encoder_layers:
            final_embeddings = layer(final_embeddings, mask)
        return final_embeddings


class FirstTokenPooler(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        self.pool_proj = nn.Linear(in_features=config.d_model, out_features=config.d_model)
        self.pool_act = nn.Tanh()

    def forward(self, seq: torch.Tensor):
        return self.pool_act(self.pool_proj(seq[:, 0, :]))


def convert_tf_bert(tf_check: str, encoder: BertEncoder):
    load_embeddings(tf_check, encoder.embeddings)
    for i, layer in enumerate(encoder.encoder_layers):
        load_bert_encoder(tf_check, i, layer)

In [10]:
config = BertConfig(vocab_size=28996)
bert_encoder = BertEncoder(config)
convert_tf_bert(tf_chkpt, bert_encoder)
bert_encoder.eval()
with torch.no_grad():
    example = bert_encoder(torch.IntTensor([exampled_tokenized["input_ids"]]),
                           seq_seg=torch.IntTensor([exampled_tokenized["token_type_ids"]]),
                           mask=torch.IntTensor([exampled_tokenized["attention_mask"]]))
example

tensor([[[ 0.3731, -0.0692, -0.0403,  ..., -0.0555,  0.2904,  0.1068],
         [ 0.3540, -0.4470,  0.6067,  ..., -0.2113,  0.0843, -0.0635],
         [ 0.2962, -0.2089,  0.4200,  ...,  0.2340, -0.0422, -0.1688],
         ...,
         [ 0.4827, -0.2090,  0.2774,  ...,  0.0042,  0.2757,  0.1399],
         [ 0.1953, -0.2299, -0.1022,  ...,  0.1420,  0.1598,  0.0443],
         [ 0.1453, -0.0737, -0.1196,  ..., -0.2397,  0.6364, -0.0327]]])

In [11]:
class BertMLMHead(nn.Module):

    def __init__(self, config: BertConfig, word_embedding: nn.Embedding):
        super().__init__()
        self.transform = nn.Linear(in_features=config.d_model, out_features=config.d_model)
        self.transform_act = getattr(nn, config.inter_activation)()
        self.layer_norm = nn.LayerNorm(normalized_shape=config.d_model)
        self.word_embedding = word_embedding
        self.bias = nn.Parameter(
            data=torch.rand(config.vocab_size) * 2 * np.sqrt(config.d_model) - np.sqrt(config.d_model),
            requires_grad=True)

    def forward(self, seq):
        transformed_embeddings = self.layer_norm(self.transform_act(self.transform(seq)))
        scores = torch.matmul(transformed_embeddings, self.word_embedding.weight.transpose(0, 1))
        return scores + self.bias


def load_mlm_head(tf_check: str, head: BertMLMHead):
    with torch.no_grad():
        # Transform
        load_tf_var(tf_check, "cls/predictions/transform/LayerNorm/beta", head.layer_norm.bias)
        load_tf_var(tf_check, "cls/predictions/transform/LayerNorm/gamma", head.layer_norm.weight)
        load_tf_var(tf_check, "cls/predictions/transform/dense/bias", head.transform.bias)
        load_tf_var(tf_check, "cls/predictions/transform/dense/kernel", head.transform.weight,
                    processor=lambda x: np.transpose(x))

        # Bias
        load_tf_var(tf_check, "cls/predictions/output_bias", head.bias)

In [12]:
mlm_head = BertMLMHead(config, word_embedding=bert_encoder.embeddings.word_embedding)
load_mlm_head(tf_chkpt, mlm_head)
mlm_head.eval()

BertMLMHead(
  (transform): Linear(in_features=768, out_features=768, bias=True)
  (transform_act): GELU(approximate='none')
  (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (word_embedding): Embedding(28996, 768, padding_idx=0)
)

In [13]:
with torch.no_grad():
    example_representation = bert_encoder(seq=torch.IntTensor([exampled_tokenized["input_ids"]]),
                                          mask=torch.IntTensor([exampled_tokenized["attention_mask"]]),
                                          seq_seg=torch.IntTensor(exampled_tokenized["token_type_ids"]))
    scores = mlm_head(example_representation)
prediction = torch.argmax(scores, dim=-1)
prediction

tensor([[ 119,  119, 1116, 1660, 3485, 1106, 1482,  119,  119]])

In [14]:
tokenizer.decode(prediction[0, masked])

'birth'

## Loading Transformers Weights

In [15]:
transformers_bert = BertForMaskedLM.from_pretrained("bert-base-cased")
list(transformers_bert.base_model.children())

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[BertEmbeddings(
   (word_embeddings): Embedding(28996, 768, padding_idx=0)
   (position_embeddings): Embedding(512, 768)
   (token_type_embeddings): Embedding(2, 768)
   (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
   (dropout): Dropout(p=0.1, inplace=False)
 ),
 BertEncoder(
   (layer): ModuleList(
     (0-11): 12 x BertLayer(
       (attention): BertAttention(
         (self): BertSelfAttention(
           (query): Linear(in_features=768, out_features=768, bias=True)
           (key): Linear(in_features=768, out_features=768, bias=True)
           (value): Linear(in_features=768, out_features=768, bias=True)
           (dropout): Dropout(p=0.1, inplace=False)
         )
         (output): BertSelfOutput(
           (dense): Linear(in_features=768, out_features=768, bias=True)
           (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
           (dropout): Dropout(p=0.1, inplace=False)
         )
       )
       (intermediate): BertIntermediat

In [16]:
def load_transformers_embeddings(tf_embedding: BertEmbeddings, embeddings: BertEmbedding):
    with torch.no_grad():
        embeddings.word_embedding.weight.copy_(tf_embedding.word_embeddings.weight)
        embeddings.pos_embedding.weight.copy_(tf_embedding.position_embeddings.weight)
        embeddings.segment_embedding.weight.copy_(tf_embedding.token_type_embeddings.weight)


def load_transformers_encoders(tf_layers: ModuleList, layers: List[BertEncoderLayer]):
    with torch.no_grad():
        for tf_l, l in zip(tf_layers, layers):
            # Linear Weights
            l.attention.k_proj.weight.copy_(tf_l.attention.self.key.weight)
            l.attention.q_proj.weight.copy_(tf_l.attention.self.query.weight)
            l.attention.v_proj.weight.copy_(tf_l.attention.self.value.weight)
            l.attention.o_proj.weight.copy_(tf_l.attention.output.dense.weight)

            # Linear Bias
            l.attention.k_proj.bias.copy_(tf_l.attention.self.key.bias)
            l.attention.q_proj.bias.copy_(tf_l.attention.self.query.bias)
            l.attention.v_proj.bias.copy_(tf_l.attention.self.value.bias)
            l.attention.o_proj.bias.copy_(tf_l.attention.output.dense.bias)

            # Attention Norm
            l.attention.output_norm.weight.copy_(tf_l.attention.output.LayerNorm.weight)
            l.attention.output_norm.bias.copy_(tf_l.attention.output.LayerNorm.bias)

            # Intermediate Linear
            l.intermediate_proj.weight.copy_(tf_l.intermediate.dense.weight)
            l.intermediate_proj.bias.copy_(tf_l.intermediate.dense.bias)

            # Output Linear + Norm
            l.out_proj.weight.copy_(tf_l.output.dense.weight)
            l.out_proj.bias.copy_(tf_l.output.dense.bias)
            l.out_norm.weight.copy_(tf_l.output.LayerNorm.weight)
            l.out_norm.bias.copy_(tf_l.output.LayerNorm.bias)


def load_transformers_base_bert(tf_bert: BertModel, bert_base: BertEncoder):
    load_transformers_embeddings(tf_bert.embeddings, bert_base.embeddings)
    load_transformers_encoders(tf_bert.encoder.layer, bert_base.encoder_layers)

In [17]:
load_transformers_base_bert(transformers_bert.base_model, bert_encoder)
bert_encoder.eval()
with torch.no_grad():
    example = bert_encoder(torch.IntTensor([exampled_tokenized["input_ids"]]),
                           seq_seg=torch.IntTensor([exampled_tokenized["token_type_ids"]]),
                           mask=torch.IntTensor([exampled_tokenized["attention_mask"]]))
example

tensor([[[ 0.3731, -0.0692, -0.0403,  ..., -0.0555,  0.2904,  0.1068],
         [ 0.3540, -0.4470,  0.6067,  ..., -0.2113,  0.0843, -0.0635],
         [ 0.2962, -0.2089,  0.4200,  ...,  0.2340, -0.0422, -0.1688],
         ...,
         [ 0.4827, -0.2090,  0.2774,  ...,  0.0042,  0.2757,  0.1399],
         [ 0.1953, -0.2299, -0.1022,  ...,  0.1420,  0.1598,  0.0443],
         [ 0.1453, -0.0737, -0.1196,  ..., -0.2397,  0.6364, -0.0327]]])

In [18]:
tf_example = transformers_bert.base_model(input_ids=torch.IntTensor([exampled_tokenized["input_ids"]]),
                                          attention_mask=torch.IntTensor([exampled_tokenized["attention_mask"]]),
                                          token_type_ids=torch.IntTensor(
                                              exampled_tokenized["token_type_ids"])).last_hidden_state
tf_example

tensor([[[ 0.3731, -0.0692, -0.0403,  ..., -0.0555,  0.2904,  0.1068],
         [ 0.3540, -0.4470,  0.6067,  ..., -0.2113,  0.0843, -0.0635],
         [ 0.2962, -0.2089,  0.4200,  ...,  0.2340, -0.0422, -0.1688],
         ...,
         [ 0.4827, -0.2090,  0.2774,  ...,  0.0042,  0.2757,  0.1399],
         [ 0.1953, -0.2299, -0.1022,  ...,  0.1420,  0.1598,  0.0443],
         [ 0.1453, -0.0737, -0.1196,  ..., -0.2397,  0.6364, -0.0327]]],
       grad_fn=<NativeLayerNormBackward0>)

In [19]:
torch.mean(torch.abs(example - tf_example))

tensor(1.4656e-07, grad_fn=<MeanBackward0>)

In [20]:
torch.mean(torch.abs(example_representation - example))

tensor(0.)

In [21]:
tf_mlm = transformers_bert(input_ids=torch.IntTensor([exampled_tokenized["input_ids"]]),
                                          attention_mask=torch.IntTensor([exampled_tokenized["attention_mask"]]),
                                          token_type_ids=torch.IntTensor(
                                              exampled_tokenized["token_type_ids"]))
tf_prediction = torch.argmax(tf_mlm.logits, dim=-1)
tf_prediction

tensor([[ 119,  119, 1116, 1660, 3485, 1106, 1482,  119,  119]])

In [22]:
prediction

tensor([[ 119,  119, 1116, 1660, 3485, 1106, 1482,  119,  119]])