In [None]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from data_utils import label2id, chunk_examples, data_augmenter, tokenizer_and_align, filter_data
from datasets import Dataset, load_from_disk, concatenate_datasets, DatasetDict
from model import get_model

In [None]:
# fn = './data/train.json'
fn = './data/augmented_data.json'

In [None]:
with open(fn, 'r') as fp:
    data = json.load(fp)

In [None]:
model, tokenizer = get_model('dslim/bert-large-NER')

### Dataset creation

In [None]:
%%time
x = Dataset.from_list(data)

In [None]:
x = x.map(tokenizer_and_align, num_proc=16, fn_kwargs={'tokenizer': tokenizer})
print(len(x))

In [None]:
x = x.filter(filter_data, num_proc=16, fn_kwargs={'p': 0.95})
print(len(x))

In [None]:
tokenizer(
    'This is my name: Romit Jain',
    max_length=2048,
    return_overflowing_tokens=True,
    add_special_tokens=False,
    is_split_into_words=False
)

In [None]:
t = tokenizer(
    'This is my name: Romit Jain'.split(' '),
    # max_length=5,
    return_overflowing_tokens=True,
    add_special_tokens=False,
    is_split_into_words=True
)

In [None]:
print(t.tokens(), t.word_ids())

In [None]:
ds = x.map(
    chunk_examples,
    num_proc=1,
    batched=True,
    batch_size=10,
    remove_columns=x.column_names,
    fn_kwargs={'max_len': 256}
)

print(len(ds))

### Dataset saving and merging

In [None]:
# ds.save_to_disk('./data/processed/data/')
ds.save_to_disk('./data/processed/augmented_data/')

In [None]:
ds0 = load_from_disk('./data/processed/data/')

In [None]:
ds_final = concatenate_datasets([ds0, ds])

In [None]:
np.random.choice([1, 2, 3, 4], size=2, replace=False)

In [None]:
unique_idx = np.unique(ds['document_id'] + ds0['document_id'])
val_idx = np.random.choice(
    unique_idx,
    size=int(0.2 * unique_idx.shape[0]),
    replace=False
)

train_idx = np.setdiff1d(unique_idx, val_idx)

In [None]:
def get_train_sample(example): return True if example['document_id'] in train_idx else False
def get_val_sample(example): return True if example['document_id'] in val_idx else False

In [None]:
train_ds = ds_final.filter(get_train_sample, num_proc=16)
val_ds = ds_final.filter(get_val_sample, num_proc=16)

print(len(train_ds), len(val_ds))

In [None]:
ds_to_save = DatasetDict({'train': train_ds, 'val': val_ds})
ds_to_save.save_to_disk('./data/processed/dataset_2')

### Data validation

In [None]:
idx = np.random.randint(0, len(data))
start= np.random.randint(0, len(data[idx]['tokens']))
buffer = 2000
# idx, start = 0, 0

temp = data[idx]

for tokens, labels, ws in zip(
    temp['tokens'][start: start+buffer],
    temp['labels'][start: start+buffer],
    temp['trailing_whitespace'][start: start+buffer]
):
    # if labels == 'O':
    #     continue
    
    local = {'tokens': [tokens], 'labels': [labels], 'trailing_whitespace': [ws]}
    ans = tokenizer_and_align(local, tokenizer)

    print(f"""
    Original: {tokens} {labels}
    Transformed: {ans['tokens']} {ans['aligned_tokens']['input_ids']} {ans['aligned_labels']}
    """)

### Data augmentation

In [None]:
import spacy
from copy import deepcopy
from data_utils import random_augmentation

nlp = spacy.load("en_core_web_sm")

In [None]:
def spacy_tokenize(text):
    doc = nlp(text)
    return [token.text for token in doc]

In [None]:
# For every document, whenever a !'O' type token comes
# wait for 'O' token to come again, until then collect the tokens
# Concatenate the tokens, use the label of the last token
# and send for data augmentation.
# Replace the elements by the new augmented elements and add in the augmented docs
# and move on to the next document

augmented_docs = []

for idx, _ in enumerate(data):
    tokens_to_augment = []
    ids_to_replace = []
    prev_label = None
    flag=0
    
    for pos, it in enumerate(zip(
        data[idx]['tokens'],
        data[idx]['trailing_whitespace'],
        data[idx]['labels']
    )):
        token, ws, label = it
        
        if label != 'O':
            tokens_to_augment.append(f' {token}' if ws else token)
            ids_to_replace.append(pos)
            prev_label = label

        if label == 'O' and tokens_to_augment:
            result = random_augmentation(
                ' '.join(tokens_to_augment).strip(),
                prev_label
            )
            result = spacy_tokenize(result)

            temp_doc = deepcopy(data[idx])
            for id, replace_token in zip(ids_to_replace, result):
                temp_doc['tokens'][id] = replace_token

            augmented_docs.append(temp_doc)

            tokens_to_augment = []
            ids_to_replace = []
            prev_label = None

In [None]:
with open('./data/augmented_data.json', 'w') as fp:
    json.dump(augmented_docs, fp)