In [None]:
import pickle
from squad_example import SquadExample
with open('data/squad_examples.pickle', 'rb') as f:
     squad_examples=pickle.load(f)
example = squad_examples[27]
print(example)

# 2. Making Data for SQuaD Task

   * In order to use the squad data for BERT it must be tokenized.
   * BERT provides basic BertTokenizer for this task
        * Sub-word level tokenization is used.
        * Should generate mapping between origianal and sub-word token
   * Some preprocessing is done for improving answer span
   * To fully utilize data, a sliding window approach is adotped

### 2.1 Initial Paramters

In [None]:
from pytorch_pretrained_bert import BertTokenizer

bert_model = 'bert-base-multilingual-cased'

tokenizer = BertTokenizer.from_pretrained(
   bert_model ,do_lower_case=True)
# The maximum total input sequence length after WordPiece tokenizationm,
max_seq_length = 128
#When splitting up a long document into chunks, how much stride to take between chunks
doc_stride = 64
#The maximum number of tokens for the question.
max_query_length = 64 


### 2.2  A look at BERT Tokenizer


 *  divides tokens into sub_tokens

In [None]:
for (i, token) in enumerate(example.doc_tokens[15:]):
    print("Token : ", token)
    sub_tokens = tokenizer.tokenize(token)
    for sub_token in sub_tokens:
        print("     Sub-Token : ",sub_token)
    if i >= 3:
        break  # loo first 4 elements

###  2.3. Mapping between the original token and sub_word tokens

In [None]:
tok_to_orig_index = []
orig_to_tok_index = []

# list to store sub-word tokens
all_doc_tokens =[]

for (i, token) in enumerate(example.doc_tokens[:]):
    orig_to_tok_index.append(len(all_doc_tokens))
    sub_tokens = tokenizer.tokenize(token)
    for sub_token in sub_tokens:
        tok_to_orig_index.append(i)
        all_doc_tokens.append(sub_token)

print('token_to_original map : ',tok_to_orig_index[:40])
print('original_to_token_map : ',orig_to_tok_index[:40])
print('sub_word tokens : ',all_doc_tokens[:30])

### 2.4. Update answer position to sub_word token index

In [None]:
tok_start_position = None
tok_end_position = None

# only for training case
print(example.doc_tokens[17])
tok_start_position = orig_to_tok_index[example.start_position]

# if document is shorter than the given end_position 
if example.end_position < len(example.doc_tokens) -1:
    tok_end_position = orig_to_tok_index[example.end_position+1] -1
else:
    tok_end_position = len(all_doc_tokens) -1

print('start position : token-based : ',example.start_position ,
      ' , sub-word token: ',tok_start_position)
print('end   : ',tok_end_position)

print(all_doc_tokens[21:27])

### 2.5 Improving answer span
    
   * Question: What year was John Smith born?
     Context: The leader was John Smith (1895-1943).
     Answer: 1895
    
   - Original whitespace-tokenized answer :  "(1895-1943).".     
     
   - WordPiece tokenization :  "( 1895 - 1943 ) ." 
     
      -> Change answer span :  1895.


In [None]:
tok_answer_text = " ".join(tokenizer.tokenize(example.orig_answer_text))

for new_start in range(tok_start_position, tok_end_position + 1):

    for new_end in range(tok_end_position, new_start - 1, -1):

        text_span = " ".join(all_doc_tokens[new_start:(new_end + 1)])

        if text_span == tok_answer_text:
            tok_start_position = new_start
            tok_end_position =  new_end

print('start : ',tok_start_position)
print('end   : ',tok_end_position)

### 2.6 Sliding Window Approach 
   * if document is longer than the maximum sequence length
     * take chunks of the up to max_seq_length and repeat sampling
       with a stride of 'doc_stride'

In [None]:
import collections

_DocSpan = collections.namedtuple("DocSpan",["start","length"])
max_seq_length = 128
max_query_length = 64

# 1. tokenize question text
query_tokens = tokenizer.tokenize(example.question_text)

# truncate query tokens if larger than max_length (64)
if len(query_tokens) > max_query_length:
    query_tokens = query_tokens[0:max_query_length]

# Account for question tokens & [CLS] [SEP] [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

start_offset=0
doc_spans = []
start_offset = 0

while start_offset < len(all_doc_tokens):

    length = len(all_doc_tokens) - start_offset
    if length > max_tokens_for_doc:
        length = max_tokens_for_doc
    doc_spans.append(_DocSpan(start=start_offset, length=length))
    if start_offset + length == len(all_doc_tokens):
        break
    start_offset += min(length, doc_stride)

print('length of original document : ' , len(all_doc_tokens))
print("doc_stride", doc_stride)
print('Document Spans : ',doc_spans)
print('\nFirst_span \n',all_doc_tokens[doc_spans[0].start:doc_spans[0].start + doc_spans[0].length])
print('\nSecond_span \n',all_doc_tokens[doc_spans[1].start:doc_spans[1].start + doc_spans[1].length])


### 2.7 Check if token is  max context
   * For evaluation of performance the model prefer max span contents

In [None]:
def _check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""

    #  Document: the man went to the store and bought a gallon of milk
    #  Span A: the man went to the
    #  Span B: to the store and bought
    #  Span C: and bought a gallon of
    #  ...
    # "maximum context"
    #    =  minimum(  left context,   right context )
    

    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index


# Check if tokens is max context

doc_span = doc_spans[0]
tokens =[]
token_is_max_context ={}
# Document tokens
for i in range(doc_span.length):

    split_token_index = doc_span.start + i

    is_max_context = _check_is_max_context(doc_spans, 0,
                                         split_token_index)

    token_is_max_context[len(tokens)] = is_max_context

    # sliding window tokens
    tokens.append(all_doc_tokens[split_token_index])

