In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("/workspace/kbqa/")  # go to parent dir

In [2]:
import ujson
import jsonlines
import networkx as nx
import pandas as pd
from tqdm import tqdm

from pathlib import Path

### Getting JSONL files

In [3]:
def read_jsonl(path):
    jsonl_reader = jsonlines.open(path)
    jsonl_reader_list = list(jsonl_reader)
    nx_graphs = []
    df = []
    for idx, line in tqdm(enumerate(jsonl_reader_list)):
        df.append(line)
    df = pd.DataFrame(df)
    return df

In [4]:
# getting the jsonl data
with_candidate = True
dataset, dataset_ = "MINTAKA", "mintaka"
dataset_type = f"{dataset_}_train_labeled"
subgraph_path = (
    f"/workspace/storage/new_subgraph_dataset/{dataset}/{dataset_type}.jsonl"
)

train_df = read_jsonl(subgraph_path)

98033it [00:00, 2104030.72it/s]


In [5]:
train_grouped = train_df.groupby("question")
len(train_grouped)

9872

In [6]:
dataset_type = f"{dataset_}_test_labeled"
subgraph_path = (
    f"/workspace/storage/new_subgraph_dataset/{dataset}/{dataset_type}.jsonl"
)
test_df = read_jsonl(subgraph_path)

28325it [00:00, 2004549.93it/s]


In [7]:
test_grouped = test_df.groupby("question")
len(test_grouped)

2815

In [8]:
dataset_type = f"{dataset_}_validation_labeled"
subgraph_path = (
    f"/workspace/storage/new_subgraph_dataset/{dataset}/{dataset_type}.jsonl"
)
val_df = read_jsonl(subgraph_path)

14286it [00:00, 2401211.31it/s]


In [9]:
val_grouped = val_df.groupby("question")
len(val_grouped)

1419

In [10]:
len(train_df), len(val_df), len(test_df)

(98033, 14286, 28325)

In [11]:
df = pd.concat([train_df, val_df, test_df])

In [12]:
def flatten_list(col):
    res = []
    for i in col:
        if len(i) == 0:
            res.append(None)
        else:
            res.append(i[0])
    return res


# clean our df
cols = ["answerEntity", "questionEntity", "groundTruthAnswerEntity"]
df[cols] = df[cols].apply(flatten_list)
df.to_csv(
    f"/workspace/storage/new_subgraph_dataset/{dataset}/{dataset_}_combined.jsonl",
    index=False,
)

In [28]:
correct_df = df[df["answerEntity"] == df["groundTruthAnswerEntity"]]
correct_df = correct_df.dropna()
correct_df

Unnamed: 0,id,question,answerEntity,questionEntity,groundTruthAnswerEntity,complexityType,graph
3,a9011ddf,What is the seventh tallest mountain in North ...,Q1153188,Q49,Q1153188,ordinal,"{'directed': True, 'multigraph': False, 'graph..."
8,2723bb1b,Which actor was the star of Titanic and was bo...,Q38111,Q44578,Q38111,intersection,"{'directed': True, 'multigraph': False, 'graph..."
18,88349c89,Which actor starred in Vanilla Sky and was mar...,Q37079,Q174346,Q37079,intersection,"{'directed': True, 'multigraph': False, 'graph..."
33,982450cf,Who is the youngest current US governor?,Q3105215,Q889821,Q3105215,superlative,"{'directed': True, 'multigraph': False, 'graph..."
46,fe541d01,Which US president has had the most votes?,Q6279,Q30,Q6279,superlative,"{'directed': True, 'multigraph': False, 'graph..."
...,...,...,...,...,...,...,...
28257,5bd6c9cc,What is the name of the main police officer in...,Q1138965,Q275950,Q1138965,generic,"{'directed': True, 'multigraph': False, 'graph..."
28267,3490b793,Who is the protagonist in the God of War series?,Q2291154,Q390137,Q2291154,generic,"{'directed': True, 'multigraph': False, 'graph..."
28279,9fc4810d,Who is the protagonist of God of War?,Q2291154,Q390137,Q2291154,generic,"{'directed': True, 'multigraph': False, 'graph..."
28288,5e745d31,Who is the protagonist in Halo?,Q652022,Q1747150,Q652022,generic,"{'directed': True, 'multigraph': False, 'graph..."


