# __3.2. Evaluate model using the PICKLE corpus__

Goal:
- Mask entities from the Pickle corpus to evaluate model performance

Consideraetions:
- Only mask unigrams

Log:
- 11/30/23: 
   PICKLE dataset is dervied from IOB format data from:
    - `hpc.msu.edu:/mnt/research/ShiuLab/serena_kg/PICKLE_250_abstracts_entities_and_relations_FINAL_05Jul2023`
  - The derived data is copied from:
    - `hpc.msu.edu:/mnt/research/compbiol_shiu/kg/1_data_proc`
  - Will eventually move `kg:/1_data_proc/script_1_1_parse_brat.ipynb` to be `3_1` in this repo.
    - Moved.
  

## ___Setup___

In [1]:
from pathlib import Path
from tqdm import tqdm

from spacy.lang.en import English
from spacy.tokens import DocBin
from spacy.util import compile_infix_regex
from spacy.tokenizer import Tokenizer
from spacy.lang.char_classes import \
      ALPHA, ALPHA_LOWER, ALPHA_UPPER, CONCAT_QUOTES, LIST_ELLIPSES, LIST_ICONS

from datasets import load_dataset
from transformers import BertTokenizerFast, BertForMaskedLM, pipeline

In [2]:
proj_dir   = Path.home() / "projects/plantbert"
work_dir   = proj_dir / "3_eval_with_pickle"
pickle_dir = work_dir / "pickle"

# Vanilla model
dir1       = proj_dir / "1_vanilla_bert" 
model1_dir = dir1 / "models/"
ckpt1_dir  = model1_dir / "checkpoint-35500"

# Filtered model

## ___Load dataset___

### Get PICKLE data

Data processed into Spacy format. Obtained with:

```bash
scp shius@hpc.msu.edu:/mnt/research/compbiol_shiu/kg/1_data_proc/*.spacy ./
```

