In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [None]:
import os 
import torch 
from torch import nn

import numpy as np 

import src.utils as utils 
import src.globals as globals
import src.data_handler as handling

from datasets import Dataset 

from tokenizers import BertWordPieceTokenizer


In [None]:
dataset_path = os.path.join(globals.DATA_FOLDER,'training_set.json')

In [None]:
model, vocab = utils.load_embedding_model()

In [None]:
squad_dataset = handling.RawSquadDataset(dataset_path)

df = squad_dataset.train_df.copy()

In [None]:

hf_dataset = Dataset.from_pandas(df)

In [None]:
vocab_path = os.path.join(globals.DATA_FOLDER,'bert-base-uncased-vocab.txt')

In [None]:
wp_tokenizer = BertWordPieceTokenizer(vocab_path, lowercase=True)
wp_tokenizer.enable_padding(direction="right", pad_type_id=1)
wp_tokenizer.enable_truncation(512)

In [None]:
from tokenizers import Encoding

def transform(batch):

    encodings: list[Encoding] = wp_tokenizer.encode_batch(list(zip(batch['question'],batch['context'])))

    starts = list(map(lambda x: x[0],batch['label_char']))
    ends = list(map(lambda x: x[1],batch['label_char']))

    encodings = {
        'ids': [e.ids for e in encodings],
        'mask': [e.attention_mask for e in encodings],
        'offsets': [e.offsets for e in encodings], 
        'sequence_ids': [e.sequence_ids for e in encodings],
        'type_ids': [e.type_ids for e in encodings],
        'context_text': batch['context'],
        'question_text': batch['question'],
        'answer_text': batch['answer'],
    }

    return encodings

hf_dataset.set_transform(transform,output_all_columns=False)

In [None]:
print(hf_dataset[4:6])

In [None]:
ex = hf_dataset[57912]
start_token = ex['label_token_start']
end_token = ex['label_token_end']
start_char = ex['offsets'][start_token][0]
end_char = ex['offsets'][end_token][1]

print(start_char)
print(end_char)

ex['context_text'][start_char:end_char]
ex['context_text'][ex['label_char_start']:ex['label_char_end']]
ex['answer_text']

len(ex['context_ids']) == len(ex['context_tokens'])


In [None]:
start_c = ex['label_char_start']
end_c = ex['label_char_end']

starts, ends = zip(*ex['offsets'])

try :
    start_idx = starts.index(start_c)
except :
    print('errore start')

try: 
    end_idx = ends.index(end_c)
except :
    print('errore end')


ex['context_tokens'][start_idx] == ex['answer_tokens'][0]
ex['context_tokens'][end_idx] == ex['answer_tokens'][-1]