In [20]:
!pip install pyecharts
!pip install transformers

import torch
import torch.nn.functional as F

Collecting transformers
  Downloading transformers-4.36.0-py3-none-any.whl.metadata (126 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting huggingface-hub<1.0,>=0.19.3 (from transformers)
  Downloading huggingface_hub-0.19.4-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<0.19,>=0.14 (from transformers)
  Downloading tokenizers-0.15.0-cp39-cp39-macosx_10_7_x86_64.whl.metadata (6.7 kB)
Collecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.1-cp39-cp39-macosx_10_7_x86_64.whl.metadata (3.8 kB)
Collecting fsspec>=2023.5.0 (from huggingface-hub<1.0,>=0.19.3->transformers)
  Downloading fsspec-2023.12.2-py3-none-any.whl.metadata (6.8 kB)
Downloading transformers-4.36.0-py3-none-any.whl (8.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m00:01[0m:00:01[0m
[?25hDownloading huggingface_

In [29]:
import numpy as np 
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from pyecharts import options as opts
from pyecharts.charts import Graph
from torch.utils.data import Dataset,TensorDataset,DataLoader
from keras_preprocessing.sequence import pad_sequences
from transformers import BertTokenizer, BertForSequenceClassification

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.filterwarnings('ignore')

In [22]:
marked_sentence_df = pd.read_csv('/Users/zhiyi/Desktop/540/covid19-knowledge-graph/marked_sentence.csv')

In [23]:
marked_sentence_df.sample(5)

Unnamed: 0,start_entity,end_entity,start_entity_type,end_entity_type,marked_sentence
31553,CD27,CD4,Gene,Gene,We employed high content colour flow cytometry...
36696,ANPEP,ENPEP,Gene,Gene,"Among them, start_entity and DPP4 are the know..."
7037,POMC,"Influenza, Human",Gene,Disease,Background end_entity antigenic point-of-care ...
12714,CCL8,CXCL11,Gene,Gene,"Our analysis identified start_entity, CXCL10, ..."
19578,Chickenpox,Guillain-Barre Syndrome,Disease,Disease,"Matei Bal in 2015 Alena-Andreea Popa, Georgeta..."


In [24]:
def Build_graph(df,relation=False,repulsion=40,title='COVID-19 knowledge graph',labelShow=False):
    entity_type_dic = dict(df.drop_duplicates(['start_entity']).set_index(['start_entity'])['start_entity_type'])
    entity_type_dic.update(dict(df.drop_duplicates(['end_entity']).set_index(['end_entity'])['end_entity_type']))
    color = {'Disease':'#FF7F50','Gene':'#48D1CC','Chemical':'#B3EE3A'}
    cate =  {'Disease':0,'Gene':1,'Chemical':2}
    categories = [{'name':'Disease','itemStyle': {'normal': {'color': color['Disease']}}},{'name':'Gene','itemStyle': {'normal': {'color': color['Gene']}}},{'name':'Chemical','itemStyle': {'normal': {'color': color['Chemical']}}}]
    nodes = []
    for entity in list(set(df['start_entity'])|set(df['end_entity'])):
        nodes.append({'name': entity, 'symbolSize': max(10,np.log1p(df.loc[(df['start_entity']==entity)|(df['end_entity']==entity)].shape[0])*10//1),
                     'category':cate[entity_type_dic[entity]]})
    links = []
    for i in df.index:
        if not relation:
            links.append({'source': df.loc[i,'start_entity'], 'target': df.loc[i,'end_entity']})
        else:
            links.append({'source': df.loc[i,'start_entity'], 'target': df.loc[i,'end_entity'], 'value':df.loc[i,'pred']})
    g = (
        Graph()
        .add('', nodes, links,categories, repulsion=repulsion,label_opts=opts.LabelOpts(is_show=labelShow))
        .set_global_opts(title_opts=opts.TitleOpts(title=title),legend_opts=opts.LegendOpts(orient='vertical', pos_left='2%', pos_top='40%',legend_icon='circle'))
        .render_notebook()
        )
    return g

In [25]:
g = Build_graph(marked_sentence_df.sample(100),title='subsample of topology graph')
g

In [26]:
class Args:
    task_type = 'chemical-disease'
    max_seq_len = 64
    bs = 64

class Conf:
    # some information can be found in:
    # Percha B, Altman R B. A global network of biomedical relationships derived from text[J]. Bioinformatics, 2018, 34(15): 2614-2624.
    relation_type = {'chemical-disease':['T', 'C', 'Sa', 'Pr', 'Pa', 'J'],
                     'disease-chemical':['Mp'],
                     'chemical-gene':['A+', 'A-', 'B', 'E+', 'E-', 'E', 'N'],
                     'gene-chemical':['O', 'K', 'Z'],
                     'gene-disease':['U', 'Ud', 'D', 'J', 'Te', 'Y', 'G'],
                     'disease-gene':['Md', 'X', 'L'],
                     'gene-gene':['B', 'W', 'V+', 'E+', 'E', 'I', 'H', 'Rg', 'Q'],
                     }

args = Args()
conf = Conf()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [32]:
# load pretrained Bert model
def Bert_model(taskType,bertPath):
    label_df = pd.read_csv('/Users/zhiyi/Desktop/540/covid19-knowledge-graph/%s_label.csv'%taskType)
    tokenizer = BertTokenizer.from_pretrained(bertPath,do_lower_case=False)
    model = BertForSequenceClassification.from_pretrained(bertPath, num_labels=label_df['label'].nunique())
    return label_df,tokenizer,model

# bulid data loader
def Data_loader(x,y=None,bs=128,shuffle=False,numWorkers=0):
    if y is not None:
        data = TensorDataset(x,y)
    else:
        data = TensorDataset(x)
    data_loader = DataLoader(dataset=data,batch_size=bs,shuffle=shuffle,num_workers=numWorkers)
    return data_loader

def Prepare_predict_data(tokenizer,bs):
    marked_sentences = marked_sentence_df.loc[(marked_sentence_df['start_entity_type'].apply(lambda x:x.lower())==args.task_type.split('-')[0])&\
                                              (marked_sentence_df['end_entity_type'].apply(lambda x:x.lower())==args.task_type.split('-')[1]),'marked_sentence']
    # convert tokens to ids
    ids = marked_sentences.apply(lambda x:tokenizer.convert_tokens_to_ids(tokenizer.tokenize(x))).tolist()
    # padding ids
    ids = pad_sequences(ids,args.max_seq_len, truncating='post', padding='post')
    # we cannot confirm order of entities, so predict two possibilities
    reverse_marked_sentences = marked_sentence_df.loc[(marked_sentence_df['start_entity_type'].apply(lambda x:x.lower())==args.task_type.split('-')[0])&\
                                              (marked_sentence_df['end_entity_type'].apply(lambda x:x.lower())==args.task_type.split('-')[1]),'marked_sentence']\
                                              .apply(lambda x:x.replace('start_entity','init_start_entity').replace('end_entity','start_entity').replace('init_start_entity','end_entity'))
    reverse_ids = reverse_marked_sentences.apply(lambda x:tokenizer.convert_tokens_to_ids(tokenizer.tokenize(x))).tolist()
    reverse_ids = pad_sequences(reverse_ids,args.max_seq_len, truncating='post', padding='post')
    predict_data_loader = Data_loader(torch.LongTensor(ids),torch.LongTensor(reverse_ids),bs=bs)
    return marked_sentences.values,predict_data_loader

def Predict():
    reverse_task_type = args.task_type.split('-')[1] + '-' + args.task_type.split('-')[0]
    def Filter(x):
        if x['init_pred'] in conf.relation_type[args.task_type]:
            if x['reverse_pred'] not in conf.relation_type[reverse_task_type]:
                # init_pred is a correct relation but reverse_pred not
                return 'init_pred'
            else:
                # init_pred and reverse_pred both are correct relations
                if x['init_pred_prob'] >= x['reverse_pred_prob']:
                    # init_pred_prob greater than or equal to reverse_pred_prob
                    return 'init_pred'
                else:
                    return 'reverse_pred'
        else:
            if x['reverse_pred'] not in conf.relation_type[reverse_task_type]:
                # init_pred and reverse_pred both are uncorrect relations
                return 'uncorrect'
            else:
                # reverse_pred is a correct relation but init_pred not
                return 'reverse_pred'
    label_df,tokenizer,model = Bert_model(args.task_type,'/Users/zhiyi/Desktop/540/covid19-knowledge-graph/%s/'%args.task_type)
    marked_sentences,predict_data_loader = Prepare_predict_data(tokenizer,args.bs)
    model = model.to(device)
    preds = []
    preds_prob = []
    reverse_preds = []
    reverse_preds_prob = []
    for data in tqdm(predict_data_loader):
        ids,reverse_ids = [t.to(device) for t in data]
        outputs = model(input_ids=ids)
        logits = outputs[0]
        pred_prob, pred = torch.max(F.softmax(logits.data,1), 1)
        preds.extend(list(pred.cpu().detach().numpy()))
        preds_prob.extend(list(pred_prob.cpu().detach().numpy()))
        reverse_outputs = model(input_ids=reverse_ids)
        reverse_logits = reverse_outputs[0]
        reverse_pred_prob, reverse_pred = torch.max(F.softmax(reverse_logits.data,1), 1)
        reverse_preds.extend(list(reverse_pred.cpu().detach().numpy()))
        reverse_preds_prob.extend(list(reverse_pred_prob.cpu().detach().numpy()))

    pred_df = pd.DataFrame({'marked_sentence':marked_sentences,'init_pred':preds,'init_pred_prob':preds_prob,'reverse_pred':reverse_preds,'reverse_pred_prob':reverse_preds_prob})
    # map label(0, 1, 2...) to raw label(T, C, Sa...) 
    pred_df['init_pred'] = pred_df['init_pred'].replace(dict(label_df.set_index(['label'])['label_raw']))
    pred_df['reverse_pred'] = pred_df['reverse_pred'].replace(dict(label_df.set_index(['label'])['label_raw']))
    # judge the order of a pair of entities
    pred_df['filter'] = pred_df.apply(lambda x:Filter(x), axis=1)
    pred_df['pred'] = pred_df['init_pred']
    pred_df['pred_prob'] = pred_df['init_pred_prob']
    pred_df.loc[pred_df['filter']=='reverse_pred','pred'] = pred_df.loc[pred_df['filter']=='reverse_pred','reverse_pred']
    pred_df.loc[pred_df['filter']=='reverse_pred','pred_prob'] = pred_df.loc[pred_df['filter']=='reverse_pred','reverse_pred_prob']
    pred_df = pred_df.loc[pred_df['filter']!='uncorrect']
    pred_df = marked_sentence_df.merge(pred_df,how='inner',on='marked_sentence')
    pred_df['init_start_entity'] = pred_df['start_entity']
    pred_df['init_start_entity_type'] = pred_df['start_entity_type']
    pred_df.loc[pred_df['filter']=='reverse_pred','start_entity'] = pred_df.loc[pred_df['filter']=='reverse_pred','end_entity']
    pred_df.loc[pred_df['filter']=='reverse_pred','start_entity_type'] = pred_df.loc[pred_df['filter']=='reverse_pred','end_entity_type']
    pred_df.loc[pred_df['filter']=='reverse_pred','end_entity'] = pred_df.loc[pred_df['filter']=='reverse_pred','init_start_entity']
    pred_df.loc[pred_df['filter']=='reverse_pred','end_entity_type'] = pred_df.loc[pred_df['filter']=='reverse_pred','init_start_entity_type']
    pred_df.drop(['init_start_entity','init_start_entity_type'],axis=1,inplace=True)
    torch.cuda.empty_cache()
    return label_df,pred_df

In [None]:
# chemical-disease relation prediction
args.task_type = 'chemical-disease'
c_d_label_df,c_d_pred_df = Predict()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/15 [00:00<?, ?it/s]

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


In [None]:
# chemical-disease relation theme
c_d_label_df

In [None]:
# chemical-disease classification results
c_d_pred_df.sample(5)

In [None]:
# chemicl-COVID-19 relations
g = Build_graph(c_d_pred_df.loc[(c_d_pred_df['start_entity']=='COVID-19')|(c_d_pred_df['end_entity']=='COVID-19')],relation=True,repulsion=800,title='chemical-COVID-19 knowledge graph',labelShow=True)
g

In [None]:
# 1. What chemicals have Sa(Side effect/adverse event) to COVID-19 and revelant disease?

# due to a large number of negative words, these predictions may be wrong
c_d_pred_df.loc[c_d_pred_df['pred']=='Sa']

In [None]:
# 1. What chemicals have Pa(Alleviates, reduces), Pr(Prevents, suppresses), T(Treatment/therapy (incl. investigatory)) to COVID-19 and revelant disease?

c_d_pred_df.loc[(c_d_pred_df['end_entity']=='COVID-19')&(c_d_pred_df['pred'].isin(['Pa','Pr','T']))]