### imports
***

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from sentence_transformers import SentenceTransformer, models
from transformers import AdamW, get_linear_schedule_with_warmup

import multiprocessing as mp
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import csv
import random
import matplotlib.pyplot as plt
import sys
import pickle

### load data
***

In [2]:
data_path = '../../data_2020/'
model_path = 'models/'
other_path = 'others/'

In [3]:
use_cuda = torch.cuda.is_available()
no = "0"
if use_cuda:
    print("using cuda!")
device = torch.device("cuda:"+no) if use_cuda else torch.device("cpu")

# candidate_paper = pd.read_csv(data_path+"candidate_paper_pre.csv")
train_data = pd.read_csv(data_path+"train_pre.csv")
valid_data = pd.read_csv(data_path+"test_pre.csv")
# candidate_paper = candidate_paper[~candidate_paper['paper_id'].isnull()]

using cuda!


In [4]:
train_data = train_data.fillna('none')
valid_data = valid_data.fillna('none')
# candidate_paper = candidate_paper.fillna('none')

In [5]:
train_data.head()

Unnamed: 0,description_id,paper_id,description_text,key_text,key_text_pre,description_text_pre
0,77bef2,5c0f7919da562944ac759a0f,Angiogenesis is reflected as newly formed vess...,"Moreover, Wnt-1-inducible secreted protein-1 (...",moreover wnt-1-inducible secrete protein-1 wis...,angiogenesis reflect newly form vessel endothe...
1,42360e,5c1360beda56295a0896fda3,Cardiac fibrosis is a common process in remode...,There is evidence showing that the down-regula...,there evidence show down-regulation β-catenin ...,cardiac fibrosis common process remodel heart ...
2,9bf5e0,5d1b36e83a55ac0a0e8bb84e,"Agmatine, formed by the decarboxylation of L-a...","Agmatine, formed by the decarboxylation of L-a...",agmatine form decarboxylation l-arginine argin...,agmatine form decarboxylation l-arginine argin...
3,22e485,5d2709fd3a55ac2cfc28108f,The ob gene product leptin has been demonstrat...,"The aminoguanidine carboxylate, BVT.12777 (Fig...",the aminoguanidine carboxylate bvt.12777 figur...,the ob gene product leptin demonstrate activat...
4,30856c,55a392d1c91b587b095b6fcc,"Lauterbach M et al., have concluded at the end...","Lauterbach M et , have concluded at the end of...",lauterbach m et conclude end study germany ana...,lauterbach m et al. conclude end study germany...


In [6]:
valid_data.head()

Unnamed: 0,description_id,description_text,key_text,key_text_pre,description_text_pre
0,00032c,Refer to Table 2 or Methods for a brief descri...,Colons (:) indicated interaction terms..,colon indicate interaction terms..,refer table method brief description variable ...
1,000676,Sixty-nine female subjects with a mean age of ...,Handedness was evaluated according to the proc...,handedness evaluate accord procedure propose a...,sixty-nine sixty nine female subject mean age ...
2,000b24,Our behavioral and imaging findings differed f...,"Recently, Chiu et used a modified IGT, namel...",recently chiu et use modified IGT namely sooch...,our behavioral imaging finding differ previou ...
3,000c20,A novel Ehrlichia transmitted by Amblyomma ame...,"ruminantium, caused transient febrile illness,...",ruminantium cause transient febrile illness fo...,A novel ehrlichium transmit amblyomma american...
4,000c90,The dorsal fronto-striatal circuit plays an im...,"One of these functions is set-shifting, which ...",one function set-shifting set shifting refer a...,the dorsal fronto-striatal fronto striatal cir...


### model
***

In [7]:
word_embedding_model = models.BERT('scibert_scivocab_uncased/')
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)
encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model],
                              device=device)

In [8]:
def random_vec():
    return np.random.normal(0, 0.1, 768)

paper2embedding = {}
description2embedding = {}

# for i, r in tqdm(candidate_paper.iterrows(), total=candidate_paper.shape[0]):
#     paper2embedding[r['paper_id']] = {}
#     # abstract
#     ab = r['abstract']
#     if ab == 'NO_CONTENT' or ab == 'none':
#         paper2embedding[r['paper_id']]['abstract'] = random_vec()
#     else:
#         paper2embedding[r['paper_id']]['abstract'] = encoder.encode([ab])[0]
#     # title
#     title = r['title']
#     paper2embedding[r['paper_id']]['title'] = encoder.encode([title])[0]

In [9]:
# with open(other_path+'paper2embedding.pkl', 'wb') as f:
#     pickle.dump(paper2embedding, f)

In [10]:
for i, r in tqdm(train_data.iterrows(), total=train_data.shape[0]):
    description2embedding[r['description_id']+'_train'] = {}
    # description_text
    dcp = r['description_text']
    if dcp == 'none':
        description2embedding[r['description_id']+'_train']['description_text'] = random_vec()
    else:
        description2embedding[r['description_id']+'_train']['description_text'] = encoder.encode([dcp])[0]
    # key_text
    key = r['key_text']
    if key == 'none':
        description2embedding[r['description_id']+'_train']['key_text'] = random_vec()
    else:
        description2embedding[r['description_id']+'_train']['key_text'] = encoder.encode([key])[0]
        
for i, r in tqdm(valid_data.iterrows(), total=valid_data.shape[0]):
    description2embedding[r['description_id']+'_test'] = {}
    # description_text
    dcp = r['description_text']
    if dcp == 'none':
        description2embedding[r['description_id']+'_test']['description_text'] = random_vec()
    else:
        description2embedding[r['description_id']+'_test']['description_text'] = encoder.encode([dcp])[0]
    # key_text
    key = r['key_text']
    if key == 'none':
        description2embedding[r['description_id']+'_test']['key_text'] = random_vec()
    else:
        description2embedding[r['description_id']+'_test']['key_text'] = encoder.encode([key])[0]

HBox(children=(IntProgress(value=0, max=62974), HTML(value='')))




HBox(children=(IntProgress(value=0, max=34428), HTML(value='')))




In [11]:
with open(other_path+'description2embedding.pkl', 'wb') as f:
    pickle.dump(description2embedding, f)