### Getting CSV file

In [15]:
# loading in the dataset

if dataset_ == "mintaka":
    train_res = pd.read_csv(f"/workspace/storage/{dataset_}_seq2seq/train.csv")
    val_res = pd.read_csv(f"/workspace/storage/{dataset_}_seq2seq/validation.csv")
    test_res = pd.read_csv(f"/workspace/storage/{dataset_}_seq2seq/test.csv")
    res_csv = pd.concat([train_res, val_res, test_res])
    res_csv["target"] = res_csv["target"].apply(lambda x: x.strip("['']"))
    res_csv["question"] = res_csv["question"].str.strip()
else:  # sqwd
    res_csv = pd.read_csv(f"/workspace/storage/{dataset_}_seq2seq/results.csv")
    res_csv["target"] = res_csv["target"].apply(lambda x: x.strip("['']"))
    res_csv["question"] = res_csv["question"].str.strip()

res_csv.head()

Unnamed: 0,question,target,answer_0,answer_1,answer_2,answer_3,answer_4,answer_5,answer_6,answer_7,...,answer_191,answer_192,answer_193,answer_194,answer_195,answer_196,answer_197,answer_198,answer_199,target_out_of_vocab
0,What man was a famous American author and also...,Mark Twain,Mark Twain,Mark Twain,Harriet Beecher Stowe,Charles Dickens,William Faulkner,Mark Twain,Harriet Beecher Stowe,H. G. Wells,...,Henry James,Theodore Sturgeon,H. P. Lovecraft,Stephen Crane,Horatio Bottomley,William Faulkner,Mark Twain,Edgar Allan Poe.,Horatio Parker,False
1,How many Academy Awards has Jake Gyllenhaal be...,1,1,1,1,1,1,1,2,1,...,7,11,12,One,13,128,215,128,128,False
2,"Who is older, The Weeknd or Drake?",Drake,The Weeknd,Drake,Drake,The Weeknd,Drake,Drake,Drake,The Weeknd,...,Draco,DJ Khaled,TWiG,"Drake,",Weeknd,TWENTY,TWiT,Twice as old,"The Weeknd,",False
3,How many children did Donald Trump have?,5,2,3,2,3,2,3,2,4,...,13,7,9,4 children,11,8,10,12,13,False
4,Is the main hero in Final Fantasy IX named Kuja?,No,Yes,Yes,Yes,Yes,Yes,Yes,Yes,Yes,...,Yu Yu Hakukui,The Final Fantasy IX.,Is Kuja the Hero,The Answer Is No,Is Final Fantasy VIII,Yu Yu Hakuku,"Yep, yes",YYYY,The Final Fantasy VII Final Fantasy,False


### Original/raw top1 & top200

In [18]:
top200 = 0
top1 = 0
res_csv_grouped = res_csv.groupby("question")
for name, group in tqdm(res_csv_grouped):
    gold_answers = group["target"].values[0].split(",")

    if str(group["answer_0"].values[0]) in str(group["target"].values[0]):
        top1 += 1

    group_filtered = group.drop(["question", "target_out_of_vocab", "target"], axis=1)
    for gold_answer in gold_answers:
        if gold_answer in group_filtered.values[0]:
            top200 += 1
            continue

100%|██████████| 4000/4000 [00:01<00:00, 2596.54it/s]


In [19]:
top1 / len(res_csv), top200 / len(res_csv)

(0.27925, 0.65125)

### Filter the data in jsonl and csv 

In [21]:
# getting questions in result csv that exist in our jsonl
res_filtered = []
for index, row in tqdm(res_csv.iterrows()):
    question = row[0].strip()
    curr_ques_df = df[df["question"] == question]
    if len(curr_ques_df) > 0:
        res_filtered.append(row)

4000it [00:49, 80.32it/s]


