In [2]:
import numpy as np
import torch
import joblib

import config
import dataset
import engine
from model import EntityModel

import json
from tqdm import tqdm
import re

In [3]:
with open("../input/test_ner.json") as json_file:
    test_data = json.load(json_file)
    test_paper_ids = list(test_data.keys())

In [4]:
meta_data = joblib.load("meta.bin")
enc_tag = meta_data["enc_tag"]
num_tag = len(list(enc_tag.classes_))

In [5]:
device = torch.device("cuda")
model = EntityModel(num_tag)
model.load_state_dict(torch.load(config.MODEL_PATH))
model.to(device)
print("Model loading completed!")

Model loading completed!


In [6]:
# Main BERT tokenizer
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [7]:
for index, paper_id in enumerate(test_paper_ids):
    print(index, '\t', paper_id)

0 	 2100032a-7c33-4bff-97ef-690822c43466
1 	 2f392438-e215-4169-bebf-21ac4ff253e1
2 	 3f316b38-1a24-45a9-8d8c-4e05a42257c6
3 	 8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60


In [8]:
# Testing if aligning text will be possible
print(test_data["2100032a-7c33-4bff-97ef-690822c43466"][30])
print()
print(tokenizer(test_data["2100032a-7c33-4bff-97ef-690822c43466"][30], is_split_into_words=True).word_ids())

['the', 'three', 'SNPs', 'previously', 'associated', 'with', 'education', 'were', 'variants', 'included', 'on', 'commercially', 'available', 'microarrays', 'and', 'thus', 'were', 'imputed', 'into', 'their', 'datasets', 'Rietveld', 'et', 'al', '2013', 'In', 'the', 'COGENT1', 'studies', 'SNPs', 'were', 'imputed', 'using', 'HapMap3', 'reference', 'panels', 'as', 'previously', 'described', 'Lencz', 'et', 'al', '2013', 'COGENT2', 'samples', 'that', 'did', 'not', 'have', 'genotypes', 'for', 'the', 'SNPs', 'of', 'interest', 'were', 'imputed', 'using', 'IMPUTE2', 'Howie', 'et', 'al', '2009', '.']

[None, 0, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 13, 13, 14, 15, 16, 17, 17, 18, 19, 20, 20, 20, 21, 21, 21, 21, 22, 23, 24, 25, 26, 27, 27, 27, 28, 29, 29, 29, 30, 31, 31, 32, 33, 33, 33, 33, 34, 35, 36, 37, 38, 39, 39, 40, 41, 42, 43, 43, 43, 44, 45, 46, 47, 48, 49, 49, 49, 50, 51, 52, 52, 52, 53, 54, 55, 56, 56, 57, 58, 58, 58, 59, 60, 61, 62, 63, None]


In [10]:
# Placeholders
paper_token_results = {}
paper_tokenized_sentences = {}
paper_words_ids = {}

# Main inference loop
for paper_id in tqdm(test_paper_ids):
    sentences = test_data[paper_id]
    tokenized_sentences = []
    tokenized_sentences_word_ids = []
    aligned_results = []
    for sentence in sentences:
        tokenized_output = tokenizer(sentence, is_split_into_words=True)
        tokenized_sentences.append(tokenized_output['input_ids'])
        tokenized_sentences_word_ids.append(tokenized_output.word_ids())
    paper_tokenized_sentences[paper_id] = tokenized_sentences
    paper_words_ids[paper_id] = tokenized_sentences_word_ids

    # Inference dataset
    test_dataset = dataset.EntityDataset(
        texts = sentences,
        tags = [[0] * len(sentence) for sentence in sentences]
    )

    with torch.no_grad():
        test_dataset
        tags = []
        results = []
        for data in test_dataset:
            for k, v in data.items():
                data[k] = v.to(device).unsqueeze(0)
            tag, _ = model(**data)
            tags.append(tag)
        for tag, tokenized_sentence in zip(tags, tokenized_sentences):
            result = enc_tag.inverse_transform(tag.argmax(2).cpu().numpy().reshape(-1))[:len(tokenized_sentence)]
            results.append(result)
        paper_token_results[paper_id] = results

100%|██████████| 91/91 [00:00<00:00, 705.01it/s]
100%|██████████| 713/713 [00:00<00:00, 1019.00it/s]
100%|██████████| 263/263 [00:00<00:00, 1417.50it/s]
100%|██████████| 186/186 [00:00<00:00, 1282.48it/s]


In [21]:
label_all_tokens = True

def tokenize_and_align_labels(tokenized_input, labels, word_ids):
    # for i, label in enumerate(labels):
    previous_word_idx = None
    label_ids = []
    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)
        elif word_idx != previous_word_idx:
            label_ids.append(labels[word_idx])
        else:
            label_ids.append(labels[word_idx] if label_all_tokens else -100)
        previous_word_idx = word_idx
    return label_ids

In [22]:
# Aligning labels loop
paper_aligned_labels = {}

