In [211]:
import requests
from ast import literal_eval
from transformers import AutoTokenizer
from tqdm import tqdm
import srsly
import pandas as pd

### Utility function

In [8]:
def promtify(input_text, entity_type):
    PROMPT=f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

    ### Instruction:
    Given a sentence, extract {entity_type} entities from it by highlighting them with <mark> and </mark>. If not present, output the same sentence.
    ### Input:
    {input_text}
    ### Response:

    '''
    return PROMPT

def load_tokenizer(model_name="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # new tokens
    new_tokens = ["<mark>", "</mark>"]

    # check if the tokens are already in the vocabulary
    new_tokens = set(new_tokens) - set(tokenizer.vocab.keys())

    # add the tokens to the tokenizer vocabulary
    tokenizer.add_tokens(list(new_tokens))

    return tokenizer

def get_response(text):
    url = "http://127.0.0.1:5000/v1/completions"
    headers = {
        "Content-Type": "application/json"
    }

    data = {
        "prompt": text,
        "max_tokens": 200,
        "temperature": 0.7,
        "typical_p": 1,
        "seed": 10,
        "return_full_text": False,
        "repetition_penalty": 1.15,
        "repetition_penalty_range":1024,
        "guidance_scale": 1
    }

    response = requests.post(url, headers=headers, json=data)

    answer = literal_eval(response.text)['choices'][0]['text']
    #answer = response.text
    return answer

def get_answer(text, entity_type='disease'):
    prompt = promtify(text, entity_type)

    #return literal_eval(answer)
    return get_response(prompt)


tokenizer = load_tokenizer()



### Prepare data

In [162]:
# read lines from 'datasets/NCBI-disease_test.txt'
with open('datasets/BC5CDR-disease_test.txt', 'r') as f:
    lines = f.readlines()

In [163]:
idx = 1

lines[idx]

'The authors describe the case of a 56 - year - old woman with chronic , severe <mark>heart failure</mark> secondary to <mark>dilated cardiomyopathy</mark> and absence of significant <mark>ventricular arrhythmias</mark> who developed <mark>QT prolongation</mark> and <mark>torsade de pointes</mark> <mark>ventricular tachycardia</mark> during one cycle of intermittent low dose ( 2 . 5 mcg / kg per min ) dobutamine . \n'

In [164]:
line_processed = lines[idx].replace('<mark>', '').replace('</mark>', '')
line_processed

'The authors describe the case of a 56 - year - old woman with chronic , severe heart failure secondary to dilated cardiomyopathy and absence of significant ventricular arrhythmias who developed QT prolongation and torsade de pointes ventricular tachycardia during one cycle of intermittent low dose ( 2 . 5 mcg / kg per min ) dobutamine . \n'

In [165]:
ans = get_answer(line_processed)

ans

'\n    The authors describe the case of a 56 - year - old woman with <mark>chronic</mark>, <mark>severe heart failure</mark> secondary to <mark>dilated cardiomyopathy</mark> and <mark>absence of significant ventricular arrhythmias</mark> who developed <mark>QT prolongation</mark> and <mark>torsade de pointes ventricular tachycardia</mark> during one cycle of intermittent low dose ( 2 . 5 mcg / kg per min ) dobutamine .'

In [166]:
ans.strip()

'The authors describe the case of a 56 - year - old woman with <mark>chronic</mark>, <mark>severe heart failure</mark> secondary to <mark>dilated cardiomyopathy</mark> and <mark>absence of significant ventricular arrhythmias</mark> who developed <mark>QT prolongation</mark> and <mark>torsade de pointes ventricular tachycardia</mark> during one cycle of intermittent low dose ( 2 . 5 mcg / kg per min ) dobutamine .'

In [167]:
ans = get_answer(line_processed)


In [168]:
# Predict
pred_rows = []
for line in lines:
    line_processed = line.replace('<mark>', '').replace('</mark>', '')

In [None]:
line = lines[10]

# Eval based on BIO tagging
entity_type_short = 'DIS'

# convert entity between <mark> and </mark> to BIO tagging
tokens = tokenizer.tokenize(line)

In [2]:
# token next to <mark> becomes B-<entity_type_short>
# tokens from B-<entity_type_short> to one before </mark> becomes I-<entity_type_short>
# other tokens are O
# remore <mark> and </mark> from tokens 
def get_bio_tagging(tokens, entity_type_short='DIS'):
    tags = ['O'] * len(tokens)

    # mark <mark> and </mark> tokens as 'MS' and 'ME' in the bio_tagging
    tags = ['MS' if token == '<mark>' else 'ME' if token == '</mark>' else tag for (tag, token) in zip(tags, tokens)]

    i = 0
    n = len(tags)
    
    while i < n:
        if tags[i] == 'MS':
            # Start from the next token after 'MS'
            start = i + 1
            while start < n and tags[start] != 'ME':
                start += 1
            
            # Now 'start' should be at 'ME' or out of bounds
            if start < n and tags[start] == 'ME':
                if start - i == 2:
                    # Only one token between 'MS' and 'ME'
                    tags[i + 1] = 'B' + '-' + entity_type_short
                else:
                    # More than one token between 'MS' and 'ME'
                    tags[i + 1] = 'B' + '-' + entity_type_short
                    for j in range(i + 2, start):
                        tags[j] = 'I' + '-' + entity_type_short
            i = start # Continue from the end of this segment
        else:
            i += 1

    # remove all 'MS' and 'ME' tokens
    tags = [tag for tag in tags if tag not in ['MS', 'ME']]
    
    return tags


In [170]:
print(tokens)

['risk', 'of', 'transient', '<mark>', 'hyper', '##amm', '##one', '##mic', 'encephalopathy', '</mark>', 'in', '<mark>', 'cancer', '</mark>', 'patients', 'who', 'received', 'continuous', 'infusion', 'of', '5', '-', 'fluorouracil', 'with', 'the', 'complication', 'of', '<mark>', 'dehydration', '</mark>', 'and', '<mark>', 'infection', '</mark>', '.']


In [171]:
len(tokens)

35

In [172]:
bio_tags = get_bio_tagging(tokens, entity_type_short)
print(bio_tags)

['O', 'O', 'O', 'B-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'O', 'B-DIS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DIS', 'O', 'B-DIS', 'O']


In [173]:
len(bio_tags)

27

In [174]:
tokens_processed = [token for token in tokens if token not in ['<mark>', '</mark>']]  

for i, (token, bio_tag) in enumerate(zip(tokens_processed, bio_tags)):
    print(f'{i} |  {token} | {bio_tag}')

0 |  risk | O
1 |  of | O
2 |  transient | O
3 |  hyper | B-DIS
4 |  ##amm | I-DIS
5 |  ##one | I-DIS
6 |  ##mic | I-DIS
7 |  encephalopathy | I-DIS
8 |  in | O
9 |  cancer | B-DIS
10 |  patients | O
11 |  who | O
12 |  received | O
13 |  continuous | O
14 |  infusion | O
15 |  of | O
16 |  5 | O
17 |  - | O
18 |  fluorouracil | O
19 |  with | O
20 |  the | O
21 |  complication | O
22 |  of | O
23 |  dehydration | B-DIS
24 |  and | O
25 |  infection | B-DIS
26 |  . | O


In [175]:
y_true = []
for line in lines:
    tokens = tokenizer.tokenize(line)
    bio_tags = get_bio_tagging(tokens, entity_type_short='DIS')
    y_true.append(bio_tags)

In [176]:
print(y_true[1])

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DIS', 'I-DIS', 'O', 'O', 'B-DIS', 'I-DIS', 'O', 'O', 'O', 'O', 'B-DIS', 'I-DIS', 'O', 'O', 'B-DIS', 'I-DIS', 'O', 'B-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'B-DIS', 'I-DIS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


### Predicting

In [177]:
lines_pred = []

for line in tqdm(lines):
    line_processed = line.replace('<mark>', '').replace('</mark>', '')
    try:
        ans = get_answer(line_processed)
    except SyntaxError:
        ans = 'ERROR: SyntaxError'
    lines_pred.append(ans.strip())

100%|██████████| 4797/4797 [1:38:36<00:00,  1.23s/it]  


In [133]:
line_processed

'Occasional missense mutations in ATM were also found in tumour DNA from patients with B - cell non - Hodgkins lymphomas ( B - NHL ) and a B - NHL cell line . \n'

In [134]:
ans = get_answer(line_processed)

In [136]:
ans

'Internal Server Error'

In [141]:
lines_pred

['Clustering of missense mutations in the <mark>ataxia - telangiectasia</mark> gene in a sporadic <mark>T - cell leukaemia</mark> .',
 'Ataxia - telangiectasia (<mark>A - T</mark>) is a recessive multi - system disorder caused by mutations in the ATM gene at 11q22 - q23 (ref . 3).',
 'The risk of <mark>cancer</mark> , especially <mark>lymphoid neoplasias</mark> , is substantially elevated in A - T patients and has long been associated with chromosomal instability .',
 'By analysing <mark>tumour DNA</mark> from patients with <mark>sporadic T - cell prolymphocytic leukaemia</mark> (<mark>T - PLL</mark>) , a rare clonal malignancy with similarities to a mature <mark>T - cell leukaemia</mark> seen in A - T , we demonstrate a high frequency of <mark>ATM mutations</mark> in <mark>T - PLL</mark> .',
 'In marked contrast to the ATM mutation pattern in <mark>A - T</mark> , the most frequent nucleotide changes in this leukaemia were missense mutations .',
 'These clustered in the region correspo

In [178]:
y_preds = []
for line in lines_pred:
    tokens = tokenizer.tokenize(line)
    bio_tags = get_bio_tagging(tokens, entity_type_short='DIS')
    y_preds.append(bio_tags)

In [181]:
for y in y_preds[:10]:
    print(y)


['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DIS', 'O', 'B-DIS', 'I-DIS', 'I-DIS', 'O', 'O', 'B-DIS', 'I-DIS', 'O', 'B-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'O', 'O', 'B-DIS', 'I-DIS', 'O', 'B-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'B-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'I-DIS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DIS', 'I-DIS', 'I-DIS', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'O', 'O'

In [156]:
import srsly

### Evaluation

In [123]:
# https://github.com/mdadda/nerval
from nerval import crm

In [179]:
cr, cm, cm_labels = crm(y_true, y_preds, scheme='BIO')
print(cr)

True Entities: 4422
Pred Entities: 8158 

True Entities with 3 or more tags: 1139
Pred Entities with 3 or more tags: 2992 

True positives:  1606
False positives (true = 'O'):  6387
False positives (true <> pred):  165
ToT False positives:  6552
False negatives:  2651 

              precision  recall  f1_score  true_entities  pred_entities
DIS                0.20    0.36      0.26       4,422.00       7,993.00
DIS__              0.00    0.00      0.00           0.00         165.00
micro_avg          0.20    0.36      0.26       4,422.00       8,158.00
macro_avg          0.10    0.18      0.13       4,422.00       8,158.00
weighted_avg       0.20    0.36      0.26       4,422.00       8,158.00


In [157]:
cm

array([[ 359,   57,  541],
       [   0,    0,    0],
       [1286,    0,    0]])

In [158]:
cr

Unnamed: 0,precision,recall,f1_score,true_entities,pred_entities
DIS,0.22,0.38,0.28,957.0,1645.0
DIS__,0.0,0.0,0.0,0.0,57.0
micro_avg,0.21,0.38,0.27,957.0,1702.0
macro_avg,0.11,0.19,0.14,957.0,1702.0
weighted_avg,0.22,0.38,0.28,957.0,1702.0


In [159]:
import os
import srsly

In [180]:
output_dir = 'results/mixtral_4bit'
dataset = 'bc5cdr_disease'

cr.to_csv(os.path.join(output_dir, dataset + '_cr.csv'))
srsly.write_jsonl(os.path.join(output_dir, dataset + '_y_preds.jsonl'), y_preds)

### Partial Eval

In [None]:
'''
https://pypi.org/project/eval4ner/
'''

In [3]:
text = 'John Jones and Peter Peters came to York'
tokens = ['John', 'Jones', 'and', 'Peter', 'Peters', 'came', 'to', 'York']
tokens_bio = ['B-PER', 'I-PER', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-LOC']

gt = [('PER', 'John Jones'), ('PER', 'Peter Peters'), ('LOC', 'York')]
preds = [('PER', 'John Jones'), ('PER', 'Peter Peters'), ('LOC', 'York')]

In [14]:
def bio_to_entities(tokens, tokens_bio):
    entities, entity, label = [], [], None
    for token, bio in zip(tokens, tokens_bio):
        if bio.startswith('B-'):
            if entity: entities.append((label, ' '.join(entity)))
            entity, label = [token], bio[2:]
        elif bio.startswith('I-') and label:
            entity.append(token)
        else:
            if entity: entities.append((label, ' '.join(entity)))
            entity, label = [], None
    if entity: entities.append((label, ' '.join(entity)))
    return entities

In [5]:
gt = bio_to_entities(tokens, tokens_bio)
gt

[('PER', 'John Jones'), ('PER', 'Peter Peters'), ('LOC', 'York')]

In [111]:
import eval4ner.muc as muc
import os


In [159]:
dataset = 'NLMGene'
model_name = 'mixtral_4bit'

In [160]:
# read lines from 'datasets/NCBI-disease_test.txt'
with open(f'datasets/{dataset}_test.txt', 'r') as f:
    lines = f.readlines()

In [161]:
lines_processed = [l.replace('<mark>', '').replace('</mark>', '') for l in lines]
len(lines_processed)

1128

In [162]:
y_true = []
for line in lines:
    tokens = tokenizer.tokenize(line)
    bio_tags = get_bio_tagging(tokens, entity_type_short='DIS')
    y_true.append(bio_tags)

In [163]:
preds_filepath = f'results/{model_name}/{dataset}_test_y_preds.jsonl'
y_preds = [y for y in srsly.read_jsonl(preds_filepath)]
len(y_preds)

1128

In [164]:
y_true_ent = [y for y in map(bio_to_entities, [tokenizer.tokenize(line) for line in lines_processed], y_true)]

In [165]:
y_true_ent[:10]

[[('DIS', 'dec - 205'),
  ('DIS', 'major histocompatibility complex class i'),
  ('DIS', 'cd8')],
 [('DIS', 'major histocompatibility complex ( mhc ) class i'),
  ('DIS', 'dec - 205')],
 [('DIS', 'ovalbumin'),
  ('DIS', 'ova'),
  ('DIS', '##de ##c -'),
  ('DIS', 'by'),
  ('DIS', 'by'),
  ('DIS', 'to'),
  ('DIS', ','),
  ('DIS', 'and')],
 [('DIS', 'ova'),
  ('DIS', 'mhc class i'),
  ('DIS', 'transporter of antigenic peptides'),
  ('DIS', 'tap')],
 [('DIS', '##de ##c -'), ('DIS', ':'), ('DIS', ',')],
 [('DIS', 'ova'),
  ('DIS', 'tcr'),
  ('DIS', 'cd8'),
  ('DIS', 'ova'),
  ('DIS', 'tap')],
 [('DIS', '##de ##c -'), ('DIS', ':'), ('DIS', 'with')],
 [('DIS', 'cd40'), ('DIS', '##de ##c -'), ('DIS', ':')],
 [('DIS', 'cd8'),
  ('DIS', '##cd'),
  ('DIS', 'of interleukin'),
  ('DIS', 'and interferon'),
  ('DIS', 'to')],
 [('DIS', 'dec - 205'), ('DIS', 'mhc class i')]]

In [166]:
y_preds_ent = [y for y in map(bio_to_entities, [tokenizer.tokenize(line) for line in lines_processed], y_preds)]
len(y_preds_ent)

1128

In [167]:
y_preds_ent[:10]

[[('GEN', 'dec - 205')],
 [('GEN', 'dcs'), ('GEN', 'mhc')],
 [('GEN', 'ova')],
 [],
 [('GEN', 'ova'), ('GEN', 'ova protein')],
 [('GEN', 'ova')],
 [('GEN', 'alpha ##de ##c - 205 :'), ('GEN', 'ova'), ('GEN', 'ova')],
 [('GEN', ','), ('GEN', ',')],
 [],
 [('GEN', 'dec - 205')]]

In [208]:
results = muc.evaluate_all(y_preds_ent, y_true_ent * 1, lines_processed, verbose=False)


 NER evaluation scores:
  strict mode, Precision=0.0488, Recall=0.0488, F1:0.0488
   exact mode, Precision=0.3791, Recall=0.2256, F1:0.2692
 partial mode, Precision=0.3810, Recall=0.2275, F1:0.2710
    type mode, Precision=0.0488, Recall=0.0488, F1:0.0488


In [209]:
print('\n', 'NER evaluation scores:')
for mode, res in results.items():
    print("{:>8s} mode, Precision={:<6.4f}, Recall={:<6.4f}, F1:{:<6.4f}"
            .format(mode, res['precision'], res['recall'], res['f1_score']))


 NER evaluation scores:
  strict mode, Precision=0.0488, Recall=0.0488, F1:0.0488
   exact mode, Precision=0.3791, Recall=0.2256, F1:0.2692
 partial mode, Precision=0.3810, Recall=0.2275, F1:0.2710
    type mode, Precision=0.0488, Recall=0.0488, F1:0.0488


In [210]:
df_result = pd.DataFrame(results).T

In [206]:
df_result.iloc[:,:3] = round(df_result.iloc[:,:3]*100,3)
df_result['count'] = df_result['count'].astype(int)
df_result

Unnamed: 0,precision,recall,f1_score,count
strict,25.283,21.417,22.17,940
exact,25.283,21.417,22.17,940
partial,25.949,22.121,22.801,940
type,26.615,22.825,23.433,940


In [187]:
df_result.to_csv(preds_filepath[:-13] + 'muc.csv')