In [None]:
import torch
from torch import nn
# pip install pytorch_pretrained_bert
import pytorch_pretrained_bert
from pytorch_pretrained_bert import BertTokenizer, BertModel

In [None]:
# local path of downloaded files
model_path = '/home/model/'
vocab_path = '/home/vocab'

In [None]:
# set model path

pytorch_pretrained_bert.modeling.PRETRAINED_MODEL_ARCHIVE_MAP = {
   'bert-base-uncased': model_path+'bert-base-uncased.tar.gz',
 'bert-large-uncased': model_path+'bert-large-uncased.tar.gz',
 'bert-base-cased': model_path+'bert-base-cased.tar.gz',
 'bert-large-cased': model_path+'bert-large-cased.tar.gz',
 'bert-base-multilingual-uncased': model_path+'bert-base-multilingual-uncased.tar.gz',
 'bert-base-multilingual-cased': model_path+'bert-base-multilingual-cased.tar.gz',
 'bert-base-chinese': model_path+'bert-base-chinese.tar.gz'    
}

In [None]:
pytorch_pretrained_bert.tokenization.PRETRAINED_VOCAB_ARCHIVE_MAP = {
    'bert-base-uncased': vocab_path+"bert-base-uncased-vocab.txt",
    'bert-large-uncased': vocab_path+"bert-large-uncased-vocab.txt",
    'bert-base-cased': vocab_path+"bert-base-cased-vocab.txt",
    'bert-large-cased': vocab_path+"bert-large-cased-vocab.txt",
    'bert-base-multilingual-uncased': vocab_path+"bert-base-multilingual-uncased-vocab.txt",
    'bert-base-multilingual-cased': vocab_path+"bert-base-multilingual-cased-vocab.txt",
    'bert-base-chinese': vocab_path+"bert-base-chinese-vocab.txt",
}

### text preprocessing assuming that the input is raw text

In [None]:
def sentence_to_feature(sentence,seq_length=512):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    tokenized_text = tokenizer.tokenize(sentence)
    tokens = ['[CLS]'] + tokenized_text + ['[SEP]']
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    input_mask = [1] * len(input_ids)
            # Zero-pad up to the sequence length.
    while len(input_ids) < seq_length:
        input_ids.append(0)
        input_mask.append(0)
    
    return input_ids, input_mask

In [None]:
def sentences_to_features(sentence_batch):
    """Use this as the batch input to the forward method of  Akimoto_BERT"""
    input_ids_, input_mask_ = [], []
    for sentence in sentence_batch:
        input_ids, input_mask = sentence_to_feature(sentence)
        input_ids_.append(input_ids)
        input_mask_.append(input_mask)
    all_input_ids = torch.tensor([f for f in input_ids_], dtype=torch.long)
    all_input_mask = torch.tensor([f for f in input_mask_], dtype=torch.long)
    return all_input_ids, all_input_mask

### build model here based on bert embedding

In [None]:
class Akimoto_BERT(nn.Module):
    def __init__(self, data_parallel=True):
        bert = BertModel.from_pretrained("bert-base-uncased").to(device=torch.device("cuda"))
        if data_parallel:
            self.bert = torch.nn.DataParallel(bert)
    else:
        self.bert = bert
    # other init from akimoto model
    # droput, log_softmax...
    
    def forward(self,bert_batch):
        bert_ids, bert_mask = bert_batch
           
        segment_ids = torch.zeros_like(bert_mask)
        bert_last_layer = self.bert(bert_ids, bert_mask, segment_ids)[0][-1] # this is the bert embedding