In [76]:
import os
import pandas
from transformers import BertTokenizer, BertModel
import torch

device = torch.device('cuda')

model_type = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_type)

BERT = BertModel.from_pretrained(model_type)
BERT.to(device)

f = '1-10006830-1.html.csv'

table = pandas.read_csv('all_csv/{}'.format(f), '#')
cols = table.columns

representation = torch.FloatTensor(len(table), len(cols), 768).to(device)
for i in range(len(table)):
    entry = table.iloc[i]
    for j, col in enumerate(cols):
        text = '[CLS] in row {}, {} is {}'.format(i+1, col, entry[col])
        inp = torch.tensor([tokenizer.encode(text)]).to(device)
        encoded = BERT(inp)[1][0]
        representation[i][j] = encoded

In [23]:
table

Unnamed: 0,aircraft,description,max gross weight,total disk area,max disk loading
0,robinson r - 22,light utility helicopter,1370 lb (635 kg),497 ft square (46.2 m square),2.6 lb / ft square (14 kg / m square)
1,bell 206b3 jetranger,turboshaft utility helicopter,3200 lb (1451 kg),872 ft square (81.1 m square),3.7 lb / ft square (18 kg / m square)
2,ch - 47d chinook,tandem rotor helicopter,50000 lb (22680 kg),5655 ft square (526 m square),8.8 lb / ft square (43 kg / m square)
3,mil mi - 26,heavy - lift helicopter,123500 lb (56000 kg),8495 ft square (789 m square),14.5 lb / ft square (71 kg / m square)
4,ch - 53e super stallion,heavy - lift helicopter,73500 lb (33300 kg),4900 ft square (460 m square),15 lb / ft square (72 kg / m square)


In [18]:
import torch.nn as nn
import torch
import math
import torch.nn.functional as F
import numpy as np

class FC(nn.Module):
    def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
        super(FC, self).__init__()
        self.dropout_r = dropout_r
        self.use_relu = use_relu

        self.linear = nn.Linear(in_size, out_size)

        if use_relu:
            self.relu = nn.ReLU(inplace=False)

        if dropout_r > 0:
            self.dropout = nn.Dropout(dropout_r, inplace=False)

    def forward(self, x):
        x = self.linear(x)

        if self.use_relu:
            x = self.relu(x)

        if self.dropout_r > 0:
            x = self.dropout(x)

        return x


class FFN(nn.Module):
    def __init__(self, hidden_size, ff_size, dropout):
        super(FFN, self).__init__()

        self.mlp = MLP(
            in_size=hidden_size,
            mid_size=ff_size,
            out_size=hidden_size,
            dropout_r=dropout,
            use_relu=True
        )

    def forward(self, x):
        return self.mlp(x)


class MLP(nn.Module):
    def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
        super(MLP, self).__init__()

        self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
        self.linear = nn.Linear(mid_size, out_size)

    def forward(self, x):
        return self.linear(self.fc(x))


class LayerNorm(nn.Module):
    def __init__(self, size, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.eps = eps

        self.a_2 = nn.Parameter(torch.ones(size))
        self.b_2 = nn.Parameter(torch.zeros(size))

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)

        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
    
    
class MHAtt(nn.Module):
    def __init__(self, head_num, hidden_size, dropout, hidden_size_head):
        super(MHAtt, self).__init__()
        self.head_num = head_num
        self.hidden_size = hidden_size
        self.hidden_size_head = hidden_size_head
        self.linear_v = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_merge = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout, inplace=False)

    def forward(self, v, k, q, mask):
        n_batches = q.size(0)

        v = self.linear_v(v).view(
            n_batches,
            -1,
            self.head_num,
            self.hidden_size_head
        ).transpose(1, 2)

        k = self.linear_k(k).view(
            n_batches,
            -1,
            self.head_num,
            self.hidden_size_head
        ).transpose(1, 2)

        q = self.linear_q(q).view(
            n_batches,
            -1,
            self.head_num,
            self.hidden_size_head
        ).transpose(1, 2)

        atted = self.att(v, k, q, mask)
        atted = atted.transpose(1, 2).contiguous().view(
            n_batches,
            -1,
            self.hidden_size
        )

        atted = self.linear_merge(atted)

        return atted

    def att(self, value, key, query, mask):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)

        att_map = F.softmax(scores, dim=-1)
        att_map = self.dropout(att_map)

        return torch.matmul(att_map, value)
    

