In [None]:
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

class BertBiEncoder(nn.Module):
    def __init__(self, mention_bert, candidate_bert):
        super().__init__()
        self.mention_bert = mention_bert
        self.candidate_bert = candidate_bert
        
    def forward(self, input_ids, attention_mask, is_mention=True, shard_bsz=None):
        if is_mention:
            model = self.mention_bert
        else:
            model = self.candidate_bert
            
        if shard_bsz is None:
            bertrep, _ = model(input_ids, attention_mask=attention_mask)
            bertrep = bertrep[:, 0, :]
        return bertrep



class BertCandidateGenerator(object):
    def __init__(self, biencoder, pages, device="cpu"):
        self.model = biencoder.to(device)
        self.pages = pages
        self.device = device
        
    def train(self,
              mention_dataset,
              candidate_dataset,
              inbatch=True,
              lr=1e-5,
              batch_size=32,
              max_ctxt_len=32
             ):
        
        mention_batch = mention_dataset.batch(batch_size=batch_size, max_ctxt_len=max_ctxt_len)
        
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        if inbatch:
            all_loss = []
            steps = 0
            for input_ids, labels in mention_batch:
                inputs = pad_sequence([torch.LongTensor(token)
                                      for token in input_ids], padding_value=0).t().to(self.device)

                candidate_input_ids = candidate_dataset.get_pages(labels, max_title_len=50, max_desc_len=100)
                candidate_inputs = pad_sequence([torch.LongTensor(token)
                                                for token in candidate_input_ids], padding_value=0).t().to(self.device)
                
                scores = inputs.mm(candidate_inputs.t())
                
                target = torch.LongTensor(torch.arange(scores.size(1))).to(self.device)
                loss = F.cross_entropy(scores, target, reduction="mean")
                
                all_loss.append(loss.item())
                loss.backward()
                optimizer.step()
                
        return all_loss
        

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch

input_ids = torch.tensor([[0,1,2], [1,2,3]])
attention_mask = input_ids > 0

mention_bert = AutoModel.from_pretrained('cl-tohoku/bert-base-japanese')
candidate_bert = AutoModel.from_pretrained('cl-tohoku/bert-base-japanese')
biencoder = BertBiEncoder(mention_bert, candidate_bert)
biencoder(input_ids, attention_mask)