In [1]:
# use sciemb conda env: conda activate sciemb

from transformers import AutoTokenizer, AutoModel
from  cogdl import oagbert
import torch
from datetime import datetime
import json
import glob

Failed to load C version of sampling, use python version instead.
Failed to load fast version of SpMM, use torch.scatter_add instead.


In [2]:
class DataEncoder:

    def __init__(self, model='specter_hf'):
        # Init the tokenizer and encoding the model
        if model == 'specter_hf':
            self.load_specter_hf()
        elif model == 'scibert':
            self.load_scibert()
        elif model == 'oagbert':
            self.load_oagbert()
        
        # self.model.to('cuda:1')
        self.model.eval()
        return

    def update_encoding_model(self, model_name):
        if model_name == 'specter_hf':
            self.load_specter()
        elif model_name == 'scibert':
            self.load_scibert()
        elif model_name == 'oagbert':
            self.load_oagbert()
        else:
            print('Invalid model name! Please try again with a valid  model name.')

    def load_specter_hf(self):
        self.model_name = 'specter_hf'
        self.tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
        self.model = AutoModel.from_pretrained('allenai/specter')
        return

    def load_scibert(self):
        self.model_name = 'scibert'
        self.tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_cased')
        self.model = AutoModel.from_pretrained('allenai/scibert_scivocab_cased')
        return

    def load_oagbert(self):
        self.model_name = 'oagbert'
        self.tokenizer, self.model = oagbert("oagbert-v2")
        return

    def encode_batch_wise_using_specter(self, document_dict, to_save_loc, batch_size=20):
        # List to contain small batch of douments and the corresponding doc ids
        doc_batch_list = []
        batch_ids = []

        document_emb_dict = {}

        for _, d in enumerate(document_dict):
            if _ % 2000 == 0:
                print("Encoded: {} @ {}".format(_, datetime.now().strftime('%Y-%m-%d %H:%M:%S')))

            doc_batch_list.append((document_dict[d].get('title') or '') + self.tokenizer.sep_token + (document_dict[d].get('abstract') or ''))
            batch_ids.append(d)

            if _%batch_size == 0:
                inputs = self.tokenizer(doc_batch_list, padding=True, truncation=True, return_tensors="pt", max_length=512)
                # inputs = inputs.to('cuda:1')
                result = self.model(**inputs)
                embeddings = result.last_hidden_state[:, 0, :]

                for ii, k in enumerate(batch_ids):
                    document_emb_dict[k] = embeddings[ii].detach().cpu().numpy().tolist()

                doc_batch_list = []
                batch_ids = []

        if batch_ids:
            inputs = self.tokenizer(doc_batch_list, padding=True, truncation=True, return_tensors="pt", max_length=512)
            # inputs = inputs.to('cuda:1')
            result = self.model(**inputs)
            embeddings = result.last_hidden_state[:, 0, :]

            for _, k in enumerate(batch_ids):
                document_emb_dict[k] = embeddings[_].detach().cpu().numpy().tolist()

        # To freeup memory in case of reuse in jupyter
        doc_batch_list = []
        batch_ids = []

        with open(to_save_loc, 'w') as outfile:
            for _, entry in enumerate(document_emb_dict):
                try:
                    outfile.write(json.dumps({"paper_id": entry, "embedding": document_emb_dict[entry]}) + '\n')
                except:
                    print(_, entry)
                    continue
        return document_emb_dict

In [4]:
de = DataEncoder()

all_task_files = glob.glob("./data/*.json")

In [5]:
all_task_files

['./data/paper_metadata_mag_mesh.json',
 './data/paper_metadata_recomm.json',
 './data/paper_metadata_view_cite_read.json']

In [None]:
for f in all_task_files:
    with open(f, 'r') as fin:
        data_dict = json.load(fin)

        out_file_name = f.rsplit("/", 1)[-1]
        out_file_name = out_file_name.replace("paper_metadata_", "")
        out_file_name = out_file_name.replace(".json", "")
        
        sd = de.encode_batch_wise_using_specter(data_dict, './data/shf-embeddings/{}.jsonl'.format(out_file_name))

Encoded: 0 @ 2022-06-17 13:19:10
Encoded: 2000 @ 2022-06-17 13:36:27
Encoded: 4000 @ 2022-06-17 13:54:24
Encoded: 6000 @ 2022-06-17 14:12:06
Encoded: 8000 @ 2022-06-17 14:29:53
Encoded: 10000 @ 2022-06-17 14:47:01
Encoded: 12000 @ 2022-06-17 15:04:05
Encoded: 14000 @ 2022-06-17 15:21:59
Encoded: 16000 @ 2022-06-17 15:39:58
Encoded: 18000 @ 2022-06-17 15:57:17
Encoded: 20000 @ 2022-06-17 16:15:10
Encoded: 22000 @ 2022-06-17 16:32:58
Encoded: 24000 @ 2022-06-17 16:50:28
Encoded: 26000 @ 2022-06-17 17:08:48
Encoded: 28000 @ 2022-06-17 17:28:32
Encoded: 30000 @ 2022-06-17 17:47:26
Encoded: 32000 @ 2022-06-17 18:07:29
Encoded: 34000 @ 2022-06-17 18:27:13
Encoded: 36000 @ 2022-06-17 18:46:49
Encoded: 38000 @ 2022-06-17 19:06:17
Encoded: 40000 @ 2022-06-17 19:27:14
Encoded: 42000 @ 2022-06-17 19:48:03
Encoded: 44000 @ 2022-06-17 20:09:24
Encoded: 46000 @ 2022-06-17 20:30:42
Encoded: 48000 @ 2022-06-17 20:51:27
Encoded: 0 @ 2022-06-17 20:56:55
Encoded: 2000 @ 2022-06-17 21:16:04
Encoded: 4000 

In [7]:
print("Done")

Done


In [None]:
## trying to add titles to see if it makes sense

In [None]:
# all_task_files = glob.glob("./data/*.json")

# for f in all_task_files:
#     with open(f, 'r') as fin:
#         data_dict = json.load(fin)
#         out_file_name = f.rsplit("/", 1)[-1]
#         out_file_name = out_file_name.replace("paper_metadata_", "")
#         out_file_name = out_file_name.replace(".json", "")
        
#         document_emb_dict = {}

#         with open("./data/" + out_file_name, 'r') as fi:
#             for line in tqdm(fi, desc='reading embeddings from file...'):
#             line_json = json.loads(line)
#             document_emb_dict['paper_id']
        
#          with open("./data/Title_" + out_file_name, 'w') as outfile:
#             for _, entry in enumerate(document_emb_dict):
#                 try:
#                     outfile.write(json.dumps({"paper_id": entry, "embedding": document_emb_dict[entry]}) + '\n')
#                 except:
#                     print(_, entry)
#                     continue
#         sd = de.encode_batch_wise_using_specter(data_dict, './data/shf-embeddings/{}.jsonl'.format(out_file_name))