class SA(nn.Module):
    def __init__(self, hidden_size, head_num, ff_size, dropout, hidden_size_head):
        super(SA, self).__init__()

        self.mhatt = MHAtt(head_num, hidden_size, dropout, hidden_size_head)
        self.ffn = FFN(hidden_size, ff_size, dropout)

        self.dropout1 = nn.Dropout(dropout, inplace=False)
        self.norm1 = LayerNorm(hidden_size)

        self.dropout2 = nn.Dropout(dropout, inplace=False)
        self.norm2 = LayerNorm(hidden_size)

    def forward(self, x, x_mask):
        output = self.mhatt(x, x, x, x_mask)
        dropout_output = self.dropout1(output)
        x = self.norm1(x + dropout_output)

        x = self.norm2(x + self.dropout2(
            self.ffn(x)
        ))

        return x

In [82]:
from allennlp.nn import util

class NumGNN(nn.Module):

    def __init__(self, node_dim, iteration_steps=1):
        super(NumGNN, self).__init__()

        self.node_dim = node_dim
        self.iteration_steps = iteration_steps

        self._node_weight_fc = torch.nn.Linear(node_dim, 1, bias=True)
        self._self_node_fc = torch.nn.Linear(node_dim, node_dim, bias=True)
        
        self._dd_node_fc_left = torch.nn.Linear(node_dim, node_dim, bias=False)
        self._dd_node_fc_right = torch.nn.Linear(node_dim, node_dim, bias=False)

    def forward(self, d_node, graph):
        d_node_len = d_node.size(1)

        diagmat = torch.diagflat(torch.ones(d_node.size(1), dtype=torch.long, device=d_node.device))
        diagmat = diagmat.unsqueeze(0).expand(d_node.size(0), -1, -1)
        dd_graph = 1 - diagmat
        
        dd_graph_left = dd_graph * graph[:, :d_node_len, :d_node_len]
        dd_graph_right = dd_graph * (1 - graph[:, :d_node_len, :d_node_len])

        d_node_neighbor_num = dd_graph_left.sum(-1) + dd_graph_right.sum(-1)
        d_node_neighbor_num_mask = (d_node_neighbor_num >= 1).long()
        d_node_neighbor_num = util.replace_masked_values(d_node_neighbor_num.float(), d_node_neighbor_num_mask, 1)
        
        for step in range(self.iteration_steps):
            d_node_weight = torch.sigmoid(self._node_weight_fc(d_node)).squeeze(-1)            

            self_d_node_info = self._self_node_fc(d_node)

            dd_node_info_left = self._dd_node_fc_left(d_node)
            
            dd_node_weight = util.replace_masked_values(
                    d_node_weight.unsqueeze(1).expand(-1, d_node_len, -1),
                    dd_graph_left,
                    0)
            
            dd_node_info_left = torch.matmul(dd_node_weight, dd_node_info_left)

            dd_node_info_right = self._dd_node_fc_right(d_node)
            
            dd_node_weight = util.replace_masked_values(
                    d_node_weight.unsqueeze(1).expand(-1, d_node_len, -1),
                    dd_graph_right,
                    0)
            
            dd_node_info_right = torch.matmul(dd_node_weight, dd_node_info_right)

            agg_d_node_info = (dd_node_info_left + dd_node_info_right) / d_node_neighbor_num.unsqueeze(-1)

            d_node = F.relu(self_d_node_info + agg_d_node_info)

        return d_node

