## Generating the BERT embeddings for the tables which are used for the similarity-based search of few-shot examples

In [1]:
import sys
import os
import pandas as pd
import numpy as np
import ast
import torch
import json
import itertools
import random

In [2]:
import sys
sys.path.append('/Users/z003yzhj/Desktop/Projects/tabner')

In [3]:
print(sys.path)

['/Users/z003yzhj/Desktop/Projects/tabner/notebooks', '/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python39.zip', '/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9', '/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/lib-dynload', '', '/Users/z003yzhj/Desktop/Projects/TabNER/venv/lib/python3.9/site-packages', '/Users/z003yzhj/Desktop/Projects/llama', '/Users/z003yzhj/Desktop/Projects/tabner']


In [4]:
from transformers import BertTokenizer, BertModel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from utils import get_correct_anno, process_single_table_gpt

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [7]:
seed_nr = 42
generator = torch.Generator().manual_seed(seed_nr)

In [8]:
data_path="../data/final_NER_labeled_dataset.csv"
with open(data_path, 'r') as f:
    ner_tables = json.load(f)   

In [9]:
labels_dict = {  
            'Activity': 1,
            'Organisation': 2,
            'ArchitecturalStructure': 3,
            'Event': 4,
            'Place': 5,
            'Person': 6,
            'Work': 7,
                }
labels_dict_rev={v:k for k,v in labels_dict.items()}
print(labels_dict)

{'Activity': 1, 'Organisation': 2, 'ArchitecturalStructure': 3, 'Event': 4, 'Place': 5, 'Person': 6, 'Work': 7}


### Generate table embeddings for all tables and save in npy

In [10]:
test = process_single_table_gpt(model, tokenizer, ner_tables[84][0])
test.shape

torch.Size([768])

In [13]:
all_table_embeddings = []
for i in range(len(ner_tables)):       
    table = ner_tables[i][0]
    table_embedding = process_single_table_gpt(model, tokenizer, table)
    all_table_embeddings.append(table_embedding)
    
all_table_embeddings = np.array(all_table_embeddings)
np.save('../output/bert_all_table_embeddings.npy', all_table_embeddings)

In [14]:
all_table_embeddings.shape

(51271, 768)

In [14]:
train_set, test_set = torch.utils.data.random_split(ner_tables, [49271, 2000], generator=generator)

In [15]:
train_embeddings = all_table_embeddings[train_set.indices]

### Look for the 5 most similar tables to each test table, save indices and then generate prompt demos 

In [17]:
similar_sets = []
for idx in test_set.indices:            
    similar = np.dot(all_table_embeddings[idx], train_embeddings.T)
    top3_similar_tables_indices = np.argsort(similar, axis=0)[-5:]    
    similar_sets.append(tuple(top3_similar_tables_indices))

In [20]:
expand = [item for subset in similar_sets for item in subset]

#### Some of the tables repeat, therefore we take the set of all the similar sets. For these tables, we get the correct annotations and prepare them for input to the prompt. We save them into a dict {tab_id: example_rows, example_NER_annotations}

In [21]:
len(set(expand))

4675

In [22]:
# to generate the gt for the indices
similar_tables = {}
for i in expand:
    tab_idx = train_set.indices[i]
    table = ner_tables[tab_idx][0]   
    
    example_rows, example_NER_annotations, _ = get_correct_anno(table, labels_dict_rev)        
    similar_tables[tab_idx] = [example_rows, example_NER_annotations]    

In [107]:
with open("../output/similar_tables_examples.json", "r") as f:
    similar = json.load(f)

In [43]:
list(similar.keys())[15]

'746'

### We sample tables with more than 4 entity types for the random selection of examples. We select the indices of the tables and use these in the function get_examples to get the example rows and the NER annotations.

In [88]:
inter = []
for key, val in similar.items():
    types = []
    for el in val[1]:        
        for k, v in el.items():
            if k =="type":
                types.append(v)
        
    if len(set(types)) > 4:
        inter.append(key)

In [89]:
inter

['25549',
 '5310',
 '25984',
 '34246',
 '7961',
 '36025',
 '3975',
 '1907',
 '1910',
 '1909',
 '14471',
 '9266',
 '27347',
 '27343',
 '49690',
 '6930',
 '10877',
 '7995',
 '13357',
 '13381',
 '8129',
 '46299',
 '37770',
 '42900',
 '6174',
 '15689',
 '38509',
 '33123',
 '26313',
 '35188',
 '44573',
 '34545',
 '49452',
 '43367',
 '45702',
 '45594',
 '45624',
 '33747',
 '34520',
 '5365']