In [3]:
import os
from rank_bm25 import BM25Okapi
import pandas as pd
import numpy as np

raw_dir = 'raw'
processed_dir = 'data'
BATCH_SIZE = 10
if not os.path.exists(raw_dir):
    os.mkdir(raw_dir)

if not os.path.exists(processed_dir):
    os.mkdir(processed_dir)

In [4]:
raw_doc = 'docdev-stopstem.xml_1.out'
raw_qrels = 'msmarco-docdev-qrels.tsv'
raw_queries = 'queries.docdev.tsv'

preproc_doc = 'docs.tsv'
preproc_qrels = 'qrels.tsv'
preproc_queries = 'queries.tsv'

In [5]:
def read_local_files(raw_dir=raw_dir):
    df = pd.read_csv(os.path.join(raw_dir, 'docdev-stopstem.xml_1.out'), names=['query_id', 'q0', 'doc_id', 'rank', 'score', 'text'],  sep=' ')
    df_qrels = pd.read_csv(os.path.join(raw_dir, 'msmarco-docdev-qrels.tsv'), sep=' ', names=['query_id', 'rank_0', 'doc_id', 'rank_1'])
    df_queries = pd.read_csv(os.path.join(raw_dir, 'queries.docdev.tsv'), sep='\t', names=['query_id', 'data'])
    df_docs = pd.read_csv(os.path.join(raw_dir, 'docs.tsv'), sep='\t', names=['id', 'doc_id', 'data'])

    df_qrels['id'] = df_qrels.index
    df_qrels = df_qrels[['id', 'query_id', 'doc_id']]

    df_queries['id'] = df_queries.index
    df_queries = df_queries[['id', 'query_id', 'data']]
    return df_docs, df_qrels, df_queries

In [6]:
def create_joined_file(df_docs, df_qrels, df_queries, path_processed_joined=None):
    joined_df = df_qrels.merge(df_queries, on='query_id').merge(df_docs, on='doc_id', how='left')[['query_id', 'data_x', 'doc_id', 'data_y']]
    joined_df.rename(columns={'data_x':'query_data', 'data_y':'doc_data'}, inplace=True)
    if path_processed_joined:
        joined_df.to_csv(path_processed_joined, sep='\t', index=None, header=None)
    return joined_df

### Reading files from raw dir. Creating first joined view for batch iterator

In [13]:
df_docs, df_qrels, df_queries = read_local_files()
joined_df = create_joined_file(df_docs, df_qrels, df_queries)

In [14]:
joined_df.head()

Unnamed: 0,query_id,query_data,doc_id,doc_data
0,2,androgen receptor define,D1650436,"""From Wikipedia, the free encyclopedianavigati..."
1,1215,3 levels of government in canada and their res...,D1202771,Immigration & Citizenship Canadian Government ...
2,1288,3/5 of 60,D1547717,Science & Mathematics Mathematics What is 3/5 ...
3,1576,60x40 slab cost,D1313702,"""Forum Dock Side (Discussion) PB Open Water Co..."
4,2235,bethel university was founded in what year,D2113408,73% of our students get into one of their top ...


### Baseline DEV_0. We will use bm25 to predict mrr100 on documents only
Queries in this this case are empty strings

In [15]:
df_docs, df_qrels, df_queries = read_local_files()
print(len(df_docs), len(df_qrels), len(df_queries))
joined_df = create_joined_file(df_docs, df_qrels, df_queries)
joined_df.head(3)

5185 5193 5193


