# Load the pred_T file and the gold file

In [1]:
pred_file = 'analysis/predT_not_in_original_pmid_predTs.json'
gold_file = 'data/KD-DTI/raw/test.json'
pmid = 'data/KD-DTI/raw/relis_test.pmid'

In [2]:
import json

# load the pmids that the pred_Ts are not in the original articles
with open (pred_file, 'r') as f:
    pred_d_not_in_original = json.load(f)

print(len(pred_d_not_in_original))
print(pred_d_not_in_original[0])

823
{'id': '11169165', 'predT_not_in_original': ['monoamine oxidase type b (mao-b)', 'monoamine oxidase type a (mao-a)']}


In [3]:
# load the gold standard
with open (gold_file, 'r') as f:
    gold_d = json.load(f)

print(len(gold_d))
print(gold_d[pred_d_not_in_original[0]['id']])

pmids = []
with open (pmid, 'r') as f:
    for line in f:
        line = line.rstrip()
        pmids.append(line)

1159
{'title': 'Inhibition of rat brain monoamine oxidase activities by psoralen and isopsoralen: implications for the treatment of affective disorders.', 'abstract': 'Psoralen and isopsoralen, furocoumarins isolated from the plant Psoralea corylifolia L., were demonstrated to exhibit in vitro inhibitory actions on monoamine oxidase (MAO) activities in rat brain mitochondria, preferentially inhibiting MAO-A activity over MAO-B activity. This inhibition of enzyme activities was found to be dose-dependent and reversible. For MAO-A, the IC50 values are 15.2 +/- 1.3 microM psoralen and 9.0 +/- 0.6 microM isopsoralen. For MAO-B, the IC50 values are 61.8 +/- 4.3 microM psoralen and 12.8 +/- 0.5 microM isopsoralen. Lineweaver-Burk transformation of the inhibition data indicates that inhibition by both psoralen and isopsoralen is non-competitive for MAO-A. The Ki values were calculated to be 14.0 microM for psoralen and 6.5 microM for isopsoralen. On the other hand, inhibition by both psoralen

# Load the model

In [4]:
import torch
from src.transformer_lm_prompt import TransformerLanguageModelPrompt
m = TransformerLanguageModelPrompt.from_pretrained(
        "checkpoints/RE-DTI-BioGPT", 
        "checkpoint_avg.pt", 
        "data/KD-DTI/relis-bin",
        tokenizer='moses', 
        bpe='fastbpe', 
        bpe_codes="data/bpecodes",
        max_len_b=1024,
        beam=5)
m.cuda()

2023-04-25 23:54:30 | INFO | fairseq.file_utils | loading archive file checkpoints/RE-DTI-BioGPT
2023-04-25 23:54:30 | INFO | fairseq.file_utils | loading archive file data/KD-DTI/relis-bin
2023-04-25 23:54:32 | INFO | src.language_modeling_prompt | dictionary: 42384 types
2023-04-25 23:54:35 | INFO | fairseq.models.fairseq_model | {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 100, 'log_format': None, 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': False, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'user_dir': '../../src', 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_path': None, 'profile': False, 'reset_logging': False

GeneratorHubInterface(
  (models): ModuleList(
    (0): TransformerLanguageModelPrompt(
      (decoder): TransformerDecoder(
        (dropout_module): FairseqDropout()
        (embed_tokens): Embedding(42393, 1024, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
        (layers): ModuleList(
          (0-23): 24 x TransformerDecoderLayerBase(
            (dropout_module): FairseqDropout()
            (self_attn): MultiheadAttention(
              (dropout_module): FairseqDropout()
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (activation_dropout_module): FairseqDropout()
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwis

In [5]:
# because it's hard to use the moses tokenizer.decode() to show the different between the 4 (= 4#) and 4</w>
# so here using the tokenizer from HF to decode each generated token, and it doesn't include the learn0 - learn9, which is not a problem
from transformers import BioGptTokenizer

tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")

  from .autonotebook import tqdm as notebook_tqdm


# inference step by step

In [6]:
import numpy as np
import matplotlib.pyplot as plt

In [7]:
# a test data: {"pmid": ,
#               "id": ,
#               "title+abstract.lower()": ,
#               "text_tokens": ,
#               "pred_Ts": ,
#              "pred_Ts_tokens": ,}
#               "gold_triples": ,}

def get_test_data(id):
    prefix = torch.arange(42384, 42393)
    test_data = {}
    test_data['pmid'] = pmids[id]
    test_data['text'] = gold_d[test_data['pmid']]['title'].strip() + " " + gold_d[test_data['pmid']]['abstract']
    test_data['text'] = test_data['text'].lower().strip().replace('  ', ' ')
    test_data['text_tokens'] = m.encode(test_data['text'])
    test_data['text_tokens_with_prefix'] = torch.cat([test_data['text_tokens'], prefix], dim=-1).unsqueeze(0).cuda()
    try:
        test_data['pred_Ts'] = pred_d_not_in_original[id]['predT_not_in_original']
        test_data['pred_Ts_tokens'] = [m.encode(pred_T) for pred_T in test_data['pred_Ts']]
    except:
        test_data['pred_Ts'] = None
        test_data['pred_Ts_tokens'] = None
    test_data['gold_triples'] = gold_d[test_data['pmid']]['triples']
    test_data['gold_drugs'] = [tokenizer.encode(gold_triple['drug'].lower(), add_special_tokens=False, return_tensors='pt')for gold_triple in test_data['gold_triples']]
    test_data['gold_targets'] = [tokenizer.encode(gold_triple['target'].lower(), add_special_tokens=False, return_tensors='pt') for gold_triple in test_data['gold_triples']]
    test_data['gold_interaction'] = [tokenizer.encode(gold_triple['interaction'].lower(), add_special_tokens=False, return_tensors='pt') for gold_triple in test_data['gold_triples']]
    return test_data

In [8]:
# test_data = get_test_data(822)

# print(f'{{\n"pred_Ts": "{test_data["pred_Ts"][0]}",')
# print('"gold": {')
# for key, value in test_data["gold_triples"][0].items():
#     print(f'"{key}": "{value}",')
# print('},')
# # print(f'"gold": {test_data["gold_triples"][0]},')
# print(f'"text": "{test_data["text"]}"\n}}')

In [16]:
test_data = get_test_data(707)
len(*test_data['text_tokens_with_prefix'])

191

In [17]:
k = 1
for test_data_id in range(707, len(gold_d)):
    # initialize
    test_data = get_test_data(test_data_id)
    test_input = test_data['text_tokens_with_prefix']

    output_text = []
    prob = []
    ranking = []
    step = 0

    with torch.no_grad():
        m.models[0].decoder.eval()
        for new_triple in range(len(test_data['gold_triples'])):
            # the interaction between
            for i in range(3):
                
                step += 1

                out = m.models[0].decoder(test_input)

                softmax_out = torch.softmax(out[0][0][-1], dim=-1)
                _, top_k_indices = torch.topk(out[0][0][-1], k=k)
                top_k_tokens = [tokenizer.convert_ids_to_tokens([indice]) for indice in top_k_indices]
                top_k_probs = torch.softmax(out[0][0][-1][top_k_indices], dim=-1)
                top_k = [(token, prob.item()) for token, prob in zip(top_k_tokens, top_k_probs)]
                # print(f'The top-{k} most possible tokens are:\n{top_k}')
                next_token_id = 1
                test_input = torch.cat([test_input[0], top_k_indices[next_token_id-1].unsqueeze(0)], dim=-1).unsqueeze(0)
                output_text.append(top_k_indices[next_token_id-1])

                prob.append(softmax_out[top_k_indices[next_token_id-1]].item())
                ranking.append(next_token_id)

            # drug
            for id in test_data['gold_drugs'][new_triple][0]:        
                # print(f'output_text: {m.decode(output_text)}\n')
                step += 1

                out = m.models[0].decoder(test_input)
                softmax_out = torch.softmax(out[0][0][-1], dim=-1)
                customized_string_tokens = id.unsqueeze(0).cuda()
                test_input = torch.cat([test_input[0], customized_string_tokens], dim=-1).unsqueeze(0)
                output_text.append(customized_string_tokens.squeeze(0))

                customized_string_prob = out[0][0][-1][customized_string_tokens].clone()
                sorted_output, _ = torch.sort(out[0][0][-1], descending=True)
                prob.append(softmax_out[customized_string_tokens].item())
                ranking.append(torch.where(sorted_output == customized_string_prob)[0].item() + 1)
            
            # and
            step += 1

            out = m.models[0].decoder(test_input)
            softmax_out = torch.softmax(out[0][0][-1], dim=-1)
            customized_string_tokens = tokenizer.encode("and", add_special_tokens=False, return_tensors='pt').squeeze(0)
            test_input = torch.cat([test_input[0], customized_string_tokens.cuda()], dim=-1).unsqueeze(0)
            output_text.append(customized_string_tokens.squeeze(0))

            customized_string_prob = out[0][0][-1][customized_string_tokens].clone()
            sorted_output, _ = torch.sort(out[0][0][-1], descending=True)
            prob.append(softmax_out[customized_string_tokens].item())
            ranking.append(torch.where(sorted_output == customized_string_prob)[0].item() + 1)
            # print(f'output_text: {m.decode(output_text)}\n')

            # target
            for id in test_data['gold_targets'][new_triple][0]:        
                # print(f'output_text: {m.decode(output_text)}\n')
                step += 1

                out = m.models[0].decoder(test_input)
                softmax_out = torch.softmax(out[0][0][-1], dim=-1)
                customized_string_tokens = id.unsqueeze(0).cuda()
                test_input = torch.cat([test_input[0], customized_string_tokens], dim=-1).unsqueeze(0)
                output_text.append(customized_string_tokens.squeeze(0))

                customized_string_prob = out[0][0][-1][customized_string_tokens].clone()
                sorted_output, _ = torch.sort(out[0][0][-1], descending=True)
                prob.append(softmax_out[customized_string_tokens].item())
                ranking.append(torch.where(sorted_output == customized_string_prob)[0].item() + 1)
            
            # is
            step += 1

            out = m.models[0].decoder(test_input)
            softmax_out = torch.softmax(out[0][0][-1], dim=-1)
            customized_string_tokens = tokenizer.encode("is", add_special_tokens=False, return_tensors='pt').squeeze(0)
            test_input = torch.cat([test_input[0], customized_string_tokens.cuda()], dim=-1).unsqueeze(0)
            output_text.append(customized_string_tokens.squeeze(0))

            customized_string_prob = out[0][0][-1][customized_string_tokens].clone()
            sorted_output, _ = torch.sort(out[0][0][-1], descending=True)
            prob.append(softmax_out[customized_string_tokens].item())
            ranking.append(torch.where(sorted_output == customized_string_prob)[0].item() + 1)
            # print(f'{test_data["pmid"]}: ({new_triple + 1}/{len(test_data["gold_triples"])}) \noutput_text: {m.decode(output_text)}\n')

            # interaction
            for id in test_data['gold_interaction'][new_triple][0]:        
                # print(f'output_text: {m.decode(output_text)}\n')
                step += 1

                out = m.models[0].decoder(test_input)
                softmax_out = torch.softmax(out[0][0][-1], dim=-1)
                customized_string_tokens = id.unsqueeze(0).cuda()
                test_input = torch.cat([test_input[0], customized_string_tokens], dim=-1).unsqueeze(0)
                output_text.append(customized_string_tokens.squeeze(0))

                customized_string_prob = out[0][0][-1][customized_string_tokens].clone()
                sorted_output, _ = torch.sort(out[0][0][-1], descending=True)
                prob.append(softmax_out[customized_string_tokens].item())
                ranking.append(torch.where(sorted_output == customized_string_prob)[0].item() + 1)

            # add . or ;
            if new_triple + 1 == len(test_data['gold_triples']):
                step += 1

                out = m.models[0].decoder(test_input)
                softmax_out = torch.softmax(out[0][0][-1], dim=-1)
                customized_string_tokens = torch.tensor(4).unsqueeze(0).cuda()
                test_input = torch.cat([test_input[0], customized_string_tokens.cuda()], dim=-1).unsqueeze(0)
                output_text.append(customized_string_tokens.squeeze(0))

                customized_string_prob = out[0][0][-1][customized_string_tokens].clone()
                sorted_output, _ = torch.sort(out[0][0][-1], descending=True)
                prob.append(softmax_out[customized_string_tokens].item())
                ranking.append(torch.where(sorted_output == customized_string_prob)[0].item() + 1)
                print(f'{test_data["pmid"]}: ({new_triple + 1}/{len(test_data["gold_triples"])}) \noutput_text: {m.decode(output_text)}\n')
                break
            
            else:
                step += 1

                out = m.models[0].decoder(test_input)
                softmax_out = torch.softmax(out[0][0][-1], dim=-1)
                customized_string_tokens = torch.tensor(44).unsqueeze(0).cuda()
                test_input = torch.cat([test_input[0], customized_string_tokens.cuda()], dim=-1).unsqueeze(0)
                output_text.append(customized_string_tokens.squeeze(0))

                customized_string_prob = out[0][0][-1][customized_string_tokens].clone()
                sorted_output, _ = torch.sort(out[0][0][-1], descending=True)
                prob.append(softmax_out[customized_string_tokens].item())
                ranking.append(torch.where(sorted_output == customized_string_prob)[0].item() + 1)
                print(f'{test_data["pmid"]}: ({new_triple + 1}/{len(test_data["gold_triples"])}) \noutput_text: {m.decode(output_text)}\n')


    drugs_in_original = []
    targets_in_original = []
    for i in range(len(test_data['gold_triples'])):
        if test_data['gold_triples'][i]['drug'].lower().strip().replace('  ', ' ') in test_data['text']:
            drugs_in_original.append(1)
        else:
            drugs_in_original.append(0)
        
        if test_data['gold_triples'][i]['target'].lower().strip().replace('  ', ' ') in test_data['text']:
            targets_in_original.append(1)
        else:
            targets_in_original.append(0)


    # Create some fake data.
    x = np.arange(step)
    y1 = prob
    y2 = ranking

    fig, (ax1, ax2) = plt.subplots(2, 1)
    fig.suptitle(f'Pmid: {test_data["pmid"]}')

    ax1.plot(x, y1, '.-')
    ax1.set_ylabel('Probability')

    ax2.plot(x, y2, '.-')
    ax2.set_xlabel('step')
    ax2.set_ylabel('Ranking')

    marks = [0]* (step-1)
    mark = False

    # 0 for parttern tokens, 1 for drugs ,2 for targets, 3 for interaction
    for i, token in enumerate(output_text):
        if token != 6 and output_text[i-1] == 45:
            marks[i] = 1
            mark = True
            continue
        if token == 8 or token == 21 or token == 4:
            continue
        if token != 6 and output_text[i-1] == 8:
            marks[i] = 2
            mark = True
            continue
        if token != 6 and output_text[i-1] == 21:
            marks[i] = 3
            mark = True
            continue
        if token == 44:
            mark = False
            continue
        if mark:
            marks[i] = marks[i-1]
            

    # if marks[x] == 1, then using hollow circle for the plot, if marks[x] == 2, then using hollow triangle for the plot, if marks[x] == 3, then using star for the plot.
    for i in range(step-1):
        if marks[i] == 1:
            ax1.plot(x[i], y1[i], marker='o', color='white', markeredgecolor='blue')
            ax2.plot(x[i], y2[i], marker='o', color='white', markeredgecolor='blue')
            if y2[i] > 5:
                ax1.plot(x[i], y1[i], marker='o', color='white', markeredgecolor='red')
                ax2.plot(x[i], y2[i], marker='o', color='white', markeredgecolor='red')

        if marks[i] == 2:
            ax1.plot(x[i], y1[i], marker='^', color='white', markeredgecolor='blue')
            ax2.plot(x[i], y2[i], marker='^', color='white', markeredgecolor='blue')
            if y2[i] > 5:
                ax1.plot(x[i], y1[i], marker='^', color='red')
                ax2.plot(x[i], y2[i], marker='^', color='red')

        if marks[i] == 3:
            ax1.plot(x[i], y1[i], marker='x', color='white', markeredgecolor='blue')
            ax2.plot(x[i], y2[i], marker='x', color='white', markeredgecolor='blue')
            if y2[i] > 5:
                ax1.plot(x[i], y1[i], marker='x', color='red')
                ax2.plot(x[i], y2[i], marker='x', color='red')
    plt.text(-0.05, 2.45, f"1: exist,    0: no", transform=plt.gca().transAxes)
    plt.text(-0.05, 2.35, f"drugs: {drugs_in_original}", transform=plt.gca().transAxes)
    plt.text(-0.05, 2.25, f"targets: {targets_in_original}", transform=plt.gca().transAxes)

    if 1 in targets_in_original:
        plt.savefig(f'analysis/img/goden_truth_forcing/target_in_original/{test_data["pmid"]}-{1 in drugs_in_original}-{1 in targets_in_original}.png')
    else:
        plt.savefig(f'analysis/img/goden_truth_forcing/{test_data["pmid"]}-{1 in drugs_in_original}-{1 in targets_in_original}.png')

    # plt.savefig(f'analysis/img/goden_truth_forcing/{test_data["pmid"]}-{1 in targets_in_original}.png')

    plt.close()
    print (f'{test_data_id + 1} / {len(gold_d)}')
    # break
    # plt.show()

16366516: (1/1) 
output_text: the interaction between aripiprazole and dopamine d2 receptor is antagonist.

708 / 1159
17296815: (1/1) 
output_text: the interaction between sunitinib and vascular endothelial growth factor receptor 3 is inhibitor.

709 / 1159
17105867: (1/1) 
output_text: the interaction between testosterone and mineralocorticoid receptor is ligand.

710 / 1159
25274603: (1/1) 
output_text: the interaction between flecainide and ryanodine receptor 2 is inhibitor.

711 / 1159
10076535: (1/1) 
output_text: the interaction between bicalutamide and androgen receptor is antagonist.

712 / 1159
1588924: (1/1) 
output_text: the interaction between ambenonium and acetylcholinesterase is inhibitor.

713 / 1159
11969359: (1/1) 
output_text: the interaction between amrinone and tumor necrosis factor is inhibitor.

714 / 1159
21088277: (1/1) 
output_text: the interaction between nitroxoline and methionine aminopeptidase 2 is inhibitor.

715 / 1159
6186851: (1/1) 
output_text: the i

../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [62,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [62,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [62,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [62,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [62,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [62,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [62,0,0],

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
