In [2]:
import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from ratsql.utils.analysis import cal_attention_flow
from ratsql.commands.infer import Inferer
from run_all import test_example, load_model, TestInfo
import _jsonnet
from search import read_data, match, show_results
import warnings

In [3]:
# Paths
project_dir = "/home/hkkang/NL2QGM"
original_eval_file = "logdir/spider_bert_run_no_join_cond_seed_0/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs_original_eval/bert_run_true_1-step_41600-eval.json"
original_infer_file = "logdir/spider_bert_run_no_join_cond_seed_0/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs_original_eval/bert_run_true_1-step_41600-infer.jsonl"

my_eval_file = "logdir/spider_bert_run_no_join_cond_seed_0/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs_my_eval/bert_run_true_1-step_41600-eval.json"
my_infer_file = "logdir/spider_bert_run_no_join_cond_seed_0/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs_my_eval/bert_run_true_1-step_41600-infer.jsonl"

# eval file path
ori_eval_path = os.path.join(project_dir, original_eval_file)
my_eval_path = os.path.join(project_dir, my_eval_file)
# Infer file path
ori_infer_path = os.path.join(project_dir, original_infer_file)
my_infer_path = os.path.join(project_dir, my_infer_file)

In [4]:
def load_json_custom(path):
    result = json.load(open(path))['per_item']
    print(len(result))
    return result

def load_jsonl(path):
    with open(path, 'r') as f:
        results = [json.loads(line) for line in f.readlines()]
    return results

def load_eval_files(eval_paths):
    eval_list = [load_json_custom(path) for path in eval_paths]
    return eval_list

def load_infer_files(infer_paths):
    infer_list = [load_jsonl(path) for path in infer_paths]
    return infer_list

def compare(e_data1, e_data2, i_data1, i_data2):
    assert len(e_data1) == len(e_data2)
    assert len(i_data1) == len(i_data2)
    assert len(e_data1) == len(i_data1)
    
    left_correct_only = []
    right_correct_only = []
    for idx ,(datum1, datum2) in enumerate(zip(e_data1, e_data2)):
        if datum1['exact'] != datum2['exact']:
            print(f"Idx:{idx} different! {bool(datum1['exact'])} - {datum2['exact']}")
            if datum1['exact']:
                left_correct_only.append(idx)
            else:
                right_correct_only.append(idx)

            # Detail print
            infer_datum1 = i_data1[idx]
            infer_datum2 = i_data2[idx]
            print(f"db_id: {infer_datum1['db_id']}")
            print(f"question: {infer_datum1['question']}")
            print(f"gold: {datum1['gold']}")
            print(f"pred1: {datum1['predicted']}")
            print(f"pred2: {datum2['predicted']}\n")
            
    print(f"Summary: {len(left_correct_only)} vs {len(right_correct_only)}")
    return left_correct_only, right_correct_only

In [5]:
ori_eval, my_eval = load_eval_files([ori_eval_path, my_eval_path])
ori_infer, my_infer = load_infer_files([ori_infer_path, my_infer_path])
compare(ori_eval, my_eval, ori_infer, my_infer)

1034
1034
Idx:51 different! False - True
db_id: pets_1
question: Find number of pets owned by students who are older than 20.
gold: SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid WHERE T1.age  >  20
pred1: SELECT Count(*) FROM Student JOIN Has_Pet ON Student.StuID = Has_Pet.StuID JOIN Pets ON Has_Pet.PetID = Pets.PetID WHERE Student.Age > 'terminal'
pred2: SELECT Count(*) FROM Student JOIN Has_Pet ON Student.StuID = Has_Pet.StuID JOIN Pets ON Has_Pet.PetID = Pets.PetID WHERE Student.Age > 'terminal'

Idx:52 different! False - True
db_id: pets_1
question: How many pets are owned by students that have an age greater than 20?
gold: SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid WHERE T1.age  >  20
pred1: SELECT Count(*) FROM Student JOIN Has_Pet ON Student.StuID = Has_Pet.StuID JOIN Pets ON Has_Pet.PetID = Pets.PetID WHERE Student.Age > 'terminal'
pred2: SELECT Count(*) FROM Student JOIN Has_Pet ON Student.StuID = Has_Pet.Stu

([203,
  204,
  206,
  209,
  210,
  213,
  214,
  215,
  216,
  217,
  218,
  219,
  220,
  221,
  225,
  226,
  227,
  228,
  229,
  230,
  233,
  234,
  235,
  236,
  237,
  238,
  239,
  240,
  245,
  246,
  247,
  248,
  250,
  253,
  254,
  255,
  256,
  451,
  452,
  487,
  488,
  544,
  707,
  888,
  889,
  904,
  905,
  906,
  907,
  908,
  909,
  912,
  913,
  944,
  945],
 [51, 52, 61, 62, 63, 79, 105, 146, 314, 353])

In [6]:
# Compare
project_dir = "/home/hkkang/NL2QGM/"

