In [1]:
import pandas as pd
import re
import multiprocessing

In [2]:
# Read the ground truth files for SOTAB
# cta_train_gt = pd.read_csv('data/CTA/CTA_training_gt.csv')
# cta_val_gt = pd.read_csv('data/CTA/CTA_validation_gt.csv')
# cta_test_gt = pd.read_csv('data/CTA/CTA_test_gt.csv')

In [2]:
# Read the ground truth files for SOTABv2
cta_train_gt = pd.read_csv('SOTAB-v2/CTA/sotab_v2_cta_training_set.csv')
cta_val_gt = pd.read_csv('SOTAB-v2/CTA/sotab_v2_cta_validation_set.csv')
cta_test_gt = pd.read_csv('SOTAB-v2/CTA/sotab_v2_cta_test_set.csv')

In [4]:
gt = {'train':{}, 'val':{}, 'test':{}}
for index, row in cta_train_gt.iterrows():
    if row['table_name'] not in gt['train']:
        gt['train'][row['table_name']] = {}
        
    gt['train'][row['table_name']][row['column_index']] = row['label']
val = {}
for index, row in cta_val_gt.iterrows():
    if row['table_name'] not in gt['val']:
        gt['val'][row['table_name']] = {} 
    gt['val'][row['table_name']][row['column_index']] = row['label']
test = {}
for index, row in cta_test_gt.iterrows():
    if row['table_name'] not in gt['test']:
        gt['test'][row['table_name']] = {}
    gt['test'][row['table_name']][row['column_index']] = row['label']

In [5]:
cta_train_cols = (cta_train_gt['table_name'] + '|' + cta_train_gt['column_index'].map(str) + '|' + cta_train_gt['label']).tolist()
cta_val_cols = (cta_val_gt['table_name'] + '|' + cta_val_gt['column_index'].map(str) + '|' + cta_val_gt['label']).tolist()
cta_test_cols = (cta_test_gt['table_name'] + '|' + cta_test_gt['column_index'].map(str) + '|' + cta_test_gt['label']).tolist()

In [9]:
type_labels = list(cta_val_gt['label'].unique())
print(len(type_labels))

In [6]:
#Simple Preprocessing

def clean_text(text):        
    if(isinstance(text, dict)):
        text = ' '.join([ clean_text(v) for k, v in text.items()] )
    elif(isinstance(text, list)):
        text = map(clean_text, text)
        text = ' '.join(text)
        
    if pd.isnull(text):
        return ''
        
    #Remove excess whitespaces
    text = re.sub(' +', ' ', str(text)).strip()
    
    return text

In [7]:
# Prepare format of input datasets for Doduo models: table_id, [labels], data, label_ids
def get_table_column(column):
    file_name, column_index, label = column.split('|')

    if file_name in cta_train_gt['table_name'].tolist():
        path = 'SOTAB-v2/CTA/Train/'+file_name # Path for train tables
    elif file_name in cta_val_gt['table_name'].tolist():
        path = 'SOTAB-v2/CTA/Validation/'+file_name # Path for validation tables
    else:
        path = 'SOTAB-v2/CTA/Test/'+file_name # Path for test tables

    df = pd.read_json(path, compression='gzip', lines=True)

    y = [0] * len(type_labels)
    y[type_labels.index(label)] = 1

    return [
        file_name, #table_id
        [label], #[labels]
        clean_text(df.iloc[:, int(column_index)].tolist()), #data
        y, #label_ids
        column_index
    ]

In [10]:
pool = multiprocessing.Pool(processes=20)
train_result = pool.map(get_table_column, cta_train_cols)
val_result = pool.map(get_table_column, cta_val_cols)
test_result = pool.map(get_table_column, cta_test_cols)
pool.close()
pool.join()

In [12]:
cta = {}
cta['train'] = pd.DataFrame(train_result, columns=['table_id', 'labels', 'data', 'label_ids','column_index'])
cta['dev'] = pd.DataFrame(val_result, columns=['table_id', 'labels', 'data', 'label_ids','column_index'])
cta['test'] = pd.DataFrame(test_result, columns=['table_id', 'labels', 'data', 'label_ids','column_index'])

In [None]:
# Copy MLB from DODUO provided datasets
import pickle
with open('data/turl-datasets/table_rel_extraction_serialized.pkl', "rb") as f:
    train = pickle.load(f)
cta['mlb'] = train['mlb']

In [None]:
cta['test']

In [29]:
import pickle
file_name='data/sotabv2/table_col_type_serialized.pkl'
f = open(file_name,'wb')
pickle.dump(cta,f)
f.close()