In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [None]:
import os 
import torch 
from torch import nn

import numpy as np 

import lib.utils as utils 
import lib.globals as globals
import lib.data_handling as handling

In [None]:
dataset_path = os.path.join(globals.DATA_FOLDER,'training_set.json')

In [None]:
model, vocab = utils.load_embedding_model()

In [None]:
squad_dataset = handling.RawSquadDataset(dataset_path)

df = squad_dataset.df.copy()

In [None]:
df[df['question_id']=='56df85525ca0a614008f9bfe']

In [None]:
starts = np.array(list(map(lambda x: x[0],df['label_char'])))
ends = np.array(list(map(lambda x: x[1],df['label_char'])))

s = starts - ends

df[s==0]

In [None]:
from tokenizers import  Tokenizer
from tokenizers.models import WordLevel
from tokenizers.normalizers import Lowercase, Sequence, Strip, StripAccents
from tokenizers.pre_tokenizers import Punctuation
from tokenizers.pre_tokenizers import Sequence as PreSequence
from tokenizers.pre_tokenizers import Whitespace

from datasets import Dataset 


In [None]:

hf_dataset = Dataset.from_pandas(squad_dataset.df)

In [None]:
tokenizer = Tokenizer(WordLevel(vocab,unk_token=globals.UNK_TOKEN))
tokenizer.normalizer = Sequence([StripAccents(), Lowercase(), Strip()])
tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
tokenizer.enable_padding(direction="right", pad_id=vocab[globals.PAD_TOKEN], pad_type_id=1, pad_token=globals.PAD_TOKEN)

In [None]:
from tokenizers import Encoding

def transform(batch):

    context_encodings: list[Encoding] = tokenizer.encode_batch(batch['context'])
    question_encodings: list[Encoding] = tokenizer.encode_batch(batch['question'])

    starts = list(map(lambda x: x[0],batch['label_char']))
    ends = list(map(lambda x: x[1],batch['label_char']))

    encodings = {
        'context_ids': [e.ids for e in context_encodings],
        'question_ids': [e.ids for e in question_encodings],
        'context_mask': [e.attention_mask for e in context_encodings],
        'question_mask': [e.attention_mask for e in question_encodings],
        'offsets': [e.offsets for e in context_encodings], 
        'context_text': batch['context'],
        'question_text': batch['question'],
        'answer_text': batch['answer'],
        #'tokens': [e.tokens for e in context_encodings], 
        'label_token_start': [e.char_to_token(starts[i]) for i,e in enumerate(context_encodings)],
        'label_token_end': [e.char_to_token(ends[i]-1) for i,e in enumerate(context_encodings)],
        'label_char_start': starts,
        'label_char_end': ends
    }

    return encodings

hf_dataset.set_transform(transform,output_all_columns=False)

In [None]:
print(hf_dataset[13692])

In [None]:
ex = hf_dataset[13692]
start_token = ex['label_token_start']
end_token = ex['label_token_end']
start_char = ex['offsets'][start_token][0]
end_char = ex['offsets'][end_token][1]

print(start_char)
print(end_char)

ex['context_text'][start_char:end_char]
ex['context_text'][ex['label_char_start']:ex['label_char_end']]

#ex['tokens'][start_token:end_token]


In [None]:
start_c = ex['label_char_start']
end_c = ex['label_char_end']

starts, ends = zip(*ex['offsets'])

try :
    start_idx = starts.index(start_c)
except :
    print('errore start')

try: 
    end_idx = ends.index(end_c)
except :
    print('errore end')


In [None]:
start_c
end_c