# BERT
bert_eval_0 = "logdir/spider_bert_run_no_join_cond_seed_0/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs_my_eval/bert_run_true_1-step_41600-eval.json"
bert_infer_0 = "logdir/spider_bert_run_no_join_cond_seed_0/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs_my_eval/bert_run_true_1-step_41600-infer.jsonl"
bert_eval_1 = "logdir/spider_bert_run_no_join_cond_seed_2/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=2,join_cond=false/ie_dirs_my_eval/bert_run_true_1-step_41600-eval.json"
bert_infer_1 = "logdir/spider_bert_run_no_join_cond_seed_2/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=2,join_cond=false/ie_dirs_my_eval/bert_run_true_1-step_41600-infer.jsonl"
bert_eval_2 = "logdir/spider_bert_run_no_join_cond_seed_3/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=3,join_cond=false/ie_dirs_my_eval/bert_run_true_1-step_41600-eval.json"
bert_infer_2 = "logdir/spider_bert_run_no_join_cond_seed_3/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=3,join_cond=false/ie_dirs_my_eval/bert_run_true_1-step_41600-infer.jsonl"
# combine
bert_evals = [os.path.join(project_dir, bert_eval_0)]
bert_evals += [os.path.join(project_dir, bert_eval_1)]
bert_evals += [os.path.join(project_dir, bert_eval_2)]
bert_infers = [os.path.join(project_dir, bert_infer_0)]
bert_infers += [os.path.join(project_dir, bert_infer_1)]
bert_infers += [os.path.join(project_dir, bert_infer_2)]

# ELECTRA
# electra_eval_1 = "logdir/spider_electra_run_no_join_cond_seed_1/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=1,join_cond=false/ie_dirs_my_eval/electra_run_true_1-step_27000-eval.json"
# electra_infer_1 = "logdir/spider_electra_run_no_join_cond_seed_1/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=1,join_cond=false/ie_dirs_my_eval/electra_run_true_1-step_27000-infer.jsonl"
# electra_eval_2 = "logdir/spider_electra_run_no_join_cond_seed_3/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=3,join_cond=false/ie_dirs_my_eval/electra_run_true_1-step_64000-eval.json"
# electra_infer_2 = "logdir/spider_electra_run_no_join_cond_seed_3/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=3,join_cond=false/ie_dirs_my_eval/electra_run_true_1-step_64000-infer.jsonl"
# electra_eval_3 = "logdir/spider_electra_run_no_join_cond_seed_3/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=3,join_cond=false/ie_dirs_my_eval/electra_run_true_1-step_72000-eval.json"
# electra_infer_3 = "logdir/spider_electra_run_no_join_cond_seed_3/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=3,join_cond=false/ie_dirs_my_eval/electra_run_true_1-step_72000-infer.jsonl"
electra_eval_1 = "logdir/spider_electra_run_no_join_cond_seed_0_squad_batch_32/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs/electra_run_true_1-step_18000-eval.json"
electra_infer_1 = "logdir/spider_electra_run_no_join_cond_seed_0_squad_batch_32/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs/electra_run_true_1-step_18000-infer.jsonl"
electra_eval_2 = "logdir/spider_electra_run_no_join_cond_seed_0_squad_batch_32/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs/electra_run_true_1-step_25000-eval.json"
electra_infer_2 = "logdir/spider_electra_run_no_join_cond_seed_0_squad_batch_32/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs/electra_run_true_1-step_25000-infer.jsonl"
electra_eval_3 = "logdir/spider_electra_run_no_join_cond_seed_0_squad_batch_32/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs/electra_run_true_1-step_26000-eval.json"
electra_infer_3 = "logdir/spider_electra_run_no_join_cond_seed_0_squad_batch_32/bs=8,lr=7.4e-04,bert_lr=1.0e-05,end_lr=0e0,seed=0,join_cond=false/ie_dirs/electra_run_true_1-step_26000-infer.jsonl"


# combine
electra_evals = [os.path.join(project_dir, electra_eval_1)]
electra_evals += [os.path.join(project_dir, electra_eval_2)]
electra_evals += [os.path.join(project_dir, electra_eval_3)]
electra_infers = [os.path.join(project_dir, electra_infer_1)]
electra_infers += [os.path.join(project_dir, electra_infer_2)]
electra_infers += [os.path.join(project_dir, electra_infer_3)]

In [7]:
# Read in files
# BERT
bert_evals = load_eval_files(bert_evals)
bert_infers = load_infer_files(bert_infers)
# ELECTRA
electra_evals = load_eval_files(electra_evals)
electra_infers = load_infer_files(electra_infers)

1034
1034
1034
1034
1034
1034


In [8]:
def all_wrong(evals, idx):    
    flags = [result[idx]['exact'] for result in evals]
    return True not in flags

In [9]:
bert_all_wrong = []
electra_all_wrong = []
for idx in range(1034):
    if all_wrong(bert_evals, idx):
        bert_all_wrong.append(idx)
    if all_wrong(electra_evals, idx):
        electra_all_wrong.append(idx)
