In [None]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from data_utils import label2id, chunk_examples, data_augmenter, tokenizer_and_align, filter_data
from datasets import Dataset, load_from_disk, concatenate_datasets, DatasetDict
from model import get_model

In [None]:
# fn = './data/train.json'
fn = './data/augmented_data.json'

In [None]:
with open(fn, 'r') as fp:
    data = json.load(fp)

In [None]:
model, tokenizer = get_model('sileod/deberta-v3-large-tasksource-nli')

### Dataset creation

In [None]:
%%time
x = Dataset.from_list(data)

In [None]:
x = x.map(tokenizer_and_align, num_proc=16, fn_kwargs={'tokenizer': tokenizer})
print(len(x))

In [None]:
x = x.filter(filter_data, num_proc=16, fn_kwargs={'p': 0.95})
print(len(x))

In [None]:
ds = x.map(
    chunk_examples,
    num_proc=1,
    batched=True,
    batch_size=10,
    remove_columns=x.column_names,
    fn_kwargs={'max_len': 2048, 'buffer': 0}
)

print(len(ds))

### Dataset saving and merging

In [None]:
# ds.save_to_disk('./data/processed/data/')
ds.save_to_disk('./data/processed/augmented_data/')

In [None]:
ds0 = load_from_disk('./data/processed/data/')

In [None]:
ds_final = concatenate_datasets([ds0, ds])

In [None]:
unique_idx = np.unique(ds['document_id'] + ds0['document_id'])
val_idx = np.random.choice(
    unique_idx,
    size=int(0.2 * unique_idx.shape[0]),
    replace=False
)

train_idx = np.setdiff1d(unique_idx, val_idx)

In [None]:
def get_train_sample(example): return True if example['document_id'] in train_idx else False
def get_val_sample(example): return True if example['document_id'] in val_idx else False

In [None]:
train_ds = ds_final.filter(get_train_sample, num_proc=16)
val_ds = ds_final.filter(get_val_sample, num_proc=16)

print(len(train_ds), len(val_ds))

In [None]:
ds_to_save = DatasetDict({'train': train_ds, 'val': val_ds})
ds_to_save.save_to_disk('./data/processed/dataset_3')

### Data validation

In [None]:
idx = np.random.randint(0, len(data))
start= np.random.randint(0, len(data[idx]['tokens']))
buffer = 20
# idx, start = 0, 0

temp = data[idx]

for tokens, labels, ws in zip(
    temp['tokens'][start: start+buffer],
    temp['labels'][start: start+buffer],
    temp['trailing_whitespace'][start: start+buffer]
):
    if labels == 'O':
        continue
    
    local = {'tokens': [tokens], 'labels': [labels], 'trailing_whitespace': [ws]}
    ans = tokenizer_and_align(local, tokenizer)

    print(f"""
    Original: {tokens} {labels}
    Transformed: {ans['tokens']} {ans['aligned_tokens']['input_ids']} {ans['aligned_labels']}
    """)

### Data augmentation

In [None]:
import spacy
from copy import deepcopy
from data_utils import random_augmentation

nlp = spacy.load("en_core_web_sm")

In [None]:
def spacy_tokenize(text):
    doc = nlp(text)
    return [token.text for token in doc]

In [None]:
# For every document, whenever a !'O' type token comes
# wait for 'O' token to come again, until then collect the tokens
# Concatenate the tokens, use the label of the last token
# and send for data augmentation.
# Replace the elements by the new augmented elements and add in the augmented docs
# and move on to the next document

augmented_docs = []

for idx, _ in enumerate(data):
    tokens_to_augment = []
    ids_to_replace = []
    prev_label = None
    flag=0
    
    for pos, it in enumerate(zip(
        data[idx]['tokens'],
        data[idx]['trailing_whitespace'],
        data[idx]['labels']
    )):
        token, ws, label = it
        
        if label != 'O':
            tokens_to_augment.append(f' {token}' if ws else token)
            ids_to_replace.append(pos)
            prev_label = label

        if label == 'O' and tokens_to_augment:
            result = random_augmentation(
                ' '.join(tokens_to_augment).strip(),
                prev_label
            )
            result = spacy_tokenize(result)

            temp_doc = deepcopy(data[idx])
            for id, replace_token in zip(ids_to_replace, result):
                temp_doc['tokens'][id] = replace_token

            augmented_docs.append(temp_doc)

            tokens_to_augment = []
            ids_to_replace = []
            prev_label = None

In [None]:
with open('./data/augmented_data.json', 'w') as fp:
    json.dump(augmented_docs, fp)

### Rough

In [None]:
x[0]['aligned_tokens']['attention_mask'][710:800]

In [None]:
x[0]['aligned_tokens']['input_ids'][710:800]

In [None]:
x[0]['aligned_labels'][710:800]

In [None]:
def tokenize_row(example):
    text = []
    token_map = []
    labels = []
    targets = []
    idx = 0
    for t, l, ws in zip(example["tokens"], example["labels"], example["trailing_whitespace"]):
        text.append(t)
        labels.extend([l]*len(t))
        token_map.extend([idx]*len(t))

        if l in config['target_cols']:  
            targets.append(1)
        else:
            targets.append(0)
        
        if ws:
            text.append(" ")
            labels.append("O")
            token_map.append(-1)
        idx += 1

    tokenized = tokenizer("".join(text), return_offsets_mapping=True, truncation=True, max_length=2048)  # Adjust max_length if needed
     
    target_num = sum(targets)
    labels = np.array(labels)

    text = "".join(text)
    token_labels = []

    for start_idx, end_idx in tokenized.offset_mapping:
        if start_idx == 0 and end_idx == 0: 
            token_labels.append(label2id["O"])
            continue
        
        if text[start_idx].isspace():
            start_idx += 1
        try:
            token_labels.append(label2id[labels[start_idx]])
        except:
            continue
    length = len(tokenized.input_ids)
    
    return {
        "input_ids": tokenized.input_ids,
        "attention_mask": tokenized.attention_mask,
        "offset_mapping": tokenized.offset_mapping,
        "labels": token_labels,
        "length": length,
        "target_num": target_num,
        "group": 1 if target_num > 0 else 0,
        "token_map": token_map,
    }

In [None]:
config = {
    'target_cols': ['TARGET'],
    'valid_stride': False,
    'max_length': 512
}

In [None]:
example = x[9]
result = tokenize_row(example)

In [None]:
tokenizer.decode(result['input_ids'])

In [None]:
for i, l in zip(result['input_ids'], result['labels']):
    if l != 0:
        print(f'{tokenizer.decode(i)} => {l}')

In [None]:
for i, l in zip(ds[9]['input_ids'], ds[9]['labels']):
    if l != 0 and l != -100:
        print(f'{tokenizer.decode(i)} => {l}')

In [None]:
tokenizer('Nathalie')

In [None]:
tokenizer.decode(x[0]['aligned_tokens']['input_ids'])

In [None]:
x[0]

In [None]:
data[0]