for paper_id in tqdm(test_paper_ids):
    sentences = test_data[paper_id]
    paper_aligned_labels[paper_id] = []
    for i, sentence in enumerate(sentences):
        aligned_labels = tokenize_and_align_labels(paper_tokenized_sentences[paper_id][i], 
                                                   paper_token_results[paper_id][i], 
                                                   paper_words_ids[paper_id][i])
        paper_aligned_labels[paper_id].append(aligned_labels)

100%|██████████| 4/4 [00:00<00:00, 78.23it/s]


### 3 Main output

1. <code>paper_token_results</code>: NER tag results with B, I and O tokens

2. <code>paper_tokenized_sentences</code>: output for the tokenization process e.g. 101, 102, etc.

3. <code>paper_words_ids</code>: word ids for each word in a tokenized sentence

4. <code>paper_aligned_labels</code>: containing aligned labels (B, I, O)

In [37]:
# Locating positive predictions
print("Positive predictions\n")
for paper_id in test_paper_ids:
    print(f"Paper id: {paper_id}")
    for index, pred in enumerate(paper_token_results[paper_id]):
        if 'B' in pred[1:-1]:
            print(f'Index: ', index)
            print('\t', paper_aligned_labels[paper_id][index])
            print(len(paper_aligned_labels[paper_id][index]))
            print('\t', tokenizer.convert_ids_to_tokens(paper_tokenized_sentences[paper_id][index]))
            print(len(tokenizer.convert_ids_to_tokens(paper_tokenized_sentences[paper_id][index])))
            
            # Getting the positive labels
            result = tokenizer.convert_ids_to_tokens(paper_tokenized_sentences[paper_id][index])
            positive_token_list = []
            for position, token, tag in zip(range(len(result)), tokenizer.convert_ids_to_tokens(paper_tokenized_sentences[paper_id][index]), paper_aligned_labels[paper_id][index]):
                if tag in ['B', 'I'] and position not in [0, 1, len(result)-1, len(result)-2]:
                    positive_token_list.append(token)
            print(positive_token_list)
            print()

    print("\n")

Positive predictions

