In [1]:
# built in
import json
import random
# append to path to allow relative imports
import sys
sys.path.append("..")

# 3rd party
import pandas as pd
from tqdm import tqdm
from transformers import BertForTokenClassification
import torch

# own
from utils.parse import ParseUtils

# Prep Data

In [2]:
MAX_LENGTH = 64 # max no. words for each sentence.
OVERLAP = 20 # if a sentence exceeds MAX_LENGTH, we split it to multiple sentences with overlapping

MAX_SAMPLE = None # set a small number for experimentation, set None for production.

In [3]:
TRAIN_CSV = '../../data/coleridgeinitiative-show-us-the-data/train.csv'
TRAIN_DATA = '../../data/coleridgeinitiative-show-us-the-data/train'

train = pd.read_csv(TRAIN_CSV)
train = train[:MAX_SAMPLE]
print(f'No. raw training rows: {len(train)}')

No. raw training rows: 19661


In [4]:
train = train.groupby('Id').agg({
    'pub_title': 'first',
    'dataset_title': '|'.join,
    'dataset_label': '|'.join,
    'cleaned_label': '|'.join
}).reset_index()

print(f'No. grouped training rows: {len(train)}')

No. grouped training rows: 14316


In [5]:
train.head()

Unnamed: 0,Id,pub_title,dataset_title,dataset_label,cleaned_label
0,0007f880-0a9b-492d-9a58-76eb0b0e0bd7,The Impact of ICT Training on Income Generatio...,Program for the International Assessment of Ad...,Program for the International Assessment of Ad...,program for the international assessment of ad...
1,0008656f-0ba2-4632-8602-3017b44c2e90,Finnish Ninth Graders’ Gender Appropriateness ...,Trends in International Mathematics and Scienc...,Trends in International Mathematics and Scienc...,trends in international mathematics and scienc...
2,000e04d6-d6ef-442f-b070-4309493221ba,Economic Research Service: Specialized Agency...,Agricultural Resource Management Survey,Agricultural Resources Management Survey,agricultural resources management survey
3,000efc17-13d8-433d-8f62-a3932fe4f3b8,Risk factors and global cognitive status relat...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI|Alzheimer's Disease Neuroimaging Initiati...,adni|alzheimer s disease neuroimaging initiati...
4,0010357a-6365-4e5f-b982-582e6d32c3ee,Timelines of COVID-19 Vaccines,SARS-CoV-2 genome sequence,genome sequence of COVID-19,genome sequence of covid 19


In [6]:
papers = {}
for paper_id in train['Id'].unique():
    with open(f'{TRAIN_DATA}/{paper_id}.json', 'r') as f:
        paper = json.load(f)
        papers[paper_id] = paper

In [7]:
cnt_pos, cnt_neg = 0, 0 # number of sentences that contain/not contain labels
ner_data = []

pbar = tqdm(total=len(train))
for i, id, dataset_label in train[['Id', 'dataset_label']].itertuples():
    # paper
    paper = papers[id]
    
    # labels
    labels = dataset_label.split('|')
    labels = [ParseUtils.clean_training_text(label) for label in labels]
    
    # sentences
    sentences = set([ParseUtils.clean_training_text(sentence) for section in paper 
                 for sentence in section['text'].split('.') 
                ])
    sentences = ParseUtils.shorten_sentences(sentences) # make sentences short
    # only accept sentences with length > 10 chars
    sentences = [sentence for sentence in sentences if len(sentence) > 10] 
    
    # positive sample
    for sentence in sentences:
        is_positive, tags = ParseUtils.tag_sentence(sentence, labels)
        if is_positive:
            cnt_pos += 1
            ner_data.append(tags)
        elif any(word in sentence.lower() for word in ['data', 'study']): 
            ner_data.append(tags)
            cnt_neg += 1
    
    # process bar
    pbar.update(1)
    pbar.set_description(f"Training data size: {cnt_pos} positives + {cnt_neg} negatives")
    
# shuffling
#random.shuffle(ner_data)