In [83]:
self_attention = SA(hidden_size=768, head_num=4, ff_size=768, dropout=0.1, hidden_size_head=768 // 4)
self_attention.to(device)

representation = self_attention(representation, None)

def detect_number(string):
    if '-' not in string:
        for _ in string.split(' '):
            try:
                __ = float(_)
                return __
            except:
                return string
    else:
        return string
import scipy.stats as ss

outputs = []
idxs = []
for i, col in enumerate(cols):
    tmp = []
    for _ in table[col]:
        tmp.append(detect_number(_))
    if all([isinstance(_, float) for _ in tmp]):
        outputs.append(ss.rankdata(tmp))
        idxs.append(i)

number_rank = torch.FloatTensor(outputs)
graph_mask = number_rank.unsqueeze(1).expand(3, number_rank.size(-1), -1) \
    < number_rank.unsqueeze(-1).expand(3, -1, number_rank.size(-1))
graph = graph_mask.long().to(device)

d_nodes = representation[idxs]

In [78]:
d_nodes.device

device(type='cuda', index=0)

In [80]:
graph.device

device(type='cuda', index=0)

In [85]:
gnn = NumGNN(768, 1)
gnn.to(device)

d_node = gnn(d_nodes, graph)

In [3]:
import json
import re
import pandas

with open('../Table-Fact-Checking/data/train_id.json') as f:
    train_ids = json.load(f)
    
with open('../Table-Fact-Checking/data/val_id.json') as f:
    val_ids = json.load(f)
    
with open('../Table-Fact-Checking/data/test_id.json') as f:
    test_ids = json.load(f)

with open('../Table-Fact-Checking/data/simple_test_id.json') as f:
    simple_test_ids = json.load(f)

with open('../Table-Fact-Checking/data/complex_test_id.json') as f:
    complex_test_ids = json.load(f)
    
def collect_cols(string):
    semi_column = False
    buf = ''
    cols = []
    for c in string:
        if c == ';':
            semi_column = True
            buf = ''
            continue
        
        if semi_column:
            if c == '#':
                semi_column = False
                cols.extend(json.loads(buf))
                buf = ''
            else:
                buf += c
    
    cols = [_ for _ in list(set(cols)) if _ >= 0]
    return cols
        
with open('data/full_cleaned_aggressive.json') as f:
    full_cleaned = json.load(f)
    
pattern = '#([^;]);[^#]#'
"""
for ids, split in zip([train_ids, val_ids, test_ids, simple_test_ids, complex_test_ids], \
                      ['train', 'val', 'test', 'simple_test', 'complex_test']):
    examples = []
    for t_id in ids:
        table = pandas.read_csv('all_csv/{}'.format(t_id), '#')
        columns = table.columns
        
        pair = full_cleaned[t_id]
        for sent, label in zip(pair[0], pair[1]):
            cols = collect_cols(sent)
            if len(cols) == 0:
                cols = [0, 1, 2]
            else:
                if 0 not in cols:
                    cols.insert(0, 0)
            
            text = ''
            for i in range(len(table)):
                text += 'row {} is : '.format(i + 1)
                entry = table.iloc[i]
                for col in cols:
                    text += '{} is {} , '.format(columns[col], entry[col])
                
                if i < len(table) - 1:
                    text = text[:-2] + ' . '
                else:
                    text = text[:-2]
            
            sent = re.sub(r'#([^;]+);([^#]+)#', r'\1', sent).lower()
            
            examples.append((t_id, text, pair[-1], sent, label))

    with open('data/{}_baseline_examples.json'.format(split), 'w') as f:
        json.dump(examples, f, indent=2)
"""   

for ids, split in zip([train_ids, val_ids, test_ids, simple_test_ids, complex_test_ids], \
                      ['train', 'val', 'test', 'simple_test', 'complex_test']):
    sents = {}
    for t_id in ids:
        pair = full_cleaned[t_id]
        for sent in pair[0]:
            cols = collect_cols(sent)
            if len(cols) == 0:
                cols = [0, 1, 2]
            else:
                if 0 not in cols:
                    cols.insert(0, 0)

            sent = re.sub(r'#([^;]+);([^#]+)#', r'\1', sent).lower()
            if t_id not in sents:
                sents[t_id] = [[sent], [cols], pair[1], pair[-1]]
            else:
                sents[t_id][0].append(sent)
                sents[t_id][1].append(cols)

    with open('data/{}_examples.json'.format(split), 'w') as f:
        json.dump(sents, f, indent=2)
    
    total = sum([len(v[1]) for k, v in sents.items()])
    print('totally {} table with {} statements for {}'.format(len(sents), total, split))

totally 13182 table with 92283 statements for train
totally 1696 table with 12792 statements for val
totally 1695 table with 12779 statements for test
totally 833 table with 4171 statements for simple_test
totally 862 table with 8608 statements for complex_test


In [9]:
import json
import pandas

with open('../Table-Fact-Checking/data/test_id.json') as f:
    test_ids = json.load(f)

for t_id in test_ids:
    table = pandas.read_csv('all_csv/{}'.format(t_id), '#')
    print(table)
    columns = table.columns
    print(table.dtypes)
    break

       outcome               date                  location surface  \
0       winner         2 may 1999    coatzacoalcos , mexico    hard   
1       winner       11 july 1999      felixstowe , england   grass   
2  runner - up    6 february 2000  wellington , new zealand    hard   
3  runner - up        28 may 2000     el paso , texas , usa    hard   
4  runner - up  14 september 2003           spoleto , italy    clay   
5       winner    6 february 2005  wellington , new zealand    hard   
6       winner     28 august 2005              jesi , italy    hard   
7       winner    5 february 2006       taupo , new zealand    hard   
8       winner   12 february 2006  wellington , new zealand    hard   
9       winner      20 april 2008         mazatlán , mexico    hard   

     opponent in final                      score  
0      candice jairala          3 - 6 6 - 3 7 - 5  
1         karen nugent                6 - 4 6 - 4  
2    mirielle dittmann  6 - 7 (5) 6 - 1 6 - 7 (5)  
3        e