Info on saving data can be find in [Training Pipelines & Models](https://spacy.io/usage/training#training-data)
- Specifically, the section on [preparing training data](https://spacy.io/usage/training#training-data) indicated that `.spacy` file is saved as a [`DocBin`](https://spacy.io/api/docbin) obj.
- Note that the `spacy` files are generated with a custom tokenizer which needs to be loaded.

In [4]:
nlp = English()
nlp.tokenizer = custom_tokenizer(nlp)

In [5]:
train_db = DocBin().from_disk(pickle_dir / "train.spacy")
dev_db   = DocBin().from_disk(pickle_dir / "dev.spacy")
test_db  = DocBin().from_disk(pickle_dir / "test.spacy")

In [6]:
docs = [doc for doc in train_db.get_docs(nlp.vocab)] + \
       [doc for doc in dev_db.get_docs(nlp.vocab)] + \
       [doc for doc in test_db.get_docs(nlp.vocab)]
len(docs)

233

### Helper function to get masked sentence list

In [49]:
def mask_docs(docs):
  '''Mask words in docs
  Args:
    docs (Doc): list of spacy docs with annotations, specifically, PICKLE docs are
      used.
  Returns:
    masked_sents (list): nested list where each sublist contains original token, 
      and masked sentence
  '''

  masked_sents = []
  for doc in tqdm(docs):
    # Build entity dictionary: {start: [end, text, label]}
    edict = {}
    for ent in doc.ents:
      if ent.start not in edict:
        edict[ent.start] = [ent.end, ent.text, ent.label_]
      else:
        print("ERR: Duplicate start index", ent.start)
    
    test_sent = [] # list of tokens
    start_tidx = 0

    # Go through each token in the doc
    for tidx, token in enumerate(doc):
      if token.text != ".":
        test_sent.append(token)
      else:
        #print(f"## start_idx={start_idx}, tidx={tidx}")
        #print(sent)

        # Go through edict to see if there is anything that need to be masked
        #   start_tidx: the WHOLE doc-based index for the start token of this sent.
        #   tidx: at this point, the WHOLE doc-based index for "." for this sent.
        #   sidx below is the WHOE doc-based index for each token in this sent
        for sidx in range(start_tidx, tidx):
          if sidx in edict:

            # unigram entity
            if edict[sidx][0] - sidx == 1:
              # get a tmp list going because there can be more than one entities
              # in the same sentence
              test_sent_tmp = [t.text for t in test_sent]

              # midx: the token index to mask where the starting index for this
              # sent is 0.
              midx = sidx - start_tidx
              ori_txt = test_sent_tmp[midx]
              test_sent_tmp[midx] = "[MASK]"

              # add masked sentence to list
              masked_sents.append([ori_txt, 
                                   " ".join(test_sent_tmp)])

        # reset variables
        test_sent = []
        start_tidx = tidx + 1
    
  return masked_sents

In [50]:
masked_sents = mask_docs(docs)
len(masked_sents)

100%|██████████| 233/233 [00:00<00:00, 8915.18it/s]


2475

In [51]:
masked_sents[:5]

[['Bensulfuron-methyl',
  '[MASK] ( BSM ) is widely used in paddy soil for weed control'],
 ['BSM',
  'Bensulfuron-methyl ( [MASK] ) is widely used in paddy soil for weed control'],
 ['BSM',
  'In this study , we have found significant effects of [MASK] on the infestation of Bemisia tabaci , Myzus persicae , and Tobacco mosaic virus ( TMV ) in Nicotiana tabacum'],
 ['TMV',
  'In this study , we have found significant effects of BSM on the infestation of Bemisia tabaci , Myzus persicae , and Tobacco mosaic virus ( [MASK] ) in Nicotiana tabacum'],
 ['BSM', 'The soil was treated with [MASK] before the pest inoculation']]

## ___Set up pipeline using plant-bert-vanilla-cased___

### Load model, tokenizer

In [37]:
model1     = BertForMaskedLM.from_pretrained(ckpt1_dir)
tokenizer1 = BertTokenizerFast.from_pretrained(model1_dir)

### Set fill mask pipeline

In [43]:
fill_mask1 = pipeline("fill-mask", model=model1, tokenizer=tokenizer1)

### Fill mask for masked_sents

## ___Testing___

### Picke doc testing

In [None]:
test_doc    = docs[0]
type(test_doc)

spacy.tokens.doc.Doc

In [None]:
test_doc

Bensulfuron-methyl (BSM) is widely used in paddy soil for weed control. BSM residue in the soil has been known to inhibit the growth of sensitive crop plants. However, it is unknown whether BSM residue can affect the agrosystem in general. In this study, we have found significant effects of BSM on the infestation of Bemisia tabaci, Myzus persicae, and Tobacco mosaic virus (TMV) in Nicotiana tabacum. The soil was treated with BSM before the pest inoculation. The herbicide-treated tobaccos showed resistance to B. tabaci, but this resistance could not be detected until 15-day post-infestation when smaller number of adults B. tabaci appeared. In M. persicae assay, the longevity of all development stages of insects, and the fecundity of insects were not significantly affected when feeding on BSM-treated plants. In TMV assay, the BSM treatment also reduced virus-induced lesions in early infection time. However, the titer of TMV in BSM treated plants increased greatly over time and was over 4

In [None]:
# Check behavior of tokenizer
test_span1 = test_doc[11:14]
test_span2 = test_doc[112:115]

# Not period is its own token
print(test_span1.text)
for token in test_span1:
  print(" token=", token.text, token.idx)

# But in the following case, it is not.
print(test_span2.text)
for token in test_span2:
  print(" token=", token.text, token.idx)

weed control.
 token= weed 58
 token= control 63
 token= . 70
adults B. tabaci
 token= adults 620
 token= B. 627
 token= tabaci 630


In [None]:
for ent in test_doc.ents:
  print(ent.text, ent.label_, ent.start, ent.end)

Bensulfuron-methyl Organic_compound_other 0 1
BSM Organic_compound_other 2 3
BSM residue Organic_compound_other 14 16
BSM residue Organic_compound_other 37 39
BSM Organic_compound_other 56 57
Bemisia tabaci Multicellular_organism 61 63
Myzus persicae Multicellular_organism 64 66
Tobacco mosaic virus Virus 68 71
TMV Virus 72 73
Nicotiana tabacum Multicellular_organism 75 77
BSM Organic_compound_other 83 84
tobaccos Multicellular_organism 91 92
B. tabaci Multicellular_organism 95 97
B. tabaci Multicellular_organism 113 115
M. persicae Multicellular_organism 118 120
TMV Virus 147 148
BSM Organic_compound_other 151 152
TMV Virus 167 168
BSM Organic_compound_other 169 170
BSM Organic_compound_other 194 195
jasmonic acid Plant_hormone 198 200
JA Plant_hormone 201 202
salicylic acid Plant_hormone 204 206
SA Plant_hormone 207 208
tobacco Multicellular_organism 211 212
JA and SA signaling pathways Biochemical_pathway 223 228
NtWIPK DNA 231 232
NtPR1a DNA 233 234
NtPAL DNA 236 237
NtPR1a DNA 238

In [None]:
test_doc[61], test_doc[62]

(Bemisia, tabaci)

In [None]:
# Build entity dictionary: {start: [end, text, label]}
test_edict = {}
for ent in test_doc.ents:
  if ent.start not in test_edict:
    test_edict[ent.start] = [ent.end, ent.text, ent.label_]
  else:
    print("ERR: Duplicate start index", ent.start)

In [None]:
test_edict

{0: [1, 'Bensulfuron-methyl', 'Organic_compound_other'],
 2: [3, 'BSM', 'Organic_compound_other'],
 14: [16, 'BSM residue', 'Organic_compound_other'],
 37: [39, 'BSM residue', 'Organic_compound_other'],
 56: [57, 'BSM', 'Organic_compound_other'],
 61: [63, 'Bemisia tabaci', 'Multicellular_organism'],
 64: [66, 'Myzus persicae', 'Multicellular_organism'],
 68: [71, 'Tobacco mosaic virus', 'Virus'],
 72: [73, 'TMV', 'Virus'],
 75: [77, 'Nicotiana tabacum', 'Multicellular_organism'],
 83: [84, 'BSM', 'Organic_compound_other'],
 91: [92, 'tobaccos', 'Multicellular_organism'],
 95: [97, 'B. tabaci', 'Multicellular_organism'],
 113: [115, 'B. tabaci', 'Multicellular_organism'],
 118: [120, 'M. persicae', 'Multicellular_organism'],
 147: [148, 'TMV', 'Virus'],
 151: [152, 'BSM', 'Organic_compound_other'],
 167: [168, 'TMV', 'Virus'],
 169: [170, 'BSM', 'Organic_compound_other'],
 194: [195, 'BSM', 'Organic_compound_other'],
 198: [200, 'jasmonic acid', 'Plant_hormone'],
 201: [202, 'JA', 'Pla

In [None]:
# list with masked sentences
test_masked_sents = []
test_sent = [] # list of tokens
start_tidx = 0

# Go through each token in the doc
for tidx, token in enumerate(test_doc):
  if token.text != ".":
    test_sent.append(token)
  else:
    #print(f"## start_idx={start_idx}, tidx={tidx}")
    #print(sent)

    # Go through edict to see if there is anything that need to be masked
    #   start_tidx: the WHOLE doc-based index for the start token of this sent.
    #   tidx: at this point, the WHOLE doc-based index for "." for this sent.
    #   sidx below is the WHOE doc-based index for each token in this sent
    for sidx in range(start_tidx, tidx):
      if sidx in test_edict:

        # unigram entity
        if test_edict[sidx][0] - sidx == 1:
          # get a tmp list going because there can be more than one entities
          # in the same sentence
          test_sent_tmp = [t.text for t in test_sent]

          # midx: the token index to mask where the starting index for this
          # sent is 0.
          midx = sidx - start_tidx
          test_sent_tmp[midx] = "[MASK]"

          # add masked sentence to list
          test_masked_sents.append(" ".join(test_sent_tmp))

    # reset variables
    test_sent = []
    start_tidx = tidx + 1


In [None]:
test_masked_sents

['[MASK] ( BSM ) is widely used in paddy soil for weed control',
 'Bensulfuron-methyl ( [MASK] ) is widely used in paddy soil for weed control',
 'In this study , we have found significant effects of [MASK] on the infestation of Bemisia tabaci , Myzus persicae , and Tobacco mosaic virus ( TMV ) in Nicotiana tabacum',
 'In this study , we have found significant effects of BSM on the infestation of Bemisia tabaci , Myzus persicae , and Tobacco mosaic virus ( [MASK] ) in Nicotiana tabacum',
 'The soil was treated with [MASK] before the pest inoculation',
 'The herbicide-treated [MASK] showed resistance to B. tabaci , but this resistance could not be detected until 15-day post-infestation when smaller number of adults B. tabaci appeared',
 'In [MASK] assay , the BSM treatment also reduced virus-induced lesions in early infection time',
 'In TMV assay , the [MASK] treatment also reduced virus-induced lesions in early infection time',
 'However , the titer of [MASK] in BSM treated plants inc

### Test load tokenizer

In [38]:
example = "Cytokinins are plant hormones that promote cell division, or " +\
          "cytokinesis, in plant roots and shoots."

input_ids = tokenizer1(example)["input_ids"]
for idx, input_id in enumerate(input_ids):
  print(idx, tokenizer1.convert_ids_to_tokens(input_id))

0 [CLS]
1 cytokinins
2 are
3 plant
4 hormones
5 that
6 promote
7 cell
8 division
9 ,
10 or
11 cytokinesis
12 ,
13 in
14 plant
15 roots
16 and
17 shoots
18 .
19 [SEP]


In [40]:
test_mask_id = 3
test_list = tokenizer1.convert_ids_to_tokens(input_ids)[1:-1]
test_list[3] = "[MASK]"
test_str = " ".join(test_list)
test_str

'cytokinins are plant [MASK] that promote cell division , or cytokinesis , in plant roots and shoots .'

In [41]:
# Even though there is extra spaced added before "," and ".", the number of
# tokens remain the same.
len(tokenizer1(test_str)["input_ids"])

20

In [48]:
# Based on the token indesx, which include special tokens in 0 and -1 positions
to_mask = [1, 3, 4, 7, 8, 11, 14, 15, 17]

for mask_idx in to_mask:
  input_ids_tmp = input_ids.copy()
  input_ids_tmp[mask_idx] = tokenizer1.mask_token_id
  txt = " ".join(tokenizer1.convert_ids_to_tokens(input_ids_tmp)[1:-1])
  print(txt)
  for pred in fill_mask1(txt):
    print(f"  {pred['token_str']}, score:{pred['score']:.4f}")

[MASK] are plant hormones that promote cell division , or cytokinesis , in plant roots and shoots .
  there, score:0.6489
  cytokinins, score:0.0486
  what, score:0.0454
  brassinosteroids, score:0.0294
  they, score:0.0230
cytokinins are [MASK] hormones that promote cell division , or cytokinesis , in plant roots and shoots .
  plant, score:0.7351
  the, score:0.0778
  endogenous, score:0.0133
  common, score:0.0105
  important, score:0.0077
cytokinins are plant [MASK] that promote cell division , or cytokinesis , in plant roots and shoots .
  hormones, score:0.8892
  cytokinins, score:0.0284
  regulators, score:0.0175
  factors, score:0.0079
  phytohormones, score:0.0068
cytokinins are plant hormones that promote [MASK] division , or cytokinesis , in plant roots and shoots .
  cell, score:0.9443
  the, score:0.0137
  nuclear, score:0.0039
  mitotic, score:0.0033
  their, score:0.0025
cytokinins are plant hormones that promote cell [MASK] , or cytokinesis , in plant roots and shoots .