In [23]:
import tensorflow as tf 
import re
import torch
import numpy as np

In [24]:
tf_path = '../../biobert_v1.1_pubmed/model.ckpt-1000000'

In [25]:
init_vars = tf.train.list_variables(tf_path)

In [26]:
excluded = ['BERTAdam','_power','global_step']
init_vars = list(filter(lambda x:all([True if e not in x[0] else False for e in excluded]),init_vars))

In [27]:
init_vars

[('bert/embeddings/LayerNorm/beta', [768]),
 ('bert/embeddings/LayerNorm/gamma', [768]),
 ('bert/embeddings/position_embeddings', [512, 768]),
 ('bert/embeddings/token_type_embeddings', [2, 768]),
 ('bert/embeddings/word_embeddings', [28996, 768]),
 ('bert/encoder/layer_0/attention/output/LayerNorm/beta', [768]),
 ('bert/encoder/layer_0/attention/output/LayerNorm/gamma', [768]),
 ('bert/encoder/layer_0/attention/output/dense/bias', [768]),
 ('bert/encoder/layer_0/attention/output/dense/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/key/bias', [768]),
 ('bert/encoder/layer_0/attention/self/key/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/query/bias', [768]),
 ('bert/encoder/layer_0/attention/self/query/kernel', [768, 768]),
 ('bert/encoder/layer_0/attention/self/value/bias', [768]),
 ('bert/encoder/layer_0/attention/self/value/kernel', [768, 768]),
 ('bert/encoder/layer_0/intermediate/dense/bias', [3072]),
 ('bert/encoder/layer_0/intermediate/dense/kernel',

In [28]:
names = []
arrays = []
for name, shape in init_vars:
    print("Loading TF weight {} with shape {}".format(name, shape))
    array = tf.train.load_variable(tf_path, name)
    names.append(name)
    arrays.append(array)

Loading TF weight bert/embeddings/LayerNorm/beta with shape [768]
Loading TF weight bert/embeddings/LayerNorm/gamma with shape [768]
Loading TF weight bert/embeddings/position_embeddings with shape [512, 768]
Loading TF weight bert/embeddings/token_type_embeddings with shape [2, 768]
Loading TF weight bert/embeddings/word_embeddings with shape [28996, 768]
Loading TF weight bert/encoder/layer_0/attention/output/LayerNorm/beta with shape [768]
Loading TF weight bert/encoder/layer_0/attention/output/LayerNorm/gamma with shape [768]
Loading TF weight bert/encoder/layer_0/attention/output/dense/bias with shape [768]
Loading TF weight bert/encoder/layer_0/attention/output/dense/kernel with shape [768, 768]
Loading TF weight bert/encoder/layer_0/attention/self/key/bias with shape [768]
Loading TF weight bert/encoder/layer_0/attention/self/key/kernel with shape [768, 768]
Loading TF weight bert/encoder/layer_0/attention/self/query/bias with shape [768]
Loading TF weight bert/encoder/layer_0/a

In [29]:
from pytorch_pretrained_bert  import BertConfig, BertForPreTraining

In [30]:
# Initialise PyTorch model
config = BertConfig.from_json_file('../../biobert_v1.1_pubmed/bert_config.json')
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)


Building PyTorch model from configuration: {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 28996
}



In [31]:

for name, array in zip(names, arrays):
    name = name.split('/')
    # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
    # which are not required for using pretrained model
    if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
        print("Skipping {}".format("/".join(name)))
        continue
    pointer = model
    for m_name in name:
        if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
            l = re.split(r'_(\d+)', m_name)
        else:
            l = [m_name]
        if l[0] == 'kernel' or l[0] == 'gamma':
            pointer = getattr(pointer, 'weight')
        elif l[0] == 'output_bias' or l[0] == 'beta':
            pointer = getattr(pointer, 'bias')
        elif l[0] == 'output_weights':
            pointer = getattr(pointer, 'weight')
        else:
            pointer = getattr(pointer, l[0])
        if len(l) >= 2:
            num = int(l[1])
            pointer = pointer[num]
    if m_name[-11:] == '_embeddings':
        pointer = getattr(pointer, 'weight')
    elif m_name == 'kernel':
        array = np.transpose(array)
    try:
        assert pointer.shape == array.shape
    except AssertionError as e:
        e.args += (pointer.shape, array.shape)
        raise
    print("Initialize PyTorch weight {}".format(name))
    pointer.data = torch.from_numpy(array)

# Save pytorch-model


Initialize PyTorch weight ['bert', 'embeddings', 'LayerNorm', 'beta']
Initialize PyTorch weight ['bert', 'embeddings', 'LayerNorm', 'gamma']
Initialize PyTorch weight ['bert', 'embeddings', 'position_embeddings']
Initialize PyTorch weight ['bert', 'embeddings', 'token_type_embeddings']
Initialize PyTorch weight ['bert', 'embeddings', 'word_embeddings']
Initialize PyTorch weight ['bert', 'encoder', 'layer_0', 'attention', 'output', 'LayerNorm', 'beta']
Initialize PyTorch weight ['bert', 'encoder', 'layer_0', 'attention', 'output', 'LayerNorm', 'gamma']
Initialize PyTorch weight ['bert', 'encoder', 'layer_0', 'attention', 'output', 'dense', 'bias']
Initialize PyTorch weight ['bert', 'encoder', 'layer_0', 'attention', 'output', 'dense', 'kernel']
Initialize PyTorch weight ['bert', 'encoder', 'layer_0', 'attention', 'self', 'key', 'bias']
Initialize PyTorch weight ['bert', 'encoder', 'layer_0', 'attention', 'self', 'key', 'kernel']
Initialize PyTorch weight ['bert', 'encoder', 'layer_0', '

In [38]:
print("Save PyTorch model to {}".format('../../biobert_v1.1_pubmed/'))
torch.save(model.state_dict(),'../../biobert_v1.1_pubmed/pytorch_weight')

Save PyTorch model to ../../biobert_v1.1_pubmed/
