The nqtables data is a subset of the original Natural Questions dataset, where the answer could be found from the tables in the given wiki article. The original Natural Questions dataset could be found [here](https://ai.google.com/research/NaturalQuestions/download).

In [None]:
import re
import json
from tqdm import tqdm
import random
import collections

In [None]:
def find_table(tokens):
    table_locs = [-1 for i in range(len(tokens))]
    tids = []
    tables = []
    for i, token in enumerate(tokens):
        if token == '<Table>':
            tids.append(len(tables))
            tables.append([])
        if tids:
            table_locs[i] = tids[-1]
            tables[tids[-1]].append(token)
        if token == '</Table>':
            tids.pop()
    return table_locs, tables

def extract_table(tokens):
    data = []
    row = None
    cell = None
    merged_cells = {}
    colspan = 1
    for token in tokens[1:-1]:
        if token == '<Tr>':
            row= []
        elif token == '</Tr>':
            data.append(row)
            row = None
        elif token[:3] in ['<Td','<Th']:
            cell = []
            try:
                colspan = 1 if 'colspan' not in token else int(re.findall(r'colspan=.*?(\d+)', token))
            except:
                colspan = 1
        elif token[:4] in ['</Td','</Th']:
            row+=[' '.join(cell)]*colspan
            cell = None
        else:
            if cell is None:
                cell = []
            if row is None:
                row = []
            cell.append(token)
    return data

In [None]:
with open('./data/original/NaturalQuestions/simplified-nq-train.jsonl', 'r') as f_in:
    questions = []
    short_answer_in_table_questions = []
    long_answer_in_table_questions = []
    for idx, line in enumerate(tqdm(f_in)):
        sample = json.loads(line)
        raw_document_text = sample['document_text'].split(' ')
        list_document_text = []
        old_to_new_map = {-1:-1}
        num_token_added = 0
        for long_answer_candidate in sample['long_answer_candidates']:
            if long_answer_candidate['top_level']:
                if any([(token[0]!='<' or token[-1]!='>') for token in raw_document_text[long_answer_candidate['start_token']:long_answer_candidate['end_token']]]):
                    for i in range(long_answer_candidate['start_token'],long_answer_candidate['end_token']+1):
                        old_to_new_map[i] = num_token_added+i-long_answer_candidate['start_token']
                    num_token_added += long_answer_candidate['end_token']-long_answer_candidate['start_token']
                    list_document_text.append(raw_document_text[long_answer_candidate['start_token']:long_answer_candidate['end_token']])
        tokens = [token for document_text in list_document_text for token in document_text]
        document_text = ' '.join(tokens)
        table_locs, tables = find_table(tokens)
        tables = [extract_table(t) for t in tables]
        answer_in_table = []
        long_answer_in_table = []
        has_long = False
        has_short = False
        has_table = any([x==1 for x in table_locs])
        
        for annotation in sample['annotations']:
            answer = annotation['long_answer']
            start = old_to_new_map.get(answer['start_token'], None)
            end = old_to_new_map.get(answer['end_token'], None)
            if start is None or end is None:
                print('missed long in:', idx)
                continue
            if answer['start_token']!=-1:
                if all([True if table_locs[i]!=-1 else False for i in range(start, end)]):
                    long_answer_in_table.append([table_locs[start], start, end])
                has_long = True
            if annotation['short_answers']:
                has_short =True
            for answer in annotation['short_answers']:
                start = old_to_new_map.get(answer['start_token'], None)
                end = old_to_new_map.get(answer['end_token'], None)
                if start is None or end is None:
                    print('missed short in:', idx)
                    continue
                if all([True if table_locs[i]!=-1 else False for i in range(start, end)]):
                    answer_in_table.append([table_locs[start], ' '.join(tokens[start:end]), start, end])
        if answer_in_table:
            short_answer_in_table_questions.append({
                'q': sample['question_text'],
                't': tables,
                'd': document_text,
                'list_d': list_document_text,
                'a': sample['annotations'],
                'a in table': answer_in_table
            })
        if long_answer_in_table:
            long_answer_in_table_questions.append({
                'q': sample['question_text'],
                'd': document_text,
                'list_d': list_document_text,
                'a': sample['annotations'],
                'a in table': long_answer_in_table
            })
        questions.append([has_long, has_short, has_table, sample['question_text']])

In [None]:
def get_answer_pattern(answer):
    pattern = re.escape(answer)
    if re.match('\w', answer[0]):
        pattern = r'\b'+pattern
    if re.match('\w', answer[-1]):
        pattern += r'\b'
    return pattern
def find_answer_cell(answer, table):
    loc = []
    for i, row in enumerate(table):
        for j, cell in enumerate(row):
            if re.search(get_answer_pattern(answer),cell) is not None:
                loc.append([len(cell),[i,j]])
            
    if loc:
        return sorted(loc, key=lambda x:(x[0]+9999 if x[1][0]==0 else x[0]))[0][1]
    else:
        return None

In [None]:
random.seed(1)
total_num = len(short_answer_in_table_questions)
dev_num = int(0.1*total_num)
dev_idxes = set(random.sample(list(range(total_num)),dev_num))
train_examples = []
dev_examples = []
for qid, sample in enumerate(short_answer_in_table_questions):
    if qid in dev_idxes:
        dev_examples.append(sample)
    else:
        train_examples.append(sample)
print('train:', len(train_examples))
print('dev:', len(dev_examples))

We experiment with two types of models. One is linearized the table and treat it as regular text QA task (MRQA), the other is using a Table QA model.

## Prepare for Text QA

In [None]:
tag_removed = set()
max_position = 49
def clean_context(context):
    p_idx = 0
    t_idx = 0
    l_idx = 0
    last_tag = False
    old_to_new_map = {}
    cleaned_context = []
    for i, token in enumerate(context):
        if token == '<P>':
            last_tag=True
            cleaned_context.append(f'[P={p_idx}]')
            old_to_new_map[i] = len(cleaned_context)
            if p_idx<max_position:
                p_idx+=1
        elif token == '<Table>':
            last_tag=True
            cleaned_context.append(f'[Tab={t_idx}]')
            old_to_new_map[i] = len(cleaned_context)
            if t_idx<max_position:
                t_idx+=1
        elif token == '<Ul>':
            last_tag=True
            cleaned_context.append(f'[List={t_idx}]')
            old_to_new_map[i] = len(cleaned_context)
            if l_idx<max_position:
                l_idx+=1
        elif token[0]=='<' and token[-1]=='>':
            tag_removed.add(token)
            if not last_tag:
                last_tag=True
                cleaned_context.append('[SEP]')
                old_to_new_map[i] = len(cleaned_context)
        elif token!='':
            last_tag=False
            cleaned_context.append(token)
            old_to_new_map[i] = len(cleaned_context)
    return cleaned_context
def prepare_examples_as_MRQA(all_examples, split='train'):
    processed_examples = []
    processed_examples_only_table = []
    for qid, sample in enumerate(tqdm(all_examples)):
        processed_example = {}
        processed_example_only_table = {}
        qid = f'{split}-{qid}'
        all_answers = [a[1] for a in sample['a in table']]
        context = [token for x in sample['list_d'][:(50 if split=='train' else 99999)] for token in x]
        context = ' '.join(clean_context(context))
        processed_example['context'] = context
        table_only_context = ' '.join(clean_context([token for t in find_table(sample['d'].split(' '))[1] for token in t]))
        processed_example_only_table['context'] = table_only_context
        qas = {'question': sample['q'], 'qid': qid, 'answers': all_answers}
        table_only_qas = qas.copy()
        detected_answers = []
        table_only_detected_answers = []
        for a in sample['a in table']:
            detected_answer = {}
            detected_answer['char_spans'] = [[loc.start(), loc.end()-1] for loc in re.finditer(r'\b'+re.escape(a[1])+r'\b', context)]
            if detected_answer['char_spans']:
                detected_answer['text'] = a[1]
                detected_answers.append(detected_answer)                
            detected_answer_in_table = detected_answer.copy()
            detected_answer_in_table['char_spans'] = [[loc.start(), loc.end()-1] for loc in re.finditer(r'\b'+re.escape(a[1])+r'\b', table_only_context)]
            if detected_answer_in_table['char_spans']:
                detected_answer_in_table['text'] = a[1]
                table_only_detected_answers.append(detected_answer_in_table)
        qas['detected_answers'] = detected_answers
        table_only_qas['detected_answers'] = table_only_detected_answers
        processed_example['qas'] = [qas]
        processed_example_only_table['qas'] = [table_only_qas]
        processed_examples.append(processed_example)
        processed_examples_only_table.append(processed_example_only_table)
    return processed_examples, processed_examples_only_table

In [None]:
train_prepared_MRQA_examples, train_prepared_tableonly_MRQA_examples = prepare_examples_as_MRQA(train_examples, 'train')
dev_prepared_MRQA_examples, dev_prepared_tableonly_MRQA_examples = prepare_examples_as_MRQA(dev_examples, 'dev')

## Prepare for Table QA

In [None]:
import data_loader.table_utils as table_utils

def get_values(text):
    values = []
    value_spans = table_utils.parse_text(text)
    for value_span in value_spans:
        span_index = [value_span.begin_index, value_span.end_index]
        for value in value_span.values:
            value_type = 'date' if value.date is not None else 'number'
            values.append([value_type, span_index, value.float_value if value_type=='number' else (value.date.year,value.date.month,value.date.day)])
            break
    return values

# clean cell: clean whitespaces, remove format tags
# parse number and date, add rank
def process_table(table, min_consolidation_fraction=0.7):
    raw_data = table
    processed_data = []
    processed_values = []
    processed_ranks = []
    processed_inv_ranks = []
    error = None
    max_col_num = max([len(row) for row in raw_data])
    for i, row in enumerate(raw_data):
        if all([cell=='' for cell in row]):
            continue
        processed_row = []
        processed_row_values = []
        for j, cell in enumerate(row):
            processed_row.append(cell)
            processed_row_values.append(table_utils.parse_text(cell))
        processed_row += ['']*(max_col_num-len(row))
        processed_row_values += [[]]*(max_col_num-len(row))
        processed_data.append(processed_row)
        processed_values.append(processed_row_values)
        processed_ranks.append([0 for _ in range(max_col_num)])
        processed_inv_ranks.append([0 for _ in range(max_col_num)])
    for j in range(max_col_num):
        column_values = {i:processed_values[i][j] for i in range(1,len(processed_values))}
        column_values = table_utils._consolidate_numeric_values(
            column_values,
            min_consolidation_fraction=min_consolidation_fraction,
        )
        try:
            key_fn = table_utils.get_numeric_sort_key_fn([cell_value[1] for cell_value in column_values.values()])
            column_numeric_values_to_rank = {row_index: key_fn(value[1]) for row_index, value in column_values.items()}

            column_numeric_values_inv = collections.defaultdict(list)
            for row_index, value in column_numeric_values_to_rank.items():
                column_numeric_values_inv[value].append(row_index)

            unique_values = sorted(column_numeric_values_inv.keys())

            for rank, value in enumerate(unique_values):
                for row_index in column_numeric_values_inv[value]:
                    processed_ranks[row_index][j] = rank + 1
                    processed_inv_ranks[row_index][j] = len(unique_values) - rank
        except ValueError:
            pass
    for i in range(len(processed_values)):
        for j in range(max_col_num):
            values = []
            for value_span in processed_values[i][j]:
                span_index = [value_span.begin_index, value_span.end_index]
                for value in value_span.values:
                    value_type = 'date' if value.date is not None else 'number'
                    values.append([value_type,span_index, value.float_value if value_type=='number' else (value.date.year,value.date.month,value.date.day)])
                    break
            processed_values[i][j] = values
    return {
        'text': processed_data,
        'error': error,
        'values': processed_values,
        'value_ranks': processed_ranks,
        'value_inv_ranks': processed_inv_ranks,
        'index': [[[i,j] for j in range(len(processed_data[0]))] for i in range(len(processed_data))]
    }

In [None]:
tags_to_remove = re.compile('<Ul>|</Li>|</Ul>')
def prepare_examples(all_examples, split='train', max_cell = 50, max_row=400, max_table = 400):
    def chunk_table(table, answer, tid):
        if len(table['text'])==0:
            return None     
        table_text = [[cell.split()[:max_cell] for cell in row] for row in table['text']]

        i_start = 1 if len(table_text)>1 else 0
        headers = table_text[0]
        while i_start<len(table_text):
            tmp_table = []
            tmp_size = len(headers[0])+len(table_text[i_start][0])
            j_start = 1 if len(table_text[i_start])>1 else 0
            i_end = i_start+1
            while j_start<len(table_text[i_start]):
                j_end = j_start
                for j in range(j_start, len(table_text[i_start])):
                    if tmp_size+len(headers[j])+len(table_text[i_start][j])<=max_row:
                        j_end += 1
                        tmp_size += len(headers[j])+len(table_text[i_start][j])
                    else:
                        break
                if j_start!=0:
                    tmp_headers = [headers[0]]+headers[j_start:j_end]
                    tmp_table.append([table_text[i_start][0]]+table_text[i_start][j_start:j_end])
                else:
                    tmp_headers = headers[j_start:j_end]
                    tmp_table.append(table_text[i_start][j_start:j_end])
                if j_start==1 and j_end==len(table_text[i_start]):
                    for row in table_text[i_start+1:]:
                        row_size = sum([len(cell) for cell in row])
                        if row_size+tmp_size<=max_table:
                            tmp_table.append(row)
                            i_end += 1
                            tmp_size += row_size
                        else:
                            break
                tmp_headers = [' '.join(cell) for cell in tmp_headers]
                tmp_table = [[' '.join(cell) for cell in row] for row in tmp_table]
                if j_start!=0:
                    i_map = {0:0}
                    j_map = {0:0}
                    i_map.update({i0:i1 for i0,i1 in zip(range(1,1+i_end-i_start),range(i_start, i_end))})
                    j_map.update({j0:j1 for j0,j1 in zip(range(1,1+j_end-j_start),range(j_start, j_end))})
                else:
                    i_map = {0:0}
                    j_map = {0:0}
                    i_map.update({i0:i1 for i0,i1 in zip(range(1,1+i_end-i_start),range(i_start, i_end))})
                    j_map.update({j0:j1 for j0,j1 in zip(range(0,j_end),range(j_start, j_end))})
                if answer!='':
                    cell_index = find_answer_cell(answer, [tmp_headers]+tmp_table)
                else:
                    cell_index = None
                if cell_index is None:
                    cell_index = [-1, -1]
                    cell_span = [-1,-1]
                else:
                    if cell_index[0] == 0:
                        cell_start = re.search(get_answer_pattern(answer),tmp_headers[cell_index[1]]).start()
                    else:
                        cell_start = re.search(get_answer_pattern(answer),tmp_table[cell_index[0]-1][cell_index[1]]).start()
                    cell_span = [cell_start,cell_start+len(answer)]
                yield {
                    'table':{
                        'idx': tid,
                        'data': [tmp_headers]+tmp_table,
                        'index':[[table['index'][i_map[i]][j_map[j]] for j in range(len(j_map))] for i in range(len(i_map))],
                        'values':[[table['values'][i_map[i]][j_map[j]] for j in range(len(j_map))] for i in range(len(i_map))],
                        'value_ranks':[[table['value_ranks'][i_map[i]][j_map[j]] for j in range(len(j_map))] for i in range(len(i_map))],
                        'value_inv_ranks':[[table['value_inv_ranks'][i_map[i]][j_map[j]] for j in range(len(j_map))] for i in range(len(i_map))],
                    },
                    'answer': {
                        'text': answer if cell_index[0]!=-1 else '',
                        'index': cell_index,
                        'span': cell_span
                    },
                }
                tmp_size = len(headers[0])+len(table_text[i_start][0])
                tmp_table = []
                if j_end == len(table_text[i_start]):
                    break
                else:
                    j_start = min([j_start+4,j_end])
                    i_end = i_start+1
            i_start = i_end

    processed_examples = []
    missed_examples = []
    for qid, sample in enumerate(tqdm(all_examples)):
        qid = f'{split}-{qid}'
        used_t = set()
        all_answers = [re.sub(' +',' ',tags_to_remove.sub('', a[1])) for a in sample['a in table']]
        found_answer = False
        question = sample['q']
        question_values = get_values(question)
        for a in sample['a in table']:
            if a[0] in used_t and found_answer:
                continue
            table = sample['t'][a[0]]
            if len(table)==0:
                continue
            table = [[re.sub(' +',' ',tags_to_remove.sub('', cell)) for cell in row] for row in table]
            table = process_table(table)
            used_t.add(a[0])
            answer = re.sub(' +',' ',tags_to_remove.sub('', a[1]))
            for processed_example in chunk_table(table, answer, a[0]):
                if processed_example is None:
                    break
                processed_example.update({
                    'qid': qid,
                    'question': question,
                    'question_values': question_values,
                    'all_answers': all_answers
                })
                processed_examples.append(processed_example)
                if processed_example['answer']['text']!='':
                    found_answer = True
        if not found_answer:
            missed_examples.append([qid, sample])
        for tid, table in enumerate(sample['t']):
            if tid in used_t:
                continue
            table = [[re.sub(' +',' ',tags_to_remove.sub('', cell)) for cell in row] for row in table]
            if len(table)==0:
                continue
            table = process_table(table)
            for processed_example in chunk_table(table, '', tid):
                if processed_example is None:
                    break
                processed_example.update({
                    'qid': qid,
                    'question': question,
                    'question_values': question_values,
                    'all_answers': all_answers
                })
                processed_examples.append(processed_example)
    return processed_examples, missed_examples

In [None]:
train_prepared_examples_row_first, train_missed_examples_row_first = prepare_examples(train_examples, 'train')
dev_prepared_examples_row_first, dev_missed_examples_row_first = prepare_examples(dev_examples, 'dev')