In [22]:
res_filtered = pd.DataFrame(res_filtered)

In [21]:
res_filtered.to_csv(
    f"/workspace/storage/{dataset_}_seq2seq/results_filtered.csv", index=False
)

In [22]:
res_filtered.head()

Unnamed: 0,question,target,answer_0,answer_1,answer_2,answer_3,answer_4,answer_5,answer_6,answer_7,...,answer_191,answer_192,answer_193,answer_194,answer_195,answer_196,answer_197,answer_198,answer_199,target_out_of_vocab
0,What is the seventh tallest mountain in North ...,Mount Lucania,Mount McKinley,Mount McKinley,Mount St. Elias,Mount Rainier,Denali,Mount McKinley,Denali,Mount Rainier,...,Mount McKinlay,Ben Nevis,Mt. Whitney,Kangchenjunga Mountain,Mt. Massive,Mount Hood,Mt. Marcy,Cascade Peak,Mount McLoughlin,False
1,Which actor was the star of Titanic and was bo...,Leonardo DiCaprio,Leonardo DiCaprio,Leonardo DiCaprio,Leonardo DiCaprio,Leonardo DiCaprio,Meryl Streep,Matthew McConaughey,Leonardo Di Caprio,Robert Pattinson,...,Kevin Spacey,Kate Winslet.,Leonardo di Caprio,Robert Pattinson.,James Franco,Samuel L. Jackson.,Ryan Reynolds,Harrison Ford,Leonardo Di Caprio,False
2,Which actor starred in Vanilla Sky and was mar...,Tom Cruise,Tom Hanks,Tom Cruise,Tom Hanks,Tom Cruise,Tom Cruise,Tom Hanks,Tom Cruise,Tom Hanks,...,Tom Hanks.,Dustin Hoffman,Matt Damon,Will Smith,Harrison Ford,Tom Cruise.,James Franco,Russell Crowe,"Tom Cruise, Jr.",False
4,Who is the youngest current US governor?,Ron DeSantis,Bobby Jindal,Jon Corzine,Rick Perry,Jennifer Granholm,Bobby Jindal,Steve Beshear,Kay Ivey,Mike Pence,...,Jay Nixon,Scott Walker,Rick Perry,Chris Christie,Mike Pence.,Jennifer Granholm.,Gary Herbert,"Scott Walker, Jr.",Scott Walker,False
5,Which US president has had the most votes?,Joe Biden,Donald Trump,George W. Bush,Barack Obama,Donald Trump,John F. Kennedy,George Washington,Barack Obama,Theodore Roosevelt,...,John F Kennedy.,Franklin D. Roosevelt,John F Kennedy,George W Bush,Ronald Reagan,Bill Clinton,Barack Obama.,George Washington,Harry S Truman,False


### Final re-ranking

In [24]:
def graph_to_sequence(subgraph, node_names):
    # getting adjency matrix and weight info
    adj_matrix = nx.adjacency_matrix(subgraph).todense().tolist()
    edge_data = subgraph.edges.data()

    # adding our edge info
    for edge in edge_data:
        i, j, data = edge
        i, j = int(i), int(j)
        adj_matrix[i][j] = data["label"]

    sequence = []
    # for adjency matrix, i, j means node i -> j
    for i, row in enumerate(adj_matrix):
        from_node = node_names[i]  # from node (node i)
        for j, edge_info in enumerate(row):
            to_node = node_names[j]
            if edge_info == 0:  # no endge from_node -> to_node
                # sequence.extend([from_node, "None", to_node])
                pass
            else:
                sequence.extend([from_node, edge_info, to_node])
    sequence = ",".join(str(node) for node in sequence)
    return sequence

In [25]:
def get_node_names(subgraph):
    node_names = [subgraph.nodes[node]["label"] for node in subgraph.nodes()]
    return node_names

