In [15]:
import os
import sys
import torch
from senteval import utils
from simcse.models import Pooler, MLPLayer, Similarity
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions

In [16]:
DATA_TYPE = "qed"
PRO_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
TOKENIZER_PATH = os.path.join(PRO_PATH, 'fairseq_mo', 'utils')
if TOKENIZER_PATH not in sys.path:
    sys.path = [TOKENIZER_PATH] + sys.path
from tokenizer import selfies_tokenizer, atomwise_tokenizer, MoTokenizer
CHECKPOINT_PATH = os.path.join(PRO_PATH, 'checkpoints', DATA_TYPE, 'simcse')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [17]:
def sentemb_forward(
    cls,
    encoder,
    input_ids=None,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    labels=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):

    return_dict = return_dict if return_dict is not None else cls.config.use_return_dict

    outputs = encoder(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False,
        return_dict=True,
    )

    pooler_output = cls.pooler(attention_mask, outputs)
    if cls.pooler_type == "cls":
        pooler_output = cls.mlp(pooler_output)

    if not return_dict:
        return (outputs[0], pooler_output) + outputs[2:]

    return BaseModelOutputWithPoolingAndCrossAttentions(
        pooler_output=pooler_output,
        last_hidden_state=outputs.last_hidden_state,
        hidden_states=outputs.hidden_states,
    )

In [18]:
class BertForExtract(BertPreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"position_ids"]
    
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)

        self.pooler_type = "cls"
        self.pooler = Pooler("cls")
        self.mlp = MLPLayer(config)
        self.sim = Similarity(temp=0.05)
        self.init_weights()
        
    def forward(self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        sent_emb=False,
        mlm_input_ids=None,
        mlm_labels=None,
    ):
        if sent_emb:
            return sentemb_forward(self, self.bert,
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        else:
            return cl_forward(self, self.bert,
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                mlm_input_ids=mlm_input_ids,
                mlm_labels=mlm_labels,
            )

In [19]:
model = BertForExtract.from_pretrained(CHECKPOINT_PATH)
model = model.to(DEVICE)
tokenizer = MoTokenizer.from_pretrained(CHECKPOINT_PATH)

In [20]:
params = {'usepytorch': True, 'kfold': 10}
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
                                 'tenacity': 5, 'epoch_size': 4}
params = utils.dotdict(params)
params.usepytorch = True if 'usepytorch' not in params else params.usepytorch
params.seed = 1111 if 'seed' not in params else params.seed

params.batch_size = 128 if 'batch_size' not in params else params.batch_size
params.nhid = 0 if 'nhid' not in params else params.nhid
params.kfold = 5 if 'kfold' not in params else params.kfold

if 'classifier' not in params or not params['classifier']:
    params.classifier = {'nhid': 0}

assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!'

In [21]:
with open(os.path.join(os.getcwd(), DATA_TYPE, 'bin_data', 'dict.low.txt'), 'r') as fp:
    data = [line.strip() for line in fp.readlines()]
vocab = [item.split()[0] for item in data]

In [22]:
def batcher(batch, max_length=None):
    # Tokenization
    if max_length is not None:
        batch = tokenizer.batch_encode_plus(
            batch,
            return_tensors='pt',
            padding=True,
            max_length=max_length,
            truncation=True
        )
    else:
        batch = tokenizer.batch_encode_plus(
            batch,
            return_tensors='pt',
            padding=True,
        )

    # Move to the correct device
    for k in batch:
        batch[k] = batch[k].to(DEVICE)

    # Get raw embeddings
    with torch.no_grad():
        outputs = model(**batch, output_hidden_states=True, return_dict=True, sent_emb=True)
        last_hidden = outputs.last_hidden_state
        pooler_output = outputs.pooler_output
        hidden_states = outputs.hidden_states

    return pooler_output.cpu()

In [23]:
embeddings_dict = dict()
for i in range(0, len(vocab), params.batch_size):
    batch = vocab[i:i+params.batch_size]
    embeddings = batcher(batch, max_length=50)
    for item in zip(batch, embeddings):
        embeddings_dict[item[0]] = item[1]

In [24]:
file_content = [f"{len(embeddings_dict)} {len(embeddings_dict[vocab[0]])}"]
for item in embeddings_dict.items():
    emb = ' '.join([str(data) for data in item[1].tolist()])
    data_line = item[0] + ' ' + emb
    file_content.append(data_line)

In [25]:
with open(os.path.join(os.getcwd(), DATA_TYPE, 'emb_data', 'dict.emb'), 'w') as fp:
    fp.write('\n'.join(file_content))