In [1]:
from utils import *
import yaml

import logging
logging.getLogger().setLevel(logging.CRITICAL)

import warnings
warnings.filterwarnings('ignore')

args_file = "./args.yaml"

with open(args_file) as f:
    args_dict = yaml.load(f)

extra_args_file = ""
    
if extra_args_file:
    with open(extra_args_file) as f:
        extra_args_dict = yaml.load(f)
        for k,v in extra_args_dict.items():
            args_dict[k] = v

args = Dict2Obj(args_dict)

## Load model

In [2]:
# checkpoint_dir = "./output/21May2020-15:37:15/checkpoint-126"
checkpoint_dir = "./output/22May2020-09:27:40/checkpoint-273"
# checkpoint_dir = "./output/22May2020-09:27:40/checkpoint-420"
args['model_name_or_path'] = checkpoint_dir
args['config_name'] = checkpoint_dir
args['tokenizer_name'] = checkpoint_dir
args['do_test'] = True
args['do_train'] = False
args['test_batch_size'] = 1

In [3]:
bart_system = BartSystem(args)  # Load model
trainer = get_trainer(bart_system, args)  # get the trainer

In [4]:
bart_system.to('cuda:0')

BartSystem(
  (model): BartForConditionalGeneration(
    (model): BartModel(
      (shared): Embedding(50265, 1024, padding_idx=1)
      (encoder): BartEncoder(
        (embed_tokens): Embedding(50265, 1024, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
        (layers): ModuleList(
          (0): EncoderLayer(
            (self_attn): SelfAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (final_layer_n

## Get test data

In [5]:
val_dataset = SummarizationDataset(
    bart_system.tokenizer, data_path=args.test_data_path, type_path="val", \
    doc_size=args.doc_max_seq_length
)
dataloader = DataLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=True)

HBox(children=(FloatProgress(value=0.0, description='Loading val data', max=543.0, style=ProgressStyle(descrip…




In [6]:
batches = list(iter(dataloader))

In [7]:
i = 1
source, source_t, source_attention_mask, targets, targets_t = val_dataset.data[i]
print(targets)
source

['perennial allergic rhinitis']


'1 INDICATIONS AND USAGE <newline> ZETONNA Nasal Aerosol is a corticosteroid indicated for treatment of symptoms associated with seasonal and perennial allergic rhinitis in adults and adolescents 12 years of age and older. ( 1.1 ) <newline> 1.1 Treatment of Allergic Rhinitis <newline> ZETONNA (ciclesonide) Nasal Aerosol is indicated for the treatment of symptoms associated with seasonal and perennial allergic rhinitis in adults and adolescents 12 years of age and older.'

## Make single prediction

In [8]:
def gen_prediction(batch, bart_system, device='cuda:0', max_length=64, min_length=0, num_beams=1, do_sample=False, num_return_sequences=1):
    summaries = bart_system.model.generate(
        input_ids=batch['source_ids'].to(device),
        attention_mask=batch['source_mask'].to(device),
        decoder_start_token_id=bart_system.model.config.eos_token_id,

        max_length=max_length + 2,  # +2 from original because we start at step=1 and stop before max_length
        min_length=min_length + 1,  # +1 from original because we start at step=1
        early_stopping=True,

        num_beams=num_beams,
        do_sample=do_sample,
        num_return_sequences=num_return_sequences,  # just the top one
        repetition_penalty=1.0,  # no penalty
        length_penalty=1.0,  # no penalty
    )

    preds = [
        bart_system.tokenizer.decode([v for v in g if v > 2], skip_special_tokens=False, clean_up_tokenization_spaces=True)
        for g in summaries
    ]
    gen_outputs = {s.strip().lower() for p in preds for s in p.split('<sep>')}
    
    return gen_outputs

In [9]:
i = 4
source, source_t, source_attention_mask, targets, targets_t = val_dataset.data[i]

batch = batches[i]
prediction = gen_prediction(batch, bart_system)
print(targets)
print(prediction)

['type 2 diabetes mellitus']
{'type 2 diabetes mellitus'}


In [10]:
def gen_prediction_from_text(text, bart_system, device='cuda:0', max_length=64, min_length=0, num_beams=1, do_sample=False, num_return_sequences=1):
    dct = bart_system.tokenizer.batch_encode_plus([text.lower()], max_length=512, return_tensors="pt", pad_to_max_length=True)

    summaries = bart_system.model.generate(
        input_ids=dct['input_ids'].to(device),
        attention_mask=dct['attention_mask'].to(device),
        decoder_start_token_id=bart_system.model.config.eos_token_id,

        max_length=max_length + 2,  # +2 from original because we start at step=1 and stop before max_length
        min_length=min_length + 1,  # +1 from original because we start at step=1
        early_stopping=True,

        num_beams=num_beams,
        do_sample=do_sample,
        num_return_sequences=num_return_sequences,  # just the top one
        repetition_penalty=1.0,  # no penalty
        length_penalty=1.0,  # no penalty
    )

    preds = [
        bart_system.tokenizer.decode([v for v in g if v > 2], skip_special_tokens=False, clean_up_tokenization_spaces=True)
        for g in summaries
    ]
    gen_outputs = {s.strip().lower() for p in preds for s in p.split('<sep>')}
    
    return gen_outputs

In [11]:
text = "1 INDICATIONS AND USAGE <newline> Preservative-free TIMOPTIC in OCUDOSE is a corticosteroid indicated for treatment of symptoms associated with seasonal and perennial allergic rhinitis in adults and adolescents 12 years of age and older. ( 1.1 ) <newline> 1.1 Treatment of Allergic Rhinitis <newline> Preservative-free TIMOPTIC in OCUDOSE is indicated for the treatment of symptoms associated with seasonal and perennial allergic rhinitis in adults and adolescents 12 years of age and older."
gen_prediction_from_text(text, bart_system)

{'perennial allergic rhinitis'}

In [12]:
text = "Pizza is used to help adults and children with: <newline> \
<bullet> chickenpox <newline> \
<bullet> hunger <newline> \
It also has applications in the roofing industry. Pizza is not indicated for patients with obesity."
gen_prediction_from_text(text, bart_system, num_beams=16, do_sample=False, num_return_sequences=1)

{'chickenpox', 'hunger'}

In [13]:
text = "1 INDICATIONS AND USAGE <newline> ZETONNA Nasal Aerosol is a corticosteroid indicated for treatment of symptoms associated with seasonal and perennial allergic rhinitis in adults and adolescents 12 years of age and older. ( 1.1 ) <newline> 1.1 Treatment of Allergic Rhinitis <newline> ZETONNA (ciclesonide) Nasal Aerosol is indicated for the treatment of symptoms associated with seasonal and perennial allergic rhinitis in adults and adolescents 12 years of age and older."
gen_prediction_from_text(text, bart_system)

{'perennial allergic rhinitis'}

## Get all predictions

In [14]:
targets = []
for source, source_t, source_attention_mask, target, targets_t in val_dataset.data:
    targets.append({t.lower() for t in target})

In [16]:
predictions = []
for batch in tqdm(batches):
    predictions.append(gen_prediction(batch, bart_system, num_beams=1, do_sample=False, num_return_sequences=1))

HBox(children=(FloatProgress(value=0.0, max=533.0), HTML(value='')))




In [19]:
with open("./test_targets.json", 'wt') as f: 
    json.dump([list(t) for t in targets], f)
with open("./test_predictions.json", 'wt') as f: 
    json.dump([list(p) for p in predictions], f)

In [20]:
with open("./test_targets.json", 'rt') as f: 
    targets = json.load(f)
    targets = [set(t) for t in targets]
with open("./test_predictions.json", 'rt') as f: 
    predictions = json.load(f)
    predictions = [set(p) for p in predictions]

## Strict

In [21]:
tp = 0
fp = 0
fn = 0

for i, target in enumerate(targets):
    prediction = predictions[i]
    tp += len(target & prediction)
    fp += len(prediction - target)
    fn += len(target - prediction)
    
recall = tp/(tp+fn)
precision = tp/(tp+fp)
f1 = 2*recall*precision/(recall+precision)

{'recall': recall,
'precision': precision,
'f1': f1}

{'recall': 0.6779279279279279,
 'precision': 0.7496886674968867,
 'f1': 0.7120047309284447}

### ./output/21May2020-15:37:15/checkpoint-126

    {'recall': 0.6385135135135135,
     'precision': 0.7411764705882353,
     'f1': 0.6860254083484573}
     
### ./output/22May2020-09:27:40/checkpoint-273, do_sample=False, num_return_sequences=1

     {'recall': 0.6779279279279279,
     'precision': 0.7496886674968867,
     'f1': 0.7120047309284447}
     
### ./output/22May2020-09:27:40/checkpoint-273, do_sample=True, num_return_sequences=5

     {'recall': 0.7804054054054054,
     'precision': 0.5294117647058824,
     'f1': 0.6308602639963586}

### ./output/22May2020-09:27:40/checkpoint-273, do_sample=False, num_beams=16

     {'recall': 0.7274774774774775,
     'precision': 0.5273469387755102,
     'f1': 0.611452910553715}
     
### ./output/22May2020-09:27:40/checkpoint-420, do_sample=False, num_return_sequences=1

    {'recall': 0.6587837837837838,
     'precision': 0.7452229299363057,
     'f1': 0.6993424985056783}
     
### ./output/22May2020-09:27:40/checkpoint-420, do_sample=True, num_return_sequences=5

    {'recall': 0.7826576576576577,
     'precision': 0.5438184663536776,
     'f1': 0.641735918744229}
     
I've found that a more extensive beam search (either by sampling or by increasing number of beams) is needed to retrieve all the indications from my test sentence (for pizza), but when this is used for evaluation then recall increases but precision decreases more leading to a lower F1 score.

## Fuzzy

Use edit distance (Levenshtien distance) of characters and subtokens

Only slightly better scores than strict

To get a better metric I should asign results to EFO IDs and then compare these

In [38]:
!pip install fuzzywuzzy

Collecting fuzzywuzzy
  Using cached fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)
Installing collected packages: fuzzywuzzy
Successfully installed fuzzywuzzy-0.18.0


In [39]:
from fuzzywuzzy import fuzz

In [81]:
threshold=80

tp = 0
fp = 0
fn = 0

for i, ts in enumerate(targets):
    ts = list(ts)
    ps = list(predictions[i])
    d = np.zeros((len(ts),len(ps)), dtype=np.float64)
    for x,t in enumerate(ts):
        for y,p in enumerate(ps):
            d[x,y] = fuzz.ratio(t,p)
    d[d<threshold] = 0
    d[d>=threshold] = 1
    
    temp_tp = d.max(axis=0).sum()
    tp += temp_tp
    fp += d.shape[1] - temp_tp
    fn += d.shape[0] - temp_tp
    
recall = tp/(tp+fn)
precision = tp/(tp+fp)
f1 = 2*recall*precision/(recall+precision)

{'recall': recall,
'precision': precision,
'f1': f1}

{'recall': 0.7027027027027027,
 'precision': 0.7770859277708593,
 'f1': 0.7380248373743348}

### ./output/21May2020-15:37:15/checkpoint-126

    {'recall': 0.652027027027027,
     'precision': 0.7568627450980392,
     'f1': 0.7005444646098004}
     
### ./output/22May2020-09:27:40/checkpoint-273

     {'recall': 0.7027027027027027,
     'precision': 0.7770859277708593,
     'f1': 0.7380248373743348}

In [45]:
!pip install textdistance

Collecting textdistance
  Downloading textdistance-4.2.0-py3-none-any.whl (29 kB)
Installing collected packages: textdistance
Successfully installed textdistance-4.2.0


In [130]:
import textdistance

def fuzzy_match(targets,predictions,tokenizer):
    ts = [tokenizer.encode(t.strip(), add_special_tokens=False) for t in targets]
    ps = [tokenizer.encode(p.strip(), add_special_tokens=False) for p in predictions]
    d = np.zeros((len(ts),len(ps)), dtype=np.float64)
    for x,t in enumerate(ts):
        for y,p in enumerate(ps):
            d[x,y] = textdistance.levenshtein.normalized_similarity(t,p)
    return d

In [134]:
threshold=0.8

tp = 0
fp = 0
fn = 0

for i, ts in enumerate(targets):
    ts = list(ts)
    ps = list(predictions[i])
    
    d = fuzzy_match(ts,ps,bart_system.tokenizer)
    
    d[d<threshold] = 0
    d[d>=threshold] = 1
    
    temp_tp = d.max(axis=0).sum()
    tp += temp_tp
    fp += d.shape[1] - temp_tp
    fn += d.shape[0] - temp_tp
    
recall = tp/(tp+fn)
precision = tp/(tp+fp)
f1 = 2*recall*precision/(recall+precision)

{'recall': recall,
'precision': precision,
'f1': f1}

{'recall': 0.6801801801801802,
 'precision': 0.7521793275217933,
 'f1': 0.7143701951507984}