In [26]:
def get_graph_seqs(graphs):
    graph_seq = []
    for graph in graphs:
        try:
            graph_obj = nx.readwrite.json_graph.node_link_graph(graph)
            graph_node_names = get_node_names(graph_obj)
            curr_seq = graph_to_sequence(graph_obj, graph_node_names)
        except KeyError:
            print("ERROR NO LABEL!")
            curr_seq = "ERROR_NO_LABEL"
        except nx.NetworkXError:
            print("ERROR EMPTY GRAPHS!")
            curr_seq = "ERROR_EMPTY_GRAPH"
        graph_seq.append(curr_seq)
    return graph_seq

In [None]:
import torch
import pandas as pd
import torch
import numpy as np
from transformers import RobertaTokenizer, BertTokenizer, BertModel, RobertaModel
from torch import nn
from torch.optim import Adam
from tqdm import tqdm

model_type = "roberta-large"


class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super(BertClassifier, self).__init__()
        if model_type == "roberta-large" or model_type == "roberta-case":
            self.model = RobertaModel.from_pretrained(model_type)
        else:
            self.model = BertModel.from_pretrained(model_type)
        self.dropout = nn.Dropout(dropout)
        if model_type == "bert-large-cased" or model_type == "roberta-large":
            dim = 1024
        else:
            dim = 768
        self.linear = nn.Linear(dim, 2)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.model(
            input_ids=input_id, attention_mask=mask, return_dict=False
        )
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer


model = torch.load(
    "/workspace/storage/subgraph_classify_models/mintaka/candidates_True/BertClassifier_mintaka_roberta-large_sampler.pt"
)
tokenizer = RobertaTokenizer.from_pretrained(model_type)

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [19]:
from kbqa.wikidata import WikidataEntityToLabel

entity2label = WikidataEntityToLabel()

  from .autonotebook import tqdm as notebook_tqdm


In [77]:
res_csv_grouped = res_filtered.groupby(["question"])
jsonl_grouped = df.groupby(["question"])

In [78]:
def beams_answers_to_labels(beams, entity2label):
    labels_entity = {}
    for beam in beams:
        labels_entity[entity2label.get_label(beam)] = beam
    return labels_entity

In [None]:
from ast import literal_eval
from unidecode import unidecode


def try_literal_eval(s):
    try:
        return literal_eval(s)
    except ValueError:
        return s

In [80]:
total_correct = 0
extra_correct = 0

