In [1]:
import torch
from torch import nn
import numpy as np
import random
import tensorflow as tf

random.seed(93)
np.random.seed(93)
torch.manual_seed(93)

torch.cuda.is_available()

2023-12-17 17:36:16.772880: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-17 17:36:16.772913: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-17 17:36:16.773871: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-17 17:36:16.778781: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


True

In [2]:
tf_chkpt = "./models/cased_L-12_H-768_A-12/bert_model.ckpt"
tf.train.list_variables(tf_chkpt)[:21]

[('bert/embeddings/LayerNorm/beta', [768]),
 ('bert/embeddings/LayerNorm/gamma', [768]),
 ('bert/embeddings/position_embeddings', [512, 768]),
 ('bert/embeddings/token_type_embeddings', [2, 768]),
 ('bert/embeddings/word_embeddings', [28996, 768]),
 ('bert/encoder/layer_0/attention/output/LayerNorm/beta', [768]),
 ('bert/encoder/layer_0/attention/output/LayerNorm/gamma', [768]),
 ('bert/encoder/layer_0/attention/output/dense/bias', [768]),
 ('bert/encoder/layer_0/attention/output/dense/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/key/bias', [768]),
 ('bert/encoder/layer_0/attention/self/key/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/query/bias', [768]),
 ('bert/encoder/layer_0/attention/self/query/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/value/bias', [768]),
 ('bert/encoder/layer_0/attention/self/value/kernel', [768, 768]),
 ('bert/encoder/layer_0/intermediate/dense/bias', [3072]),
 ('bert/encoder/layer_0/intermediate/dense/kernel',

In [3]:
class BertConfig:

    def __init__(self, vocab_size: int, vocab_pad: int = 0, d_model: int = 768,
                 seq_len: int = 512, attention_heads = 12, layer_norm_eps=1e-5):
        self.vocab_size = vocab_size
        self.vocab_pad = vocab_pad
        self.d_model = d_model
        self.seq_len = seq_len
        self.layer_norm_eps = layer_norm_eps
        self.attention_head = attention_heads


class BertEmbedding(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        self.word_embedding = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model, padding_idx=config.vocab_pad)
        self.segment_embedding = nn.Embedding(num_embeddings=3, embedding_dim=config.d_model, padding_idx=2)
        self.pos_embedding = nn.Embedding(num_embeddings=config.seq_len, embedding_dim=config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, seq, seq_seg):
        embedding = self.word_embedding(seq) + self.pos_embedding(seq) + self.segment_embedding(seq_seg)
        return self.dropout(embedding)


def load_tf_var(chpt: str, src_var: str, target: nn.parameter.Parameter, processor=lambda x: x):
    src_val = tf.train.load_variable(chpt, src_var)
    src_val = processor(src_val)
    src_val = torch.from_numpy(src_val).float()
    target.copy_(src_val)


def load_embeddings(tf_chk, embedding: BertEmbedding):
    with torch.no_grad():
        # layer norm
        load_tf_var(tf_chk, "bert/embeddings/LayerNorm/gamma", embedding.layer_norm.weight)
        load_tf_var(tf_chk, "bert/embeddings/LayerNorm/beta", embedding.layer_norm.bias)

        # token embeddings
        embedding_shape = embedding.word_embedding.weight.shape
        load_tf_var(tf_chk, "bert/embeddings/word_embeddings", embedding.word_embedding.weight)
        load_tf_var(tf_chk, "bert/embeddings/position_embeddings", embedding.pos_embedding.weight)
        load_tf_var(tf_chk, "bert/embeddings/token_type_embeddings", embedding.segment_embedding.weight,
                    lambda matrix: np.vstack([matrix, np.zeros(embedding_shape[1])]))

In [4]:
config = BertConfig(vocab_size=28996)
embedding = BertEmbedding(config)
load_embeddings(tf_chkpt, embedding)
with torch.no_grad():
    example = embedding(torch.IntTensor([[1, 2]]), seq_seg=torch.IntTensor([[1, 2]]))
example

tensor([[[-0.0139, -0.0442, -0.0051,  ..., -0.0000, -0.0369, -0.0127],
         [-0.0161, -0.0390,  0.0132,  ..., -0.0043, -0.0360, -0.0300]]])

In [5]:
with torch.no_grad():
    example = embedding(torch.IntTensor([[1, 2]]), seq_seg=torch.IntTensor([[1, 1]]))
example

tensor([[[-0.0139, -0.0442, -0.0051,  ..., -0.0126, -0.0369, -0.0127],
         [-0.0189, -0.0433,  0.0095,  ..., -0.0070, -0.0384, -0.0326]]])

In [5]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        assert config.d_model % config.attention_head == 0

        self.heads = config.attention_head
        self.size_per_head = config.d_model // self.heads
        self.d_model = config.d_model

        self.q_proj = nn.Linear(in_features=self.d_modelm, out_features=self.d_model)