Paper id: 2100032a-7c33-4bff-97ef-690822c43466
Index:  84
	 [-100, 'B', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', -100]
70
	 ['[CLS]', 'the', 'laboratory', 'for', 'ne', '##uro', 'imaging', 'at', 'the', 'university', 'of', 'southern', 'california', 'finally', 'several', 'publicly', 'available', 'data', '##set', '##s', 'were', 'included', 'we', 'kindly', 'thank', 'the', 'investigative', 'teams', 'and', 'staff', '##s', 'of', 'the', 'pediatric', 'imaging', 'ne', '##uro', '##co', '##gni', '##tion', 'and', 'genetics', 'ping', 'study', 'the', 'alzheimer', 's', 'disease', 'ne', '##uro', '##ima', '##ging', 'initiative', 'ad', '##ni', 'project', 'and', 'the', 'studies', 'who', 'made', '

**Notes: **

It looks like using the word_ids() method as a medium to align words is not that great, let's just use the tokenized sentence

In [15]:
# Trying out the positive sentence from the last paper
print(paper_aligned_labels['8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60'][66])

[-100, 'B', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', -100, 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', -100, 'O', -100, 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', -100, 'O', 'O', 'O', 'O', 'O', 'O', -100, -100, 'B', 'I', 'I', 'I', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B', -100]


In [None]:
# Checking len
faulty_results = []
for paper_id in test_paper_ids:
    tokenized_sentences_len = [len(sentence) for sentence in test_data[paper_id]]
    tokens_len = [len(tokens) for tokens in paper_token_results[paper_id]]
    for index, i, j in zip(range(len(tokenized_sentences_len)), tokenized_sentences_len, tokens_len):
        if 'B' in paper_token_results[paper_id][index][1:-1]:
            if i != j:
                j -= 2
            try:
                assert i == j
            except AssertionError:
                print(paper_id)
                print(i, j)
                # print(index)
                print(test_data[paper_id][index])
                print(paper_token_results[paper_id][index])
                faulty_results.append(paper_id)
# print('Faulty results:')
for item in list(set(faulty_results)):
    print('\t', item)

In [None]:
# Checking len
faulty_results = []
for paper_id in test_paper_ids:
    tokenized_sentences_len = [len(sentence) for sentence in paper_tokenized_sentences[paper_id]]
    tokens_len = [len(tokens) for tokens in paper_token_results[paper_id]]
    for index, i, j in zip(range(len(tokenized_sentences_len)), tokenized_sentences_len, tokens_len):
        if 'B' in paper_token_results[paper_id][index][1:-1]:
            print(f"Index: {index}")
            if i != j:
                j -= 2
            try:
                assert i == j
            except AssertionError:
                print(paper_id)
                print(i, j)
                # print(index)
                print(config.TOKENIZER.convert_ids_to_tokens(paper_tokenized_sentences[paper_id][index]))
                print(paper_token_results[paper_id][index])
                faulty_results.append(paper_id)
# print('Faulty results:')
for item in list(set(faulty_results)):
    print('\t', item)

In [38]:
preds_dict = {}

for paper_id in test_paper_ids:
    preds_dict[paper_id] = []
    print(f'Paper ID:\t{paper_id}')
    print("")
    for index, result in enumerate(paper_token_results[paper_id]):
        preds_preds = []
        if 'B' in result[1:-1]:
            print(test_data[paper_id][index])
            print(len(test_data[paper_id][index]))
            print(result)
            print(len(result))
            preds = []
            for i, item in enumerate(result[1:-1]):
                try:
                    if item == 'B' or item == 'I':
                        preds.append(test_data[paper_id][index][i])
                except IndexError:
                    pass
            preds = [item for item in preds if item != '.']    
            preds_dict[paper_id].append(" ".join(preds))
            print(preds)
            print("\n\n")
            # preds = [item for item in preds if item != '.']

Paper ID:	2100032a-7c33-4bff-97ef-690822c43466

['the', 'Laboratory', 'for', 'Neuro', 'Imaging', 'at', 'the', 'University', 'of', 'Southern', 'California', 'Finally', 'several', 'publicly', 'available', 'datasets', 'were', 'included', 'we', 'kindly', 'thank', 'the', 'investigative', 'teams', 'and', 'staffs', 'of', 'the', 'Pediatric', 'Imaging', 'Neurocognition', 'and', 'Genetics', 'PING', 'study', 'the', 'Alzheimer', 's', 'Disease', 'Neuroimaging', 'Initiative', 'ADNI', 'project', 'and', 'the', 'studies', 'who', 'made', 'their', 'data', 'available', 'in', 'dbGaP', '.']
54
['B' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O'
 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O'
 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'B' 'I' 'I' 'I' 'I' 'I' 'I' 'I' 'I'
 'I' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'B']
64
['the', 'studies', 'who', 'made', 'their', 'data', 'available', 'in', 'dbGaP']



Paper ID:	2f392438-e215-4169-bebf-21ac4ff253e1

['at', 'the', 'Organization

In [43]:
preds_dict = {}

for paper_id in test_paper_ids:
    preds_dict[paper_id] = []
    print(f'Paper ID:\t{paper_id}')
    print("")
    for index, result in enumerate(paper_token_results[paper_id]):
        preds_preds = []
        if 'B' in result[1:-1]:
            print(test_data[paper_id][index])
            print(len(test_data[paper_id][index]))
            print(result)
            print(len(result))
            preds = []
            for i, item in enumerate(result[1:-1]):
                try:
                    if item == 'B' or item == 'I':
                        preds.append(test_data[paper_id][index][i])
                except IndexError:
                    pass
            preds = [item for item in preds if item != '.']    
            preds_dict[paper_id].append(" ".join(preds))
            print(preds)
            print("\n\n")
            preds = [item for item in preds if item != '.']

Paper ID:	2100032a-7c33-4bff-97ef-690822c43466

['the', 'Laboratory', 'for', 'Neuro', 'Imaging', 'at', 'the', 'University', 'of', 'Southern', 'California', 'Finally', 'several', 'publicly', 'available', 'datasets', 'were', 'included', 'we', 'kindly', 'thank', 'the', 'investigative', 'teams', 'and', 'staffs', 'of', 'the', 'Pediatric', 'Imaging', 'Neurocognition', 'and', 'Genetics', 'PING', 'study', 'the', 'Alzheimer', 's', 'Disease', 'Neuroimaging', 'Initiative', 'ADNI', 'project', 'and', 'the', 'studies', 'who', 'made', 'their', 'data', 'available', 'in', 'dbGaP', '.']
54
['B' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O'
 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O'
 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'B' 'I' 'I' 'I' 'I' 'I' 'I' 'I' 'I'
 'I' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'O' 'B']
64
['the', 'studies', 'who', 'made', 'their', 'data', 'available', 'in', 'dbGaP']



Paper ID:	2f392438-e215-4169-bebf-21ac4ff253e1

['at', 'the', 'Organization

In [44]:
preds_dict_str = {
    test_paper_ids[0]: "|".join(preds_dict[test_paper_ids[0]]).lower(),
    test_paper_ids[1]: "|".join(preds_dict[test_paper_ids[1]]).lower(),
    test_paper_ids[2]: "|".join(preds_dict[test_paper_ids[2]]).lower(),
    test_paper_ids[3]: "|".join(preds_dict[test_paper_ids[3]]).lower()
}

In [45]:
preds_dict_str

{'2100032a-7c33-4bff-97ef-690822c43466': 'the studies who made their data available in dbgap',
 '2f392438-e215-4169-bebf-21ac4ff253e1': 'and science study timss begun in 2002|science study timss this report comparative indicators|trends in international mathematics and science study|international mathematics and science study timss information|trends in international mathematics and science study|trends in international mathematics and science study|and science study timss in 2007 on|trends in international mathematics and science study',
 '3f316b38-1a24-45a9-8d8c-4e05a42257c6': 'national weather service|the shapefile',
 '8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60': 'ruccs united states department'}