for name, group in tqdm(res_csv_grouped):
    name = name.strip()
    all_beams = group.iloc[0, 3:-1].values.tolist()  # all 200 beams
    pred_answer = group["answer_0"].values[0]

    if dataset_ == "sqwd":
        target = group["target"].values[0].strip("['']")
    else:
        target = group["target"].values[0]

    # get rid of accents
    target, pred_answer = unidecode(target), unidecode(pred_answer)
    if target == pred_answer:  # correct answer
        total_correct += 1
    else:  # wrong answer
        with torch.no_grad():
            # find the beams that exist in jsonl to get our subgraphs
            beams_jsonl = jsonl_grouped.get_group(name)

            # get all answer entity along with their label
            beams_jsonl_answer = beams_jsonl["answerEntity"].tolist()
            beams_jsonl_labels_entities = beams_answers_to_labels(
                beams_jsonl_answer, entity2label
            )
            beams_jsonl_labels = beams_jsonl_labels_entities.keys()
            beams_jsonl_entities = beams_jsonl_labels_entities.values()

            # out of 200 beams, get beams that exist in jsonl
            existing_beams_labels = list(
                set(all_beams).intersection(beams_jsonl_labels)
            )
            existing_beams_entities = [
                beams_jsonl_labels_entities[x] for x in existing_beams_labels
            ]
            existing_beams = beams_jsonl[
                beams_jsonl["answerEntity"].isin(existing_beams_entities)
            ]

            # get subgraphs and their sequences
            subgraphs = existing_beams["graph"].tolist()
            subgraphs = [try_literal_eval(subgraph) for subgraph in subgraphs]
            answers = existing_beams["answerEntity"].tolist()
            graph_seqs = get_graph_seqs(subgraphs)

            curr_max = 0
            best_pred_answer = None
            for seq, answer in zip(graph_seqs, answers):
                seq_tok = tokenizer(
                    seq,
                    padding="max_length",
                    max_length=512,
                    truncation=True,
                    return_tensors="pt",
                )
                mask = seq_tok["attention_mask"].to(device)
                input_id = seq_tok["input_ids"].squeeze(1).to(device)
                output = model(input_id, mask)
                correct_pred = output.argmax(dim=1).item()

                # get the highest predicted correct sequence/beam answer
                if correct_pred == 1:
                    correct_prob = output.cpu().detach().numpy()[0][1]

                    if correct_prob > curr_max:
                        curr_max = correct_prob
                        best_pred_answer = answer

        # all subgraphs are predicted to be wrong
        if best_pred_answer is None:
            continue

        best_pred_label = list(beams_jsonl_labels_entities.keys())[
            list(beams_jsonl_labels_entities.values()).index(best_pred_answer)
        ]
        best_pred_label = unidecode(best_pred_label)
        if best_pred_label == target:
            extra_correct += 1

  0%|          | 35/14106 [00:01<09:57, 23.55it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


  6%|▋         | 912/14106 [01:04<11:47, 18.64it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


  7%|▋         | 964/14106 [01:08<19:24, 11.29it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 15%|█▍        | 2062/14106 [02:31<15:40, 12.80it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 15%|█▌        | 2176/14106 [02:37<05:20, 37.23it/s]

ERROR EMPTY GRAPHS!


 19%|█▊        | 2623/14106 [03:02<15:02, 12.73it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 19%|█▉        | 2694/14106 [03:06<12:06, 15.70it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 20%|█▉        | 2792/14106 [03:11<08:18, 22.71it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 34%|███▍      | 4822/14106 [04:55<05:55, 26.08it/s]

ERROR EMPTY GRAPHS!


 37%|███▋      | 5214/14106 [05:15<06:33, 22.57it/s]

ERROR EMPTY GRAPHS!


 37%|███▋      | 5286/14106 [05:19<04:54, 29.96it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 38%|███▊      | 5420/14106 [05:28<07:41, 18.83it/s]

ERROR EMPTY GRAPHS!


 42%|████▏     | 5940/14106 [06:00<06:07, 22.24it/s]

ERROR EMPTY GRAPHS!


 43%|████▎     | 6020/14106 [06:05<06:57, 19.36it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 44%|████▎     | 6163/14106 [06:13<09:21, 14.14it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 44%|████▍     | 6173/14106 [06:14<07:50, 16.85it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 44%|████▍     | 6262/14106 [06:19<09:26, 13.85it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 57%|█████▋    | 8027/14106 [07:56<05:52, 17.25it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 57%|█████▋    | 8095/14106 [08:00<06:42, 14.95it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 59%|█████▊    | 8267/14106 [08:09<04:35, 21.18it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 66%|██████▌   | 9271/14106 [08:56<04:36, 17.50it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 67%|██████▋   | 9507/14106 [09:10<03:11, 23.96it/s]

ERROR EMPTY GRAPHS!


 78%|███████▊  | 10954/14106 [10:21<03:21, 15.66it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 78%|███████▊  | 11003/14106 [10:23<02:13, 23.28it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 79%|███████▊  | 11082/14106 [10:28<03:24, 14.76it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 80%|███████▉  | 11228/14106 [10:32<02:36, 18.34it/s] 

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 87%|████████▋ | 12213/14106 [11:37<02:21, 13.38it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 90%|█████████ | 12710/14106 [12:06<01:08, 20.30it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


 99%|█████████▉| 13959/14106 [13:28<00:07, 19.02it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


100%|█████████▉| 14069/14106 [13:35<00:03, 11.76it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


100%|█████████▉| 14071/14106 [13:35<00:03, 11.12it/s]

ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!
ERROR EMPTY GRAPHS!


100%|██████████| 14106/14106 [13:37<00:00, 17.25it/s]


In [81]:
total_correct / len(res_csv_grouped)

0.29739118105770596

In [82]:
extra_correct

1336

In [83]:
(total_correct + extra_correct) / len(res_csv_grouped)

0.39210265135403377