print(f"Bert all wrong ({len(bert_all_wrong)}): {bert_all_wrong}")
print(f"\nElectra all wrong ({len(electra_all_wrong)}): {electra_all_wrong}")
    

Bert all wrong (253): [5, 6, 16, 17, 43, 44, 47, 48, 50, 64, 65, 66, 68, 83, 84, 94, 96, 97, 98, 101, 102, 109, 110, 116, 123, 124, 129, 130, 131, 132, 133, 138, 140, 142, 154, 158, 159, 160, 161, 162, 166, 167, 168, 171, 172, 174, 175, 177, 178, 180, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 224, 225, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 253, 254, 255, 256, 257, 258, 265, 336, 407, 408, 413, 423, 448, 451, 452, 453, 454, 459, 460, 470, 479, 480, 485, 486, 487, 488, 493, 494, 495, 500, 505, 520, 526, 533, 534, 536, 537, 539, 541, 542, 549, 550, 557, 558, 559, 560, 561, 571, 572, 575, 576, 578, 583, 584, 591, 592, 637, 642, 645, 646, 691, 707, 709, 712, 713, 716, 717, 724, 725, 726, 728, 736, 737, 738, 739, 744, 745, 748, 749, 754, 755, 757, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 792, 793, 798, 799, 809

In [10]:
only_bert_all_wrong = []
only_electra_all_wrong = []
together_all_wrong = []
for idx in range(1034):
    if all_wrong(bert_evals, idx) and not all_wrong(electra_evals, idx):
        print(f"idx:{idx} only bert all wrong!")
        only_bert_all_wrong.append(idx)
    if all_wrong(electra_evals, idx) and not all_wrong(bert_evals, idx):
        print(f"idx:{idx} only electra all wrong!")
        only_electra_all_wrong.append(idx)
        
    if all_wrong(electra_evals, idx) and all_wrong(bert_evals, idx):
        together_all_wrong.append(idx)
    
print(f"\nOnly bert all wrong:{only_bert_all_wrong}")
print(f"\nOnly electra all wrong:{only_electra_all_wrong}")
print(f"\nTogether all wrong:{together_all_wrong}")
    
print(f"\nSummary: only_bert_all_wrong:{len(only_bert_all_wrong)}")
print(f"Summary: only_electra_all_wrong:{len(only_electra_all_wrong)}")
print(f"Summary: together_all_wrong:{len(together_all_wrong)}")

idx:4 only electra all wrong!
idx:6 only bert all wrong!
idx:16 only bert all wrong!
idx:50 only bert all wrong!
idx:60 only electra all wrong!
idx:61 only electra all wrong!
idx:62 only electra all wrong!
idx:63 only electra all wrong!
idx:68 only bert all wrong!
idx:79 only electra all wrong!
idx:80 only electra all wrong!
idx:83 only bert all wrong!
idx:84 only bert all wrong!
idx:94 only bert all wrong!
idx:95 only electra all wrong!
idx:121 only electra all wrong!
idx:122 only electra all wrong!
idx:132 only bert all wrong!
idx:140 only bert all wrong!
idx:150 only electra all wrong!
idx:158 only bert all wrong!
idx:165 only electra all wrong!
idx:173 only electra all wrong!
idx:176 only electra all wrong!
idx:180 only bert all wrong!
idx:209 only bert all wrong!
idx:210 only bert all wrong!
idx:213 only bert all wrong!
idx:214 only bert all wrong!
idx:215 only bert all wrong!
idx:216 only bert all wrong!
idx:217 only bert all wrong!
idx:218 only bert all wrong!
idx:219 only bert 

In [11]:
# Show results
def print_result(eval_datas, infer_datas, idx):
    infer_datum = infer_data[0][idx]
    eval_datum = eval_data[0][idx]
    print(f"idx: {idx}")
    print(f"db_id: {infer_datum['db_id']}")
    print(f"question: {infer_datum['question']}\n")
    print(f"GOLD: {eval_datum['gold']}")
    print(f"PRED1: {eval_datas[0][idx]['predicted']}")
    print(f"PRED2: {eval_datas[1][idx]['predicted']}")
    print(f"PRED3: {eval_datas[2][idx]['predicted']}\n\n")

In [37]:
indices = [731]
eval_data = electra_evals
infer_data = electra_infers

for idx in indices:
    print_result(eval_data, infer_data, idx)

idx: 731
db_id: world_1
question: Give the mean GNP and total population of nations which are considered US territory.

GOLD: SELECT avg(GNP) ,  sum(population) FROM country WHERE GovernmentForm  =  "US Territory"
PRED1: SELECT Avg(country.GNP), Sum(country.Population) FROM country WHERE country.GovernmentForm = 'terminal'
PRED2: SELECT Avg(country.GNP), Sum(country.Population) FROM country WHERE country.GovernmentForm = 'terminal'
PRED3: SELECT Avg(country.GNP), Sum(country.Population) FROM country WHERE country.GovernmentForm = 'terminal'