Unnamed: 0,query_id,query_data,doc_id,doc_data
0,2,androgen receptor define,D1650436,"""From Wikipedia, the free encyclopedianavigati..."
1,1215,3 levels of government in canada and their res...,D1202771,Immigration & Citizenship Canadian Government ...
2,1288,3/5 of 60,D1547717,Science & Mathematics Mathematics What is 3/5 ...


In [16]:
corpus = list(joined_df['doc_data'].values)
tokenized_corpus = [corp.split(" ") for corp in corpus]

bm25 = BM25Okapi(tokenized_corpus)

In [17]:
def predicting_100_top_doc_id(joined_df, bm25):
    predict_on_query = []
    predicted_scores = []
    for query_id, query in zip(list(joined_df['query_id']), list(joined_df['query_data'])):
        query = query.split()
        score = bm25.get_scores(query)
        top_100_scores = np.round(score[score >= pd.Series(score).nlargest(100).values[-1]], 1)# score[score >= pd.Series(score).nlargest(100).values[-1]]
        top_100_doc_id = joined_df['doc_id'][score >= pd.Series(score).nlargest(100).values[-1]].values
        score_dicts = [{doc:score} for (doc, score) in zip(top_100_doc_id, top_100_scores)]
        predict_on_query.append(list(top_100_doc_id))
        predicted_scores.append(list(top_100_scores))
    return predict_on_query, predicted_scores

In [18]:
joined_df['predict_100'], joined_df['predict_100_score'] = predicting_100_top_doc_id(joined_df, bm25)
joined_df.head(3)

Unnamed: 0,query_id,query_data,doc_id,doc_data,predict_100,predict_100_score
0,2,androgen receptor define,D1650436,"""From Wikipedia, the free encyclopedianavigati...","[D1650436, D344906, D993875, D2339274, D802481...","[26.0, 6.1, 4.1, 7.7, 10.0, 5.3, 4.9, 5.8, 5.3..."
1,1215,3 levels of government in canada and their res...,D1202771,Immigration & Citizenship Canadian Government ...,"[D1202771, D1361055, D22461, D2339274, D217536...","[24.7, 20.8, 19.4, 20.4, 19.5, 19.2, 21.2, 20...."
2,1288,3/5 of 60,D1547717,Science & Mathematics Mathematics What is 3/5 ...,"[D1547717, D815091, D855092, D3461202, D628872...","[26.6, 10.1, 8.5, 9.7, 9.1, 9.6, 8.1, 9.0, 8.8..."


In [19]:
from rank_eval import Qrels, Run, evaluate
qrels = Qrels()
qrels.add_multi(q_ids=[str(i) for i in list(joined_df['query_id'].values)], 
                doc_ids=[[i] for i in joined_df['doc_id']],
                scores=[[1.0] for i in range(len(joined_df))])

run = Run()
run.add_multi(q_ids=[str(i) for i in list(joined_df['query_id'].values)], 
                doc_ids=[i for i in joined_df['predict_100']],
                scores=[i for i in joined_df['predict_100_score']])

evaluate(qrels, run, ["mrr@100"])

0.5366385958226686

### Baseline DEV_1. We will use bm25 to predict mrr100 of docs + ideal queries

In [20]:
df_docs, df_qrels, df_queries = read_local_files()
print(len(df_docs), len(df_qrels), len(df_queries))
joined_df = create_joined_file(df_docs, df_qrels, df_queries)
joined_df.head(3)

5185 5193 5193


Unnamed: 0,query_id,query_data,doc_id,doc_data
0,2,androgen receptor define,D1650436,"""From Wikipedia, the free encyclopedianavigati..."
1,1215,3 levels of government in canada and their res...,D1202771,Immigration & Citizenship Canadian Government ...
2,1288,3/5 of 60,D1547717,Science & Mathematics Mathematics What is 3/5 ...


In [21]:
joined_df['old_doc_data'] = joined_df['doc_data']
joined_df['doc_data'] = joined_df['query_data'] + joined_df['doc_data']
joined_df.head(3)

Unnamed: 0,query_id,query_data,doc_id,doc_data,old_doc_data
0,2,androgen receptor define,D1650436,"androgen receptor define""From Wikipedia, the ...","""From Wikipedia, the free encyclopedianavigati..."
1,1215,3 levels of government in canada and their res...,D1202771,3 levels of government in canada and their res...,Immigration & Citizenship Canadian Government ...
2,1288,3/5 of 60,D1547717,3/5 of 60Science & Mathematics Mathematics Wha...,Science & Mathematics Mathematics What is 3/5 ...


In [22]:
corpus = list(joined_df['doc_data'].values)
tokenized_corpus = [corp.split(" ") for corp in corpus]

bm25 = BM25Okapi(tokenized_corpus)

In [23]:
# predict 100 docs on the added queries and docs
joined_df['predict_100'], joined_df['predict_100_score'] = predicting_100_top_doc_id(joined_df, bm25)
joined_df.head(3)

Unnamed: 0,query_id,query_data,doc_id,doc_data,old_doc_data,predict_100,predict_100_score
0,2,androgen receptor define,D1650436,"androgen receptor define""From Wikipedia, the ...","""From Wikipedia, the free encyclopedianavigati...","[D1650436, D344906, D2339274, D802481, D104368...","[26.1, 6.0, 7.7, 10.0, 5.3, 5.0, 4.7, 5.5, 4.8..."
1,1215,3 levels of government in canada and their res...,D1202771,3 levels of government in canada and their res...,Immigration & Citizenship Canadian Government ...,"[D1202771, D1361055, D22461, D2339274, D217536...","[35.1, 20.8, 19.4, 20.4, 19.5, 19.2, 21.2, 20...."
2,1288,3/5 of 60,D1547717,3/5 of 60Science & Mathematics Mathematics Wha...,Science & Mathematics Mathematics What is 3/5 ...,"[D1547717, D815091, D855092, D3461202, D628872...","[27.2, 10.1, 8.5, 9.7, 9.1, 9.6, 8.1, 9.0, 8.8..."


In [24]:
from rank_eval import Qrels, Run, evaluate
qrels = Qrels()
qrels.add_multi(q_ids=[str(i) for i in list(joined_df['query_id'].values)], 
                doc_ids=[[i] for i in joined_df['doc_id']],
                scores=[[1.0] for i in range(len(joined_df))])

run = Run()
run.add_multi(q_ids=[str(i) for i in list(joined_df['query_id'].values)], 
                doc_ids=[i for i in joined_df['predict_100']],
                scores=[i for i in joined_df['predict_100_score']]) # [[1.0] * 100 for i in range(len(joined_df))])

In [25]:
evaluate(qrels, run, ["mrr@100"])

0.8558183092923586

### Will we try to get all the documents?

In [3]:
import csv
import random
import gzip
import os
from collections import defaultdict

# The query string for each topicid is querystring[topicid]
# tsv: qid, query
querystring = {} 
with gzip.open("data/msmarco-doctrain-queries.tsv.gz", 'rt', encoding='utf8') as f:
    tsvreader = csv.reader(f, delimiter="\t")
    for [qid, querystring_of_topicid] in tsvreader:
        querystring[qid] = querystring_of_topicid

# In the corpus tsv, each docid occurs at offset docoffset[docid]
# tsv: docid, offset_trec, offset_tsv
docoffset = {}
with gzip.open("data/msmarco-docs-lookup.tsv.gz", 'rt', encoding='utf8') as f:
    tsvreader = csv.reader(f, delimiter="\t")
    for [docid, _, offset] in tsvreader:
        docoffset[docid] = int(offset)

# For each topicid, the list of positive docids is qrel[topicid]
# TREC qrels format - qid - docid - rel == 1
qrel = {}
train_docs = []
with gzip.open("data/msmarco-doctrain-qrels.tsv.gz", 'rt', encoding='utf8') as f:
    new_reader = csv.reader(f)#, delimiter="\t")
    for line in new_reader:
        [qid, _, docid, rel] = line[0].split()
        assert rel == "1"
        train_docs.append(docid)
        if qid in qrel:
            qrel[qid].append(docid)
        else:
            qrel[qid] = [docid]

train_docs_set = set(train_docs)

In [4]:
import sys
import csv
csv.field_size_limit(sys.maxsize)

train_docs = {}
with gzip.open("data/msmarco-docs.tsv.gz", 'rt', encoding='utf8') as f:
    new_reader = csv.reader(f, delimiter="\t")
    for line in new_reader:
        docid, url, title, body = line
        if docid in train_docs_set:
            train_docs_set.discard(docid)
            train_docs[docid] = body
            if not train_docs_set:
                break

In [5]:
train_docs_df = pd.DataFrame({'doc_id':train_docs.keys(), 'data':train_docs.values()})
train_docs_df.head(3)

Unnamed: 0,doc_id,data
0,D1359209,Check for Lice Nits Edited by Mian Sheilette O...
1,D3233725,Dogo Argentino Miscellaneous The Dogo Argentin...
2,D1885729,"Weeds invade lawns and gardens, robbing the go..."


In [6]:
df_qrels_train = pd.read_csv(os.path.join(raw_dir, 'msmarco-doctrain-qrels.tsv'), sep=' ', names=['query_id', 'rank_0', 'doc_id', 'rank_1'])
df_queries_train = pd.read_csv(os.path.join(raw_dir, 'queries.doctrain.tsv'), sep='\t', names=['query_id', 'data'])

df_qrels_train['id'] = df_qrels_train.index
df_qrels_train = df_qrels_train[['id', 'query_id', 'doc_id']]
df_queries_train['id'] = df_queries_train.index
df_queries_train = df_queries_train[['id', 'query_id', 'data']]

In [2]:
joined_df_train = create_joined_file(train_docs_df, df_qrels_train, df_queries_train)
joined_df_train.head(3)

NameError: name 'create_joined_file' is not defined

In [None]:
corpus = list(joined_df_train['doc_data'].values)
tokenized_corpus = [corp.split(" ") for corp in corpus]

bm25 = BM25Okapi(tokenized_corpus)

In [1]:
from rank_eval import Qrels, Run, evaluate
qrels = Qrels()
qrels.add_multi(q_ids=[str(i) for i in list(joined_df_train['query_id'].values)], 
                doc_ids=[[i] for i in joined_df_train['doc_id']],
                scores=[[1.0] for i in range(len(joined_df_train))])

NameError: name 'joined_df_train' is not defined

In [None]:
run = Run()
run.add_multi(q_ids=[str(i) for i in list(joined_df_train['query_id'].values)], 
                doc_ids=[i for i in joined_df_train['predict_100']],
                scores=[i for i in joined_df_train['predict_100_score']])

evaluate(qrels, run, ["mrr@100"])

In [117]:
class DFIterator:
    def __init__(self, df, batch_size=1):
        self.df = df
        self.batch_size = batch_size
        self.min, self.max = 0, len(self.df)

    def __iter__(self):
        self.curr_id = self.min
        return self

    def df_slice(self, start_pos, end_pos):
        data = self.df.iloc[start_pos:end_pos]#.values
        assert (len(data) == end_pos - start_pos), f"{len(data)} != {end_pos} - {start_pos}"
        return data

    def __next__(self):
        if self.curr_id > self.max:
            raise StopIteration
        start_pos = self.curr_id
        end_pos = min(start_pos + self.batch_size, self.max)
        data = self.df_slice(start_pos, end_pos)
        self.curr_id += self.batch_size
        return data

In [136]:
class QueryModel():
    def __init__(self, search_engine, model_type):
        self.model_type = model_type
        self.search_engine = search_engine
        
    def expand_document(self, doc, query=None):
        if self.model_type == 'Dummy':
            return query
        else:
            pass

In [137]:
def generate_query(df_batch, model):
    documents = df_batch['doc_text']
    generated_qs = []
    if 'query' in df_batch.columns:
        queries = df_batch['query']
        for doc, query in zip(documents, queries):
            generated_q = model.expand_document(doc, query)
            generated_qs.append([generated_q])
    else:
        for doc in documents:
            generated_q = model.expand_document(doc)
            generated_qs.append([generated_q])   
    # TODO: put generated queries into DB
    return generated_qs

In [158]:
if __name__=='__main__':
    df_docs, df_qrels, df_queries = read_local_files(raw_dir)
    joined_df = create_joined_file(df_docs, df_qrels, df_queries)
    joined_df.fillna('simple_text', inplace=True) # Just in our toy example
    bm25 = BM25Okapi(joined_df['doc_text']) # learning search engine on every document
    model = QueryModel(bm25, 'Dummy')
    iterator = DFIterator(joined_df, BATCH_SIZE)
    for batch in iterator:
        generated_queries = generate_query(batch, model)
        print(generated_queries)
        break

[[' androgen receptor define'], ['3 levels of government in canada and their responsibilities'], ['3/5 of 60'], ['60x40 slab cost'], ['bethel university was founded in what year'], ['does suddenlink carry espn3'], ['explain what a bone scan is and what it is used for.'], ['is the louisiana sales tax 4.75'], ['ludacris net worth'], ['the hormone that does the opposite of calcitonin is']]


In [166]:
batch = [(82, 34015, 'average cost of maid service', 'D35147', 'simple_text'), (111, 47864, 'awareness of others definition', 'D804804', 'simple_text'), (265, 93649, 'clutch baseball definition', 'D239506', 'simple_text'), (386, 135079, 'definition of hamburger', 'D1818481', 'simple_text'), (578, 185009, 'famous people who was born on pi day', 'D791803', 'simple_text'), (717, 237936, 'how is volcano formed', None, None), (928, 320970, 'how much is a kia sorento', 'D2914677', 'simple_text'), (1550, 543251, 'weather in anacortes wa', 'D2196243', 'simple_text'), (2209, 742022, 'what is dysgraphia?', 'D5552', 'simple_text'), (2503, 829050, 'what is the lowest point in netherlands', 'D2562508', 'simple_text'), (2600, 856171, 'what is tipm', 'D2662922', 'simple_text'), (3220, 997533, 'where is the sea of cortez located', 'D3030351', 'simple_text'), (3734, 1041948, 'what is the dare font', 'D235550', 'simple_text'), (3795, 1047548, 'what is qd?', 'D184911', 'simple_text'), (4837, 1095278, 'average temp in fort worth in march', None, None), (4949, 1097023, 'how many inches are in a yard and what is', 'D1433675', 'simple_text')]
df = pd.DataFrame(batch)
df.shape #, columns=[['id', 'query_id', 'query', 'doc_id', 'doc_text']])
df.rename(columns={0:'a', 1:'b', 3:'c', 4:'d', 5:'f'})

Unnamed: 0,a,b,2,c,d
0,82,34015,average cost of maid service,D35147,simple_text
1,111,47864,awareness of others definition,D804804,simple_text
2,265,93649,clutch baseball definition,D239506,simple_text
3,386,135079,definition of hamburger,D1818481,simple_text
4,578,185009,famous people who was born on pi day,D791803,simple_text
5,717,237936,how is volcano formed,,
6,928,320970,how much is a kia sorento,D2914677,simple_text
7,1550,543251,weather in anacortes wa,D2196243,simple_text
8,2209,742022,what is dysgraphia?,D5552,simple_text
9,2503,829050,what is the lowest point in netherlands,D2562508,simple_text


In [169]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 16 entries, 0 to 15
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   0       16 non-null     int64 
 1   1       16 non-null     int64 
 2   2       16 non-null     object
 3   3       14 non-null     object
 4   4       14 non-null     object
dtypes: int64(2), object(3)
memory usage: 768.0+ bytes


In [167]:
import pandas as pd
    
def QueryGenerator():
    def __init__(self, model, table_queries, table_qrels, db_path=DB_PATH):
        self.model = model
        self.connection = sqlite3.connect(db_path)
        self.table_queries = table_queries
        self.table_qrels = table_qrels
        self.cursor = self.connection.cursor()
        
    def generate_query(self, batch, write=True):
        df_batch = pd.DataFrame(batch)
        if df_batch.shape[1] == 5:
            df_batch.rename(columns={0:'id', 1:'query_id', 3:'query', 4:'doc_id', 5:'doc_text'}, inplace=True)
        elif df_batch.shape[1] == 3:
            df_batch.rename(columns={0:'id', 1:'doc_id', 2:'doc_text'}, inplace=True)
        documents = df_batch['doc_text']
        generated_qs = []
        if 'query' in df_batch.columns:
            queries = df_batch['query']
            for doc, query in zip(documents, queries):
                generated_qs.append([query])
        else:
            pass   
        if write:
            self.write_to_db(generated_qs, df_batch)
        return generated_qs
    
    def write_to_db(self, generated_queries, df_batch):
        q_queries = f"INSERT INTO {self.table_queries} (ID, QUERY_ID, DATA) VALUES ({df_batch['id']}, {df_batch['query_id']}, {generated_queries});"
        q_qrels = f"INSERT INTO {self.table_qrels} (ID, QUERY_ID, DOC_ID) VALUES ({df_batch['id']}, {df_batch['query_id']}, {df_batch['doc_id']});"
        self.cursor.execute(q_queries)
        self.cursor.execute(q_qrels)

In [None]:
gener = QueryGenerator(db_path, 'Dummy_model')

In [None]:
# import sys
# import sqlite3
# import pandas as pd
# from utils import DB_PATH, EXPERIMENT_INIT, logger
# from iterator import DBIterator
#
#
# class QueryGenerator:
#     def __init__(self, model, experiment_name, db_path=DB_PATH):
#         logger.info(f"Starting {experiment_name} experiment!")
#         self.model = model
#         self.connection = sqlite3.connect(db_path)
#         self.experiment_name = experiment_name
#         self.table_queries = f"QUERIES_{experiment_name}"
#         self.table_qrels = f"QRELS_{experiment_name}"
#         self.joined_view = f"JOINED_{experiment_name}"
#         self.init_experiment()
#
#     def init_experiment(self):
#         logger.info(f"Initializing DB experiments: "
#                     f"creating {self.table_queries}, {self.table_qrels}, {self.joined_view}")
#         self.connection.executescript(
#             EXPERIMENT_INIT.format(
#                 TABLE_QUERIES=self.table_queries,
#                 TABLE_QRELS=self.table_qrels,
#                 JOINED_VIEW=self.joined_view,
#             )
#         )
#
#     def generate_query(self, batch, write=True):
#         """
#         :param batch: DOCS table content: List of (id, doc_d, data)
#         :param write: write generated queries to db or not (write by default)
#         :return: List of generated queries: 1 query for 1 doc
#
#         For each document model generates only ONE string:
#         1 string = 1 query
#         1 string = multiple queries concatenated in one string with ';' separator
#         """
#         # TODO Discuss:
#         #  1 doc = 1 "string" format
#         #  is it necessary to return generated_qs
#
#         df_batch = pd.DataFrame(batch, columns=["id", "doc_id", "data"])
#         df_batch["query_id"] = df_batch["doc_id"]
#         documents = df_batch["data"]
#
#         if self.model:
#             generated_qs = self.model(documents)
#         else:
#             generated_qs = [""] * len(documents)
#
#         assert len(generated_qs) == len(documents)
#
#         if write:
#             self.write_to_db(generated_qs, df_batch)
#         return generated_qs
#
#     def write_to_db(self, generated_queries, df_batch):
#         logger.debug(f"Writing generated queries ({len(df_batch)} samples) "
#                      f"to DB (tables: {self.table_queries}, {self.table_qrels})")
#
#         queries_to_db = df_batch[["id", "query_id"]].copy()
#         queries_to_db["data"] = generated_queries
#         queries_to_db["data"] = queries_to_db["data"].astype("string")
#         queries_to_db.to_sql(
#             self.table_queries, self.connection, if_exists="append", index=False
#         )
#
#         qrels_to_db = df_batch[["id", "query_id", "doc_id"]].copy()
#         qrels_to_db.to_sql(
#             self.table_qrels, self.connection, if_exists="append", index=False
#         )
#
#
# if __name__ == "__main__":
#     experiment = sys.argv[1]
#     docs_table = "DOCS"
#     dummy_model = None
#     q_generator = QueryGenerator(dummy_model, experiment)
#     test_iterator = DBIterator(docs_table, batch_size=1024)
#     for batch_data in test_iterator:
#         q_generator.generate_query(batch_data)
import sqlite3
import pandas as pd
from utils import DB_PATH
from iterator import DBIterator


class QueryGenerator():
    def __init__(self, model, experiment_name, db_path=DB_PATH):
        self.model = model
        self.connection = sqlite3.connect(db_path)
        self.experiment_name = experiment_name
        self.table_queries = 'QUERIES_' + experiment_name
        self.table_qrels = 'QRELS_' + experiment_name
        self.joined_view = 'JOINED_' + experiment_name
        self.init_experiment()


    def init_experiment(self):
        self.connection.execute(('drop table if exists {table};'.format(table=self.table_queries)))
        self.connection.execute(('drop table if exists {table};'.format(table=self.table_qrels)))
        self.connection.execute(('drop view if exists {table};'.format(table=self.joined_view)))
        self.init_view()
        create_table = 'create table {table}({field_1} {type_1} primary key, {field_2} {type_2}, {field_3} {type_3});'
        create_index = 'create {option} index {ind}_{i}_id on {table}({field});'
        self.connection.execute(create_table.format(table=self.table_queries, field_1='id', type_1='number', field_2='query_id', type_2='number', field_3='data', type_3='text'))
        self.connection.execute(create_table.format(table=self.table_qrels, field_1='id', type_1='number', field_2='query_id', type_2='number', field_3='doc_id', type_3='number'))
        self.connection.execute(create_index.format(option='unique', ind=self.table_queries, i='1', table=self.table_queries, field='query_id'))
        self.connection.execute(create_index.format(option='', ind=self.table_qrels, i='1', table=self.table_qrels, field='query_id'))
        self.connection.execute(create_index.format(option='', ind=self.table_qrels, i='2', table=self.table_qrels, field='doc_id'))


    def init_view(self):
        self.connection.execute(('drop view if exists {table};'.format(table=self.joined_view)))
        self.connection.execute(("""
        CREATE VIEW {view} AS SELECT QR.ID ID, Q.QUERY_ID, Q.DATA QUERY_DATA, D.DOC_ID, D.DATA DOC_DATA FROM {qrels} QR
        LEFT JOIN DOCS D ON QR.DOC_ID = D.DOC_ID
        LEFT JOIN {queries} Q ON QR.QUERY_ID = Q.QUERY_ID;""".format(view=self.joined_view, queries=self.table_queries, qrels=self.table_qrels)))


    def generate_query(self, batch, write=True):
        df_batch = pd.DataFrame(batch)
        if df_batch.shape[1] == 5:
            df_batch.rename(columns={0: 'id', 1: 'query_id', 2: 'query', 3: 'doc_id', 4: 'data'}, inplace=True)
        elif df_batch.shape[1] == 3:
            df_batch.rename(columns={0: 'id', 1: 'doc_id', 2: 'data'}, inplace=True)
        documents = df_batch['data']
        generated_qs = []
        if 'query' in df_batch.columns:
            queries = df_batch['query']
            for doc, query in zip(documents, queries):
                generated_qs.append(query)
        else:
            pass
        if write:
            self.write_to_db(generated_qs, df_batch)
        return generated_qs

    def write_to_db(self, generated_queries, df_batch):
        queries_to_db = df_batch[['id', 'query_id']].copy()
        queries_to_db['data'] = generated_queries
        queries_to_db['data'] = queries_to_db['data'].astype('string')
        queries_to_db.to_sql(self.table_queries, self.connection, if_exists='append', index=False)
        qrels_to_db = df_batch[['id', 'query_id', 'doc_id']].copy()
        qrels_to_db.to_sql(self.table_qrels, self.connection, if_exists='append', index=False)


if __name__ == "__main__":
    experiment_name = 'DEV_0'
    table_name = 'JOINED'
    bs = 1024
    shuffle_data = True #  if sys.argv[3] == 'True' else False
    test_iterator = DBIterator(table_name, bs, shuffle_data)
    q_generator = QueryGenerator('Dummy_Model', experiment_name)
    for batch in test_iterator:
        q_generator.generate_query(batch)