In [None]:
import jsonlines
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from collections import defaultdict
from torchnlp.word_to_vector import FastText

In [None]:
class WikiSQLDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl):
        data = []
        with jsonlines.open(jsonl) as reader:
            for obj in reader:
                sql = obj['sql']
                label = (sql['agg'], len(sql['conds']))
                text = obj['question'].strip().lower()
                text = text.replace('?', '') # TODO: ??
                data.append((text, sql['agg'], len(sql['conds'])))
        self.df = pd.DataFrame(data, columns=['text', 'agg', 'conds'])
        self.vectors = FastText(cache='../text/vectors.pth')

    def __len__(self):
        return len(self.df)

    def embed(self, text):
        words = text.split(' ')
        emb = self.vectors[words[0]]
        for word in words[1:]:
            emb += self.vectors[word]
        emb /= len(words)
        return emb

    def __getitem__(self, idx):
        text = self.df.loc[idx, 'text']
        agg = self.df.loc[idx, 'agg']
        conds = self.df.loc[idx, 'conds']
        return self.embed(text), (agg, conds)

In [None]:
sql_dataset = WikiSQLDataset('../text/data/dev.jsonl')

In [5]:
import tasti
import jsonlines
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from collections import defaultdict
from torchnlp.word_to_vector import FastText

class WikiSQLDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl):
        self.mode = 'input'
        data = []
        with jsonlines.open(jsonl) as reader:
            for obj in reader:
                sql = obj['sql']
                label = (sql['agg'], len(sql['conds']))
                text = obj['question'].strip().lower()
                text = text.replace('?', '') # TODO: ??
                data.append((text, sql['agg'], len(sql['conds'])))
        self.df = pd.DataFrame(data, columns=['text', 'agg', 'conds'])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = self.df.loc[idx, 'text']
        agg = self.df.loc[idx, 'agg']
        conds = self.df.loc[idx, 'conds']
        if self.mode == 'input':
            return text
        else:
            return agg, conds

class FastTextEmbedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.vectors = FastText(cache='./vectors.pth')
        
    def forward(self, sentences):
        embs = []
        for x in sentences:
            words = x.split(' ')
            emb = self.vectors[words[0]]
            for word in words[1:]:
                emb += self.vectors[word]
            emb /= len(words)
            embs.append(emb.reshape(1, -1))
        embs = torch.cat(embs, dim=0)
        return embs

class WikiSQLOfflineIndex(tasti.Index):
    def get_target_dnn(self):
        model = torch.nn.Identity()
        return model
        
    def get_embedding_dnn(self):
        model = FastTextEmbedder()
        return model
    
    def get_target_dnn_dataset(self):
        sql_dataset = WikiSQLDataset('../text/data/dev.jsonl')
        sql_dataset.mode = 'input'
        return sql_dataset
    
    def get_embedding_dnn_dataset(self):
        sql_dataset = WikiSQLDataset('../text/data/dev.jsonl')
        sql_dataset.mode = 'input'
        return sql_dataset
    
    def override_target_dnn_cache(self, target_dnn_cache):
        sql_dataset = WikiSQLDataset('../text/data/dev.jsonl')
        sql_dataset.mode = 'output'
        return sql_dataset
    
    def is_close(self, label1, label2):
        return label1 == label2

In [6]:
config = tasti.IndexConfig()
config.do_mining = False
config.do_training = False
config.do_infer = True
config.nb_buckets = 500
config.batch_size = 1

index = WikiSQLOfflineIndex(config)
index.init()

HBox(children=(FloatProgress(value=0.0, description='Inference', max=8421.0, style=ProgressStyle(description_w…




RandomBucketter: 100%|██████████| 124/124 [00:00<00:00, 972.63it/s]
FPFBucketter: 100%|██████████| 375/375 [00:00<00:00, 1232.96it/s]
100%|██████████| 8421/8421 [00:00<00:00, 79137.99it/s]


HBox(children=(FloatProgress(value=0.0, description='Target DNN Invocations', max=500.0, style=ProgressStyle(d…




In [16]:
class WikiSQLAggregateQuery(tasti.AggregateQuery):
    def score(self, target_dnn_output):
        return target_dnn_output[1]
    
class WikiSQLSUPGPrecisionQuery(tasti.SUPGPrecisionQuery):
    def score(self, target_dnn_output):
        return 1.0 if target_dnn_output[0] == 0 else 0.0

In [17]:
query = NightStreetAggregateQuery(index)
query.execute()

HBox(children=(FloatProgress(value=0.0, description='Propagation', max=8421.0, style=ProgressStyle(description…


Results
Initial Estimate: 13008.502667350944
Debiased Estimate: 11533.960855797119
Samples: 8142


{'initial_estimate': 13008.502667350944,
 'debiased_estimate': 11533.960855797119,
 'samples': 8142}

In [18]:
query = NightStreetSUPGPrecisionQuery(index)
query.execute()

HBox(children=(FloatProgress(value=0.0, description='Propagation', max=8421.0, style=ProgressStyle(description…


Results
Initial Estimate: 5754.345978265899
Debiased Estimate: 5658
idxs: [   0    1    2 ... 8418 8419 8420]
shape: (5658,)


{'initial_estimate': 5754.345978265899,
 'debiased_estimate': 5658,
 'idxs': array([   0,    1,    2, ..., 8418, 8419, 8420]),
 'shape': (5658,)}