print(tokens)
print(token_is_max_context)

### Make instance from each Document Span

   * Inputs:
        * input_ids:
            - a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary
        * token_type_ids:
            - a torch.LongTensor of shape [batch_size, sequence_length] with the values in [0, 1]. Type 0 corresponds to a "sentence A" and type 1 corresponds to a "sentence B" token
        * attention_mask : 
            - a torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. only positions with value 1's are attended in the sentences.
        * start_positions : 
            - position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
        * end_positions: 
            - position of the last token for the labeled span: torch.LongTensor of shape [batch_size].

In [None]:
#for (doc_span_index, doc_span) in enumerate(doc_spans):
doc_span_index =0
doc_span = doc_spans[0]

print(doc_span_index)
tokens =[]
token_to_orig_map ={}
token_is_max_context ={}
segment_ids =[]

# initial symbol
tokens.append("[CLS]")
segment_ids.append(0)

# Question tokens
for token in query_tokens:
    tokens.append(token)
    segment_ids.append(0)

# separator token
tokens.append("[SEP]")
segment_ids.append(0)
print(doc_span_index)

# Document tokens
for i in range(doc_span.length):
    split_token_index = doc_span.start + i

    # should make another mapping because the text is sliding window part
    token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]

    is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                          split_token_index)
    token_is_max_context[len(tokens)] = is_max_context

    # sliding window (doc_span) tokens
    tokens.append(all_doc_tokens[split_token_index])
    segment_ids.append(1)


#Separator token
tokens.append("[SEP]")
segment_ids.append(1)
print(tokens)
print(token_is_max_context)
print(segment_ids)

In [None]:
input_ids = tokenizer.convert_tokens_to_ids(tokens)

# Mask 
# real tokens : 1  , padding tokens : 0
input_mask =[1]*len(input_ids)

#Zero-pad up to the sequence length
while len(input_ids)< max_seq_length:
    input_ids.append(0)
    input_mask.append(0)
    segment_ids.append(0)

assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length

print(doc_span_index)

start_position = None
end_position = None

# When training only

# for training if document chunk does not contain an annotation 
# we throw it out, since there is nothing to predict
doc_start = doc_span.start
doc_end   = doc_span.start + doc_span.length -1
print(' ex_st ',example.start_position,' ex_ed ',
      example.end_position,' doc_st ',doc_start,' doc_ed ',doc_end)
if (example.start_position < doc_start or
       example.end_position < doc_start or
       example.start_position > doc_end or 
       example.end_position > doc_end):
    print("answer is not in the document")

# [CLS] + question tokens + [SEP]
doc_offset = len(query_tokens) + 2
#               start position new_token   window_position 
start_position = tok_start_position         - doc_start    + doc_offset
end_position = tok_end_position - doc_start + doc_offset

example_index = 1
unique_id = example.qas_id
print(doc_span_index)


In [None]:
class InputFeatures(object):
    """
    A single set of rfeatures of data
    """
    def __init__(self,
                unique_id,
                example_index,
                doc_span_index,
                tokens,
                token_to_orig_map,
                token_is_max_context,
                input_ids,
                input_mask,
                segment_ids,
                start_position=None,
                end_position=None):
        self.unique_id = unique_id
        self.example_index = example_index
        self.doc_span_index = doc_span_index
        self.tokens = tokens
        self.token_to_orig_map = token_to_orig_map
        self.token_is_max_context = token_is_max_context
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.start_position = start_position
        self.end_position = end_position
        
    def __str__(self):
        
        return self.__repr__()
    
    def __repr__(self):
        
        s = "\n"
        s +="unique_id: {} \n".format(unique_id)
        s +="example_index: {} \n".format(example_index)
        s +="doc_span_index: {} \n".format(doc_span_index)
        s +="tokens: {} \n".format(" ".join(tokens))
        s +="tokens_to_origin_map: {} \n".format(" ".join([
            "{}:{}".format(x,y) for (x,y) in token_to_orig_map.items()]))
        s +="token_is_max_content: {} \n".format(" ".join([
              "{}:{}".format(x,y) for (x,y) in token_is_max_context.items()]))
        s +="input_ids: {} \n".format( " ".join([str(x) for x in input_ids]))
        s +="input_mask: {}\n".format( " ".join([str(x) for x in input_mask]))
        s +="segment_ids: {}\n".format(" ".join([str(x) for x in segment_ids]))
        # only when training 
        s += "answer_text : {}\n".format(" ".join(tokens[start_position:(end_position +1)]))
        s += "start_position: {}\n".format(start_position)
        s +="end_position : {}\n ".format(end_position)
        return s


input = InputFeatures(
        unique_id = unique_id,
        example_index = example_index,
        doc_span_index = doc_span_index,
        tokens = tokens,
        token_to_orig_map=token_to_orig_map,
        token_is_max_context=token_is_max_context,
        input_ids = input_ids,
        input_mask = input_mask,
        segment_ids=segment_ids,
        start_position = start_position,
        end_position = end_position)

print(input)       

In [None]:
import pickle
global_step = 0
cached_train_features_file = 'data/train-v1.1.json_{}_{}_{}_{}'.format(
    list(filter(None,bert_model.split('/'))).pop(),str(max_seq_length),
    str(doc_stride),str(max_query_length))

with open(cached_train_features_file,'wb') as writer:
    #train_features = features
    pickle.dump(train_features,writer)
   