In [1]:
import argparse
from collections import OrderedDict
import logging
import os
import time
from torch.multiprocessing import Pool, Process, set_start_method

import numpy as np
import pandas as pd
import torch
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import KFold, train_test_split
from tqdm import tqdm

from data_preprocess2 import split_overflow_table,fillup_table_w2i,remove_empty_cols,remove_empty_rows,tokenize_str
from utils import Config, loadpkl, make_dirs, mp, flatten_1_deg,pool_fn,savepkl,print_tableIDs
import models
from dataset import T2VDataset
from TREC_score import ndcg_pipeline

In [2]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--path",
                        help="path for the scores")
    parser.add_argument("-m", "--model_name")
    return parser.parse_args()

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

# Prepare the df

In [3]:
class TREC_data_prep():
    def __init__(self, config, vocab):
        self.config = config
        self.vocab = vocab

    def convert2table(self, inp, typ):
        if typ == 'table':
            if np.array(inp).size != 0:
                inp = remove_empty_cols(remove_empty_rows(inp))
            if np.array(inp).size == 0:
#             if len(inp) == 0:
                inp = [[['<PAD>']]]
        if typ == 'query':
            inp = [[[j] for j in inp]]
        return inp
    
    def prepare_data(self, df, typ):
        pad_id = self.vocab.index('<PAD>')
        if typ == 'table':
            df[f"{typ}_pad"] = df[f"{typ}_tkn"].apply(lambda x: self.convert2table(eval(x),typ))
            df[f"{typ}_pad"] = df[f"{typ}_pad"].apply(lambda x: split_overflow_table(x))
            for index, row in df.iterrows():
                inp = row[f"{typ}_pad"]
                inp = fillup_table_w2i(self.vocab,inp)
                inp = [T2VDataset.pad_table(self.config['table_prep_params'], inp[i], pad_id) for i in range(len(inp))]
                inp = np.array(inp)
                print(index,inp.shape)
                rows2add = [row] * inp.shape[0]
                for i in range(inp.shape[0]):
                    rows2add[i][f"{typ}_pad"] = inp[i]
                df = df.drop(index)
                df = pd.concat([df, pd.DataFrame(rows2add)])
        elif typ == 'query':
            tables = {t:t for t in list(set(df[f"{typ}_tkn"]))}
            print(tables)
            for i in tables.keys():
                t = self.convert2table(eval(tables[i]),typ)
                t = fillup_table_w2i(self.vocab,[t])
                tables[i] = np.array(T2VDataset.pad_table(config['table_prep_params'], t[0], pad_id))
            df[f"{typ}_pad"] = df[f"{typ}_tkn"].apply(lambda x: tables[x])
        print(df.shape)
        return df
    
    def pipeline(self, baseline_f):
        baseline_f = self.prepare_data(baseline_f, 'table')
        baseline_f = self.prepare_data(baseline_f, 'query')
        return baseline_f

In [4]:
# args = get_args()
args = dotdict({"path":'output/6_29_13_36_22', 'model_name':'model_40.pt'})

config = Config()
config.load(os.path.join(args.path, 'config.toml'))

In [5]:
vocab = loadpkl(config['input_files']['vocab_path'])
device = torch.device(f"cuda:{config['gpu']}")

# vocab = loadpkl('/home/vibhav_student/2Dcnn/data/w_all_data/vocab_5-15_unk_qfix.pkl')
# device = torch.device('cuda:3')

In [109]:
baseline_f = pd.read_csv('/home/vibhav_student/2Dcnn/data/w_all_data/baseline_f_tq-tkn.csv')
baseline_f = TREC_data_prep(config, vocab).pipeline(baseline_f.iloc[:,:])
baseline_f.reset_index(drop=True,inplace=True) 
# baseline_f.to_csv('/home/vibhav_student/2Dcnn/data/w_all_data/baseline_f_tq-tkn-pad.csv',index=False)
savepkl('/home/vibhav_student/2Dcnn/data/w_all_data/baseline_f_tq-tkn-pad_qfix.csv',baseline_f)

