In [1]:
# use sciemb conda env: conda activate sciemb

from transformers import AutoTokenizer, AutoModel
from  cogdl.oag import oagbert
import torch
from datetime import datetime
import json
import glob 

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
class DataEncoder:

    def __init__(self, model='specter_hf'):
        # Init the tokenizer and encoding the model
        if model == 'specter_hf':
            self.load_specter_hf()
        elif model == 'scibert':
            self.load_scibert()
        elif model == 'oagbert':
            self.load_oagbert()
        
        # self.model.to('cuda:1')
        self.model.eval()
        return

    def update_encoding_model(self, model_name):
        if model_name == 'specter_hf':
            self.load_specter()
        elif model_name == 'scibert':
            self.load_scibert()
        elif model_name == 'oagbert':
            self.load_oagbert()
        else:
            print('Invalid model name! Please try again with a valid  model name.')

    def load_specter_hf(self):
        self.model_name = 'specter_hf'
        self.tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
        self.model = AutoModel.from_pretrained('allenai/specter')
        return

    def load_scibert(self):
        self.model_name = 'scibert'
        self.tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_cased')
        self.model = AutoModel.from_pretrained('allenai/scibert_scivocab_cased')
        return

    def load_oagbert(self):
        self.model_name = 'oagbert'
        self.tokenizer, self.model = oagbert("oagbert-v2")
        return

    def encode_batch_wise_using_oagbert(self, document_dict, to_save_loc):
        document_emb_dict = {}

        for _, d in enumerate(document_dict):
            if _ % 2000 == 0:
                print("Encoded: {} @ {}".format(_, datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
            
            # document_list is actually a dictionatry. org_pid_seq is not provided
            paper_title = (document_dict[d].get('title') or '')
            paper_abstract = (document_dict[d].get('abstract') or '')
            document_id = d

            #check if title+abs is empty
            title_abs = paper_title + paper_abstract
            if not title_abs.strip():
                continue

            input_ids, input_masks, token_type_ids, masked_lm_labels, position_ids, position_ids_second, \
            masked_positions, num_spans = self.model.build_inputs(
                title=paper_title, abstract=paper_abstract, venue=[], authors=[], concepts=[], affiliations=[])
            sequence_output, pooled_output = self.model.bert.forward(
                input_ids=torch.LongTensor(input_ids).unsqueeze(0),
                token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0),
                attention_mask=torch.LongTensor(input_masks).unsqueeze(0),
                output_all_encoded_layers=False,
                checkpoint_activations=False,
                position_ids=torch.LongTensor(position_ids).unsqueeze(0),
                position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0)
            )

            embedding = pooled_output.detach().numpy().tolist()
            document_emb_dict[document_id] = embedding

            if _%4000 == 0:
                print("Dumping 4k lines to file.")
                with open(to_save_loc, 'a') as outfile:
                    for _, entry in enumerate(document_emb_dict):
                        try:
                            outfile.write(json.dumps({"paper_id": entry, "title": (document_dict[entry].get('title') or ''), "embedding": document_emb_dict[entry]}) + '\n')
                        except:
                            print(_, entry)
                            continue
                document_emb_dict = {}

        with open(to_save_loc, 'a') as outfile:
            for _, entry in enumerate(document_emb_dict):
                try:
                    outfile.write(json.dumps({"paper_id": entry, "title": (document_dict[entry].get('title') or ''), "embedding": document_emb_dict[entry]}) + '\n')
                except:
                    print(_, entry)
                    continue

        return document_emb_dict

    def encode_batch_wise_using_specter(self, document_dict, to_save_loc, batch_size=20):
        # List to contain small batch of douments and the corresponding doc ids
        doc_batch_list = []
        batch_ids = []

        document_emb_dict = {}

        for _, d in enumerate(document_dict):
            if _ % 2000 == 0:
                print("Encoded: {} @ {}".format(_, datetime.now().strftime('%Y-%m-%d %H:%M:%S')))

            doc_batch_list.append((document_dict[d].get('title') or '') + self.tokenizer.sep_token + (document_dict[d].get('abstract') or ''))
            batch_ids.append(d)

            if _%batch_size == 0:
                inputs = self.tokenizer(doc_batch_list, padding=True, truncation=True, return_tensors="pt", max_length=512)
                # inputs = inputs.to('cuda:1')
                result = self.model(**inputs)
                embeddings = result.last_hidden_state[:, 0, :]

                for ii, k in enumerate(batch_ids):
                    document_emb_dict[k] = embeddings[ii].detach().cpu().numpy().tolist()

                doc_batch_list = []
                batch_ids = []

        if batch_ids:
            inputs = self.tokenizer(doc_batch_list, padding=True, truncation=True, return_tensors="pt", max_length=512)
            # inputs = inputs.to('cuda:1')
            result = self.model(**inputs)
            embeddings = result.last_hidden_state[:, 0, :]

            for _, k in enumerate(batch_ids):
                document_emb_dict[k] = embeddings[_].detach().cpu().numpy().tolist()

        # To freeup memory in case of reuse in jupyter
        doc_batch_list = []
        batch_ids = []

        with open(to_save_loc, 'w') as outfile:
            for _, entry in enumerate(document_emb_dict):
                try:
                    outfile.write(json.dumps({"paper_id": entry, "title": (document_dict[entry].get('title') or ''), "embedding": document_emb_dict[entry]}) + '\n')
                except:
                    print(_, entry)
                    continue
        return document_emb_dict