Training data size: 3257 positives + 35706 negatives:   7%|▋         | 1001/14316 [00:30<02:54, 76.10it/s]

In [12]:
# Write data to file
with open('train_ner.json', 'w') as f:
    for row in ner_data:
        words, nes = list(zip(*row))
        row_json = {'tokens' : words, 'tags' : nes}
        json.dump(row_json, f)
        f.write('\n')

In [7]:
class NERData:
    def __init__(self):
        self.data = list()
        
    def from_json(self, filename:str, overwrite:bool=False):
        
        if self.data:
            if overwrite:
                self.data = list()
            else:
                raise ValueError(
                    'Data is present. If you want to overwrite it, '
                    'run this function again with overwrite=True.')        
        
        f = open(filename, 'r')
        
        for i,line in enumerate(f.readlines()):
            
            print('Reading data ... {}\r'.format(i), end='')
            
            # Each line is formatted in JSON format, e.g.
            # { "tokens" : ["A", "short", "sentence"],
            #   "tags"   : ["0", "0", "0"] }
            sentence = json.loads(line)
            
            # From the tokens and tags, we create a list of 
            # tuples of the form
            # [ ("A", "0"), ("short", "0"), ("sentence", "0")]
            sentence_tuple_list = [
                (token, tag) for token, tag 
                in zip(sentence["tokens"],sentence["tags"])
            ]
            
            # Each of these parsed sentences becomes an entry
            # in our overall data list
            self.data.append(sentence_tuple_list)
            
            if i==1000:
                break
            
        f.close()
        
    def get_sentences(self):
        """
        Convert each entry in self.data into a single-string sentence,
        with words separated by a blank space.
        """
        return [ " ".join([ tuple_[0] for tuple_ in tupled_sentence ]) 
                for tupled_sentence in self.data[:100] ]

In [8]:
ner_data = NERData()
ner_data.from_json('train_ner.json')

Reading data ... 0Reading data ... 1Reading data ... 2Reading data ... 3Reading data ... 4Reading data ... 5Reading data ... 6Reading data ... 7Reading data ... 8Reading data ... 9Reading data ... 10Reading data ... 11Reading data ... 12Reading data ... 13Reading data ... 14Reading data ... 15Reading data ... 16Reading data ... 17Reading data ... 18Reading data ... 19Reading data ... 20Reading data ... 21Reading data ... 22Reading data ... 23Reading data ... 24Reading data ... 25Reading data ... 26Reading data ... 27Reading data ... 28Reading data ... 29Reading data ... 30Reading data ... 31Reading data ... 32Reading data ... 33Reading data ... 34Reading data ... 35Reading data ... 36Reading data ... 37Reading data ... 38Reading data ... 39Reading data ... 40Reading data ... 41Reading data ... 42Reading data ... 43Reading data ... 44Reading data ... 45Reading data ... 46Reading data ... 47Reading data ... 48Reading data ... 49Reading da

In [16]:
text_batch = ner_data.get_sentences()[0]

# Init Model

In [None]:
model = BertForTokenClassification.from_pretrained('bert-base-uncased')

In [None]:
# Models are initialized in eval mode by default. We can call model.train() to put it in train mode.
model.train()

In [None]:
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=1e-5)

In [37]:
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(text_batch)))
tokens, len(tokens)

(['[CLS]',
  'after',
  'acquisition',
  'of',
  'this',
  'last',
  'd',
  '##ki',
  'data',
  '##set',
  'at',
  '8',
  'months',
  'of',
  'age',
  'mice',
  'were',
  'then',
  'sacrificed',
  'for',
  'his',
  '##to',
  '##logical',
  'analysis',
  '[SEP]'],
 25)

In [38]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
encoding = tokenizer(text_batch, return_tensors='pt', padding=True, truncation=True)
inputs = encoding['input_ids']
attention_mask = encoding['attention_mask']

In [39]:
len(ner_data.data[0]), text_batch

(19,
 'After acquisition of this last DKI dataset at 8 months of age mice were then sacrificed for histological analysis')

In [None]:
labels = torch.tensor([1,0]).unsqueeze(0)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