0 (1, 15, 5, 1)
1 (1, 15, 5, 1)
2 (1, 15, 5, 1)
3 (1, 15, 5, 1)
4 (1, 15, 5, 1)
5 (1, 15, 5, 1)
6 (1, 15, 5, 1)
7 (5, 15, 5, 1)
8 (1, 15, 5, 1)
9 (1, 15, 5, 1)
10 (1, 15, 5, 1)
11 (2, 15, 5, 1)
12 (2, 15, 5, 1)
13 (1, 15, 5, 1)
14 (1, 15, 5, 1)
15 (2, 15, 5, 1)
16 (1, 15, 5, 1)
17 (1, 15, 5, 1)
18 (1, 15, 5, 1)
19 (2, 15, 5, 1)
20 (1, 15, 5, 1)
21 (2, 15, 5, 1)
22 (1, 15, 5, 1)
23 (1, 15, 5, 1)
24 (1, 15, 5, 1)
25 (1, 15, 5, 1)
26 (1, 15, 5, 1)
27 (1, 15, 5, 1)
28 (1, 15, 5, 1)
29 (1, 15, 5, 1)
30 (1, 15, 5, 1)
31 (1, 15, 5, 1)
32 (1, 15, 5, 1)
33 (1, 15, 5, 1)
34 (1, 15, 5, 1)
35 (7, 15, 5, 1)
36 (1, 15, 5, 1)
37 (3, 15, 5, 1)
38 (2, 15, 5, 1)
39 (1, 15, 5, 1)
40 (5, 15, 5, 1)
41 (1, 15, 5, 1)
42 (2, 15, 5, 1)
43 (3, 15, 5, 1)
44 (1, 15, 5, 1)
45 (1, 15, 5, 1)
46 (1, 15, 5, 1)
47 (4, 15, 5, 1)
48 (1, 15, 5, 1)
49 (1, 15, 5, 1)
50 (1, 15, 5, 1)
51 (1, 15, 5, 1)
52 (1, 15, 5, 1)
53 (16, 15, 5, 1)
54 (1, 15, 5, 1)
55 (2, 15, 5, 1)
56 (1, 15, 5, 1)
57 (1, 15, 5, 1)
58 (1, 15, 5, 1)
59 (1,

461 (1, 15, 5, 1)
462 (2, 15, 5, 1)
463 (1, 15, 5, 1)
464 (4, 15, 5, 1)
465 (1, 15, 5, 1)
466 (1, 15, 5, 1)
467 (1, 15, 5, 1)
468 (2, 15, 5, 1)
469 (2, 15, 5, 1)
470 (4, 15, 5, 1)
471 (4, 15, 5, 1)
472 (1, 15, 5, 1)
473 (1, 15, 5, 1)
474 (7, 15, 5, 1)
475 (1, 15, 5, 1)
476 (3, 15, 5, 1)
477 (24, 15, 5, 1)
478 (24, 15, 5, 1)
479 (13, 15, 5, 1)
480 (2, 15, 5, 1)
481 (1, 15, 5, 1)
482 (1, 15, 5, 1)
483 (1, 15, 5, 1)
484 (2, 15, 5, 1)
485 (1, 15, 5, 1)
486 (1, 15, 5, 1)
487 (2, 15, 5, 1)
488 (1, 15, 5, 1)
489 (1, 15, 5, 1)
490 (1, 15, 5, 1)
491 (2, 15, 5, 1)
492 (3, 15, 5, 1)
493 (2, 15, 5, 1)
494 (3, 15, 5, 1)
495 (1, 15, 5, 1)
496 (2, 15, 5, 1)
497 (1, 15, 5, 1)
498 (1, 15, 5, 1)
499 (1, 15, 5, 1)
500 (3, 15, 5, 1)
501 (3, 15, 5, 1)
502 (1, 15, 5, 1)
503 (3, 15, 5, 1)
504 (2, 15, 5, 1)
505 (1, 15, 5, 1)
506 (1, 15, 5, 1)
507 (3, 15, 5, 1)
508 (1, 15, 5, 1)
509 (3, 15, 5, 1)
510 (4, 15, 5, 1)
511 (1, 15, 5, 1)
512 (8, 15, 5, 1)
513 (1, 15, 5, 1)
514 (1, 15, 5, 1)
515 (8, 15, 5, 1)
516 (2,

916 (1, 15, 5, 1)
917 (1, 15, 5, 1)
918 (1, 15, 5, 1)
919 (1, 15, 5, 1)
920 (1, 15, 5, 1)
921 (1, 15, 5, 1)
922 (1, 15, 5, 1)
923 (2, 15, 5, 1)
924 (1, 15, 5, 1)
925 (2, 15, 5, 1)
926 (2, 15, 5, 1)
927 (2, 15, 5, 1)
928 (2, 15, 5, 1)
929 (1, 15, 5, 1)
930 (2, 15, 5, 1)
931 (1, 15, 5, 1)
932 (16, 15, 5, 1)
933 (1, 15, 5, 1)
934 (1, 15, 5, 1)
935 (1, 15, 5, 1)
936 (4, 15, 5, 1)
937 (1, 15, 5, 1)
938 (2, 15, 5, 1)
939 (20, 15, 5, 1)
940 (2, 15, 5, 1)
941 (2, 15, 5, 1)
942 (2, 15, 5, 1)
943 (1, 15, 5, 1)
944 (1, 15, 5, 1)
945 (1, 15, 5, 1)
946 (1, 15, 5, 1)
947 (12, 15, 5, 1)
948 (2, 15, 5, 1)
949 (12, 15, 5, 1)
950 (1, 15, 5, 1)
951 (1, 15, 5, 1)
952 (4, 15, 5, 1)
953 (1, 15, 5, 1)
954 (1, 15, 5, 1)
955 (1, 15, 5, 1)
956 (1, 15, 5, 1)
957 (1, 15, 5, 1)
958 (1, 15, 5, 1)
959 (1, 15, 5, 1)
960 (4, 15, 5, 1)
961 (1, 15, 5, 1)
962 (1, 15, 5, 1)
963 (2, 15, 5, 1)
964 (2, 15, 5, 1)
965 (1, 15, 5, 1)
966 (8, 15, 5, 1)
967 (3, 15, 5, 1)
968 (1, 15, 5, 1)
969 (10, 15, 5, 1)
970 (2, 15, 5, 1)
971 (

1352 (2, 15, 5, 1)
1353 (1, 15, 5, 1)
1354 (10, 15, 5, 1)
1355 (2, 15, 5, 1)
1356 (3, 15, 5, 1)
1357 (2, 15, 5, 1)
1358 (1, 15, 5, 1)
1359 (1, 15, 5, 1)
1360 (1, 15, 5, 1)
1361 (1, 15, 5, 1)
1362 (1, 15, 5, 1)
1363 (1, 15, 5, 1)
1364 (1, 15, 5, 1)
1365 (1, 15, 5, 1)
1366 (1, 15, 5, 1)
1367 (1, 15, 5, 1)
1368 (1, 15, 5, 1)
1369 (1, 15, 5, 1)
1370 (1, 15, 5, 1)
1371 (1, 15, 5, 1)
1372 (1, 15, 5, 1)
1373 (2, 15, 5, 1)
1374 (2, 15, 5, 1)
1375 (2, 15, 5, 1)
1376 (4, 15, 5, 1)
1377 (2, 15, 5, 1)
1378 (2, 15, 5, 1)
1379 (3, 15, 5, 1)
1380 (4, 15, 5, 1)
1381 (2, 15, 5, 1)
1382 (1, 15, 5, 1)
1383 (2, 15, 5, 1)
1384 (4, 15, 5, 1)
1385 (2, 15, 5, 1)
1386 (3, 15, 5, 1)
1387 (2, 15, 5, 1)
1388 (5, 15, 5, 1)
1389 (2, 15, 5, 1)
1390 (2, 15, 5, 1)
1391 (2, 15, 5, 1)
1392 (2, 15, 5, 1)
1393 (3, 15, 5, 1)
1394 (2, 15, 5, 1)
1395 (1, 15, 5, 1)
1396 (2, 15, 5, 1)
1397 (1, 15, 5, 1)
1398 (1, 15, 5, 1)
1399 (1, 15, 5, 1)
1400 (1, 15, 5, 1)
1401 (1, 15, 5, 1)
1402 (2, 15, 5, 1)
1403 (1, 15, 5, 1)
1404 (1, 15

1783 (1, 15, 5, 1)
1784 (1, 15, 5, 1)
1785 (1, 15, 5, 1)
1786 (1, 15, 5, 1)
1787 (1, 15, 5, 1)
1788 (1, 15, 5, 1)
1789 (1, 15, 5, 1)
1790 (2, 15, 5, 1)
1791 (2, 15, 5, 1)
1792 (1, 15, 5, 1)
1793 (1, 15, 5, 1)
1794 (1, 15, 5, 1)
1795 (1, 15, 5, 1)
1796 (1, 15, 5, 1)
1797 (1, 15, 5, 1)
1798 (2, 15, 5, 1)
1799 (1, 15, 5, 1)
1800 (4, 15, 5, 1)
1801 (1, 15, 5, 1)
1802 (4, 15, 5, 1)
1803 (2, 15, 5, 1)
1804 (1, 15, 5, 1)
1805 (1, 15, 5, 1)
1806 (2, 15, 5, 1)
1807 (1, 15, 5, 1)
1808 (1, 15, 5, 1)
1809 (1, 15, 5, 1)
1810 (1, 15, 5, 1)
1811 (1, 15, 5, 1)
1812 (1, 15, 5, 1)
1813 (1, 15, 5, 1)
1814 (1, 15, 5, 1)
1815 (1, 15, 5, 1)
1816 (1, 15, 5, 1)
1817 (1, 15, 5, 1)
1818 (1, 15, 5, 1)
1819 (1, 15, 5, 1)
1820 (1, 15, 5, 1)
1821 (1, 15, 5, 1)
1822 (1, 15, 5, 1)
1823 (1, 15, 5, 1)
1824 (2, 15, 5, 1)
1825 (1, 15, 5, 1)
1826 (1, 15, 5, 1)
1827 (1, 15, 5, 1)
1828 (1, 15, 5, 1)
1829 (1, 15, 5, 1)
1830 (1, 15, 5, 1)
1831 (1, 15, 5, 1)
1832 (1, 15, 5, 1)
1833 (1, 15, 5, 1)
1834 (1, 15, 5, 1)
1835 (1, 15,

2215 (1, 15, 5, 1)
2216 (1, 15, 5, 1)
2217 (1, 15, 5, 1)
2218 (1, 15, 5, 1)
2219 (1, 15, 5, 1)
2220 (1, 15, 5, 1)
2221 (1, 15, 5, 1)
2222 (1, 15, 5, 1)
2223 (1, 15, 5, 1)
2224 (1, 15, 5, 1)
2225 (1, 15, 5, 1)
2226 (1, 15, 5, 1)
2227 (1, 15, 5, 1)
2228 (1, 15, 5, 1)
2229 (1, 15, 5, 1)
2230 (1, 15, 5, 1)
2231 (1, 15, 5, 1)
2232 (1, 15, 5, 1)
2233 (1, 15, 5, 1)
2234 (8, 15, 5, 1)
2235 (1, 15, 5, 1)
2236 (1, 15, 5, 1)
2237 (1, 15, 5, 1)
2238 (1, 15, 5, 1)
2239 (1, 15, 5, 1)
2240 (1, 15, 5, 1)
2241 (1, 15, 5, 1)
2242 (1, 15, 5, 1)
2243 (1, 15, 5, 1)
2244 (1, 15, 5, 1)
2245 (1, 15, 5, 1)
2246 (2, 15, 5, 1)
2247 (1, 15, 5, 1)
2248 (1, 15, 5, 1)
2249 (1, 15, 5, 1)
2250 (1, 15, 5, 1)
2251 (2, 15, 5, 1)
2252 (1, 15, 5, 1)
2253 (1, 15, 5, 1)
2254 (2, 15, 5, 1)
2255 (4, 15, 5, 1)
2256 (1, 15, 5, 1)
2257 (1, 15, 5, 1)
2258 (2, 15, 5, 1)
2259 (1, 15, 5, 1)
2260 (2, 15, 5, 1)
2261 (1, 15, 5, 1)
2262 (1, 15, 5, 1)
2263 (1, 15, 5, 1)
2264 (2, 15, 5, 1)
2265 (2, 15, 5, 1)
2266 (1, 15, 5, 1)
2267 (1, 15,

2646 (2, 15, 5, 1)
2647 (1, 15, 5, 1)
2648 (1, 15, 5, 1)
2649 (1, 15, 5, 1)
2650 (1, 15, 5, 1)
2651 (1, 15, 5, 1)
2652 (1, 15, 5, 1)
2653 (1, 15, 5, 1)
2654 (1, 15, 5, 1)
2655 (1, 15, 5, 1)
2656 (1, 15, 5, 1)
2657 (1, 15, 5, 1)
2658 (1, 15, 5, 1)
2659 (1, 15, 5, 1)
2660 (1, 15, 5, 1)
2661 (1, 15, 5, 1)
2662 (1, 15, 5, 1)
2663 (6, 15, 5, 1)
2664 (2, 15, 5, 1)
2665 (2, 15, 5, 1)
2666 (2, 15, 5, 1)
2667 (5, 15, 5, 1)
2668 (2, 15, 5, 1)
2669 (4, 15, 5, 1)
2670 (2, 15, 5, 1)
2671 (3, 15, 5, 1)
2672 (2, 15, 5, 1)
2673 (2, 15, 5, 1)
2674 (1, 15, 5, 1)
2675 (2, 15, 5, 1)
2676 (1, 15, 5, 1)
2677 (1, 15, 5, 1)
2678 (1, 15, 5, 1)
2679 (6, 15, 5, 1)
2680 (4, 15, 5, 1)
2681 (8, 15, 5, 1)
2682 (3, 15, 5, 1)
2683 (2, 15, 5, 1)
2684 (2, 15, 5, 1)
2685 (1, 15, 5, 1)
2686 (2, 15, 5, 1)
2687 (6, 15, 5, 1)
2688 (2, 15, 5, 1)
2689 (2, 15, 5, 1)
2690 (1, 15, 5, 1)
2691 (2, 15, 5, 1)
2692 (12, 15, 5, 1)
2693 (2, 15, 5, 1)
2694 (3, 15, 5, 1)
2695 (1, 15, 5, 1)
2696 (3, 15, 5, 1)
2697 (1, 15, 5, 1)
2698 (1, 15

3077 (1, 15, 5, 1)
3078 (1, 15, 5, 1)
3079 (1, 15, 5, 1)
3080 (1, 15, 5, 1)
3081 (1, 15, 5, 1)
3082 (1, 15, 5, 1)
3083 (1, 15, 5, 1)
3084 (1, 15, 5, 1)
3085 (2, 15, 5, 1)
3086 (1, 15, 5, 1)
3087 (1, 15, 5, 1)
3088 (1, 15, 5, 1)
3089 (1, 15, 5, 1)
3090 (1, 15, 5, 1)
3091 (1, 15, 5, 1)
3092 (1, 15, 5, 1)
3093 (1, 15, 5, 1)
3094 (1, 15, 5, 1)
3095 (1, 15, 5, 1)
3096 (2, 15, 5, 1)
3097 (2, 15, 5, 1)
3098 (2, 15, 5, 1)
3099 (2, 15, 5, 1)
3100 (2, 15, 5, 1)
3101 (2, 15, 5, 1)
3102 (2, 15, 5, 1)
3103 (2, 15, 5, 1)
3104 (2, 15, 5, 1)
3105 (2, 15, 5, 1)
3106 (2, 15, 5, 1)
3107 (2, 15, 5, 1)
3108 (2, 15, 5, 1)
3109 (2, 15, 5, 1)
3110 (2, 15, 5, 1)
3111 (1, 15, 5, 1)
3112 (2, 15, 5, 1)
3113 (3, 15, 5, 1)
3114 (1, 15, 5, 1)
3115 (1, 15, 5, 1)
3116 (3, 15, 5, 1)
3117 (2, 15, 5, 1)
3118 (2, 15, 5, 1)
3119 (3, 15, 5, 1)
(5596, 46)
{"['running_shoes']": "['running_shoes']", "['infections_treatment']": "['infections_treatment']", "['company_income_statements']": "['company_income_statements']", "['dise

In [60]:
# print_tableIDs(baseline_f['table_pad'][5],vocab)

[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]
[['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>'], ['<PAD>']]


# Prepare the model and features

In [6]:
# baseline_f = pd.read_csv('/home/vibhav_student/2Dcnn/data/w_all_data/baseline_f_tq-tkn-pad.csv')
# baseline_f['table_pad'] = baseline_f['table_pad'].apply(eval)
# baseline_f['query_pad'] = baseline_f['query_pad'].apply(eval)
baseline_f = loadpkl('/home/vibhav_student/2Dcnn/data/w_all_data/baseline_f_tq-tkn-pad_qfix.csv')

In [7]:
model = models.create_model(
    config['model_props']['type'],
    params=(len(vocab), config['model_params']['embedding_dim'])
)
model.to(device)
state_dict = torch.load(os.path.join(args.path, args.model_name))
model.linear_layers = torch.nn.Sequential(*(list(model.linear_layers.children())[:2]))
state_dict_ = OrderedDict({i: state_dict[i] for i in list(model.state_dict())})
model.load_state_dict(state_dict_)
print(torch.equal(list(model.parameters())[0], state_dict_['embeddings.weight']))
# print(model)
model.eval()

True


Table2Vec(
  (embeddings): Embedding(2533532, 100)
  (cnn_layers): Sequential(
    (0): Conv2d(100, 128, kernel_size=(3, 2), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (linear_layers): Sequential(
    (0): Linear(in_features=1536, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

In [8]:
def cos_sim(table, query):
        sim = cosine_similarity(np.array(table).reshape(1, -1), np.array(query).reshape(1, -1))
        return np.array(sim).reshape(-1)

def get_ft(inp, device):
    model_inp = torch.tensor(inp.tolist()).to(device)
    return model(model_inp).cpu().detach().numpy().tolist()

In [9]:
baseline_f['table_ft'] = get_ft(baseline_f['table_pad'],device)
baseline_f['query_ft'] = get_ft(baseline_f['query_pad'],device)
baseline_f['cos_sim'] = baseline_f.apply(lambda x: cos_sim(x['table_ft'], x['query_ft']), axis=1)

In [10]:
baseline_f.shape

(5596, 50)

In [11]:
baseline_f

Unnamed: 0,query_id,query,table_id,row,col,nul,in_link,out_link,pgcount,tImp,...,resim,query_l,rel,table_tkn,query_tkn,table_pad,query_pad,table_ft,query_ft,cos_sim
0,1,world interest rates Table,table-0875-680,8,2,0,31,21,51438,1.000000,...,0.281130,4,0,"[[['risk'], []], [[], []], [[], []], [[], []],...","['world_interest_rates', 'table']","[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.6285071969032288],[0.5665705800056458],[1.0]
1,1,world interest rates Table,table-1020-619,4,3,0,18,0,324,1.000000,...,0.710250,4,0,"[[['headline'], [], []], [['core'], [], []], [...","['world_interest_rates', 'table']","[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.6257492303848267],[0.5665705800056458],[1.0]
2,1,world interest rates Table,table-0288-531,3,5,0,23,22,26419,0.500000,...,0.033680,4,0,"[[['stocks'], [], [], [], []], [['bonds'], [],...","['world_interest_rates', 'table']","[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.710820734500885],[0.5665705800056458],[1.0]
3,1,world interest rates Table,table-0288-530,4,5,1,23,22,26419,0.500000,...,0.033680,4,0,"[[['stocks'], [], [], [], []], [['bonds'], [],...","['world_interest_rates', 'table']","[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.8392351865768433],[0.5665705800056458],[1.0]
4,1,world interest rates Table,table-1000-57,2,2,0,38,1,2268,1.000000,...,0.279899,4,0,"[[['t_bills'], ['return']], [['15-year_dated_s...","['world_interest_rates', 'table']","[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.3828888237476349],[0.5665705800056458],[1.0]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5591,60,games age,table-0221-887,22,4,0,87,73,4080,0.045455,...,0.031775,2,2,"[[[], ['shot_&_shell'], ['roger_nord'], ['worl...",['games_age'],"[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.9950838685035706],[0.45424747467041016],[1.0]
5592,60,games age,table-0221-887,22,4,0,87,73,4080,0.045455,...,0.031775,2,2,"[[[], ['shot_&_shell'], ['roger_nord'], ['worl...",['games_age'],"[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.9950838685035706],[0.45424747467041016],[1.0]
5593,60,games age,table-0609-879,31,3,0,25,1,579,1.000000,...,0.013236,2,2,"[[['a_tale_the_desert'], [], ['egenesis']], [[...",['games_age'],"[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.47282683849334717],[0.45424747467041016],[1.0]
5594,60,games age,table-0609-879,31,3,0,25,1,579,1.000000,...,0.013236,2,2,"[[['a_tale_the_desert'], [], ['egenesis']], [[...",['games_age'],"[[[2384110], [2384110], [2384110], [2384110], ...","[[[2384110], [2384110], [2384110], [2384110], ...",[0.47282683849334717],[0.45424747467041016],[1.0]


# Run the trec model

In [12]:
class TREC_model():
    def __init__(self, data, output_dir, config):
        self.data = data
        self.config = config
        self.file_path = os.path.join(output_dir, config['trec']['file_name'])
        self.prep_data()
        make_dirs(output_dir)

    def prep_data(self):
        x_bf = ['row', 'col', 'nul', 'in_link', 'out_link', 'pgcount', 'tImp', 'tPF', 'leftColhits', 'SecColhits', 'bodyhits', 'PMI', 'qInPgTitle', 'qInTableTitle', 'yRank', 'csr_score', 'idf1',
                'idf2', 'idf3', 'idf4', 'idf5', 'idf6', 'max', 'sum', 'avg', 'sim', 'emax', 'esum', 'eavg', 'esim', 'cmax', 'csum', 'cavg', 'csim', 'remax', 'resum', 'reavg', 'resim', 'query_l']
        x_smf = ['cos_sim']
        x_f = x_bf
        y_f = ['rel']
        if self.config['trec']['semantic_f']:
            x_f += x_smf

        self.X = self.data[x_f]
        self.y = self.data[y_f]

    def train(self):
        kfold = KFold(5, True, 42)
        for i, indices in enumerate(kfold.split(self.X)):
            train_idx, test_idx = indices
            X_train, X_test, y_train, y_test = self.X.iloc[train_idx], self.X.iloc[
                test_idx], self.y.iloc[train_idx], self.y.iloc[test_idx]
            df = self.makeModel_getdf(X_train, X_test, y_train, y_test)
            df.to_csv(f"{self.file_path}{i}.txt",
                      sep=' ', index=False, header=False)

    def makeModel_getdf(self, X_train, X_test, y_train, y_test):
        self.clf = RandomForestClassifier(
            n_estimators=1000,
            max_features=3,
            random_state=42)
        self.clf.fit(X_train, y_train.values.ravel())
        X_test = mp(X_test, self.score_mp, 20)
#         print(X_test)
        df = self.generate_trec_df(self.generate_filtered_df(X_test, y_test))
        return df

    def score_mp(self, X_test):
        X_test['model_score'] = X_test.apply(
            lambda x: self.getScore(x), axis=1)
        return X_test

    def getScore(self, row):
        arr = self.clf.predict_proba(np.array(row).reshape(1, -1))
        return arr[0][1] + 2 * arr[0][2]

    def generate_filtered_df(self, X, y):
        df = pd.concat([
            self.data.iloc[list(X.index)][['query_id', 'query', 'table_id']],
            X['model_score']
        ], axis=1)
        return df

    def generate_trec_df(self, df):
        l = []
        dic = dict(df.query_id.value_counts())
        for i in dic:
            for j in range(1, dic[i] + 1):
                l.append(j)

        df_temp = pd.DataFrame()
        df_temp['query_id'] = df['query_id']
        df_temp['Q0'] = 'Q0'
        df_temp['table_id'] = df['table_id']
        df_temp['rank'] = l
        df_temp['score'] = df['model_score']
        df_temp['smarttable'] = 'smarttable'
        return df_temp

In [13]:
trec_path = os.path.join(args.path, f'TREC_results_lstlyr-ft_{args.model_name[:-3]}')
trec_model = TREC_model(data=baseline_f, output_dir=trec_path, config=config)
trec_model.train()

In [14]:
def ndcg_pipeline(path, trec_path, query_file_path):
    d = pd.DataFrame()
    for i in range(5):
        d_ = pd.read_csv(f"{path}{i}.txt", sep=' ', header=None)
        d_.drop(columns=[3], inplace=True)
        d = pd.concat([d, d_])

    d_sorted = d.sort_values(by=[0, 4], ascending=[True, False])

    d_sorted_filtered = pd.DataFrame()
    for i in range(1, 61):
        d_temp = d_sorted[d_sorted[0] == i]
        d_temp = d_temp.drop_duplicates(subset=2, keep="last")
        d_temp = d_temp.iloc[:20]
        d_sorted_filtered = pd.concat([d_sorted_filtered, d_temp])

    l = []
    for i in range(60):
        for j in range(1, 21):
            l.append(j)
    d_sorted_filtered[3] = l
    d_sorted_filtered = d_sorted_filtered[list(range(6))]
    d_sorted_filtered.to_csv(
        f'{path}all.txt', sep=' ', index=False, header=False)

    command = os.popen(
        f"{trec_path} -m ndcg_cut.5,10,15,20 {query_file_path} {path}all.txt")
    result = command.read()
    command.close()
    return result

In [15]:
ndcg_score = ndcg_pipeline(
    trec_model.file_path,
    '../trec_eval/trec_eval',
    '../global_data/qrels.txt')

In [16]:
print(ndcg_score)

ndcg_cut_5            	all	0.7698
ndcg_cut_10           	all	0.7646
ndcg_cut_15           	all	0.7768
ndcg_cut_20           	all	0.7987



In [17]:
!cat ./output/6_23_10_10_22/config.toml | grep comment
!../trec_eval/trec_eval -m ndcg_cut.5,10,15,20 ../global_data/qrels.txt ./output/6_23_10_10_22/TREC_results_lstlyr-ft_model_5/LTR_k5_all.txt

comment = "Running 6_12_20_41_11 with lr=3e-4, early stopping and lr scheduler(0.5 reduces), 1CNN+1FC"
ndcg_cut_5            	all	0.7698
ndcg_cut_10           	all	0.7646
ndcg_cut_15           	all	0.7768
ndcg_cut_20           	all	0.7987