In [None]:
de = DataEncoder(model='oagbert')

all_task_files = glob.glob("./data/*.json")
for f in all_task_files:
    with open(f, 'r') as fin:
        data_dict = json.load(fin)

        out_file_name = f.rsplit("/", 1)[-1]
        out_file_name = out_file_name.replace("paper_metadata_", "")
        out_file_name = out_file_name.replace(".json", "")
        
        sd = de.encode_batch_wise_using_oagbert(data_dict, './data/oag-embeddings/{}.jsonl'.format(out_file_name))

file saved/oagbert-v2/config.json not found


Encoded: 0 @ 2022-06-17 13:18:57
Encoded: 2000 @ 2022-06-17 17:11:42
Encoded: 4000 @ 2022-06-17 20:14:49
Encoded: 6000 @ 2022-06-17 23:24:12


In [7]:
print("Done")

Done


### Encode it again with the new oagbert embeddings in the oagbert conda environment

In [12]:
de = DataEncoder(model='oagbert')

all_task_files = glob.glob("./data/*.json")

In [14]:
all_task_files, all_task_files[-1:]

(['./data/paper_metadata_mag_mesh.json',
  './data/paper_metadata_recomm.json',
  './data/paper_metadata_view_cite_read.json'],
 ['./data/paper_metadata_view_cite_read.json'])

In [17]:
for f in all_task_files[-1:]:
    with open(f, 'r') as fin:
        data_dict = json.load(fin)

        out_file_name = f.rsplit("/", 1)[-1]
        out_file_name = out_file_name.replace("paper_metadata_", "")
        out_file_name = out_file_name.replace(".json", "")
        
        print("Starting to encode!")
        
        sd = de.encode_batch_wise_using_oagbert(data_dict, './data/trained_embs/newoag-embeddings/{}.jsonl'.format(out_file_name))

Starting to encode!
Encoded: 0 @ 2022-06-25 13:41:24
Dumping 4k lines to file.
Encoded: 2000 @ 2022-06-25 14:10:26
Encoded: 4000 @ 2022-06-25 14:29:39
Dumping 4k lines to file.
Encoded: 6000 @ 2022-06-25 14:35:16
Encoded: 8000 @ 2022-06-25 14:40:49
Dumping 4k lines to file.
Encoded: 10000 @ 2022-06-25 14:46:13
Encoded: 12000 @ 2022-06-25 14:51:30
Dumping 4k lines to file.
Encoded: 14000 @ 2022-06-25 14:57:15
Encoded: 16000 @ 2022-06-25 15:02:59
Dumping 4k lines to file.
Encoded: 18000 @ 2022-06-25 15:08:36
Encoded: 20000 @ 2022-06-25 15:14:26
Dumping 4k lines to file.
Encoded: 22000 @ 2022-06-25 15:19:58
Encoded: 24000 @ 2022-06-25 15:25:28
Dumping 4k lines to file.
Encoded: 26000 @ 2022-06-25 15:31:14
Encoded: 28000 @ 2022-06-25 15:36:47
Dumping 4k lines to file.
Encoded: 30000 @ 2022-06-25 15:42:16
Encoded: 32000 @ 2022-06-25 15:47:32
Dumping 4k lines to file.
Encoded: 34000 @ 2022-06-25 16:16:08
Encoded: 36000 @ 2022-06-25 16:33:39
Dumping 4k lines to file.
Encoded: 38000 @ 2022-06-

In [18]:
print("Done!")

Done!
