In [4]:
import tasti
import jsonlines
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from collections import defaultdict
from torchnlp.word_to_vector import FastText

class WikiSQLDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl):
        self.mode = 'input'
        data = []
        with jsonlines.open(jsonl) as reader:
            for obj in reader:
                sql = obj['sql']
                label = (sql['agg'], len(sql['conds']))
                text = obj['question'].strip().lower()
                text = text.replace('?', '') # TODO: ??
                data.append((text, sql['agg'], len(sql['conds'])))
        self.df = pd.DataFrame(data, columns=['text', 'agg', 'conds'])
        self.vectors = FastText(cache='./cache/vectors.pth')

    def __len__(self):
        return len(self.df)

    def embed(self, text):
        words = text.split(' ') # FIXME
        emb = self.vectors[words[0]]
        for word in words[1:]:
            emb += self.vectors[word]
        emb /= len(words)
        return emb
    
    def __getitem__(self, idx):
        text = self.df.loc[idx, 'text']
        agg = self.df.loc[idx, 'agg']
        conds = self.df.loc[idx, 'conds']
        if self.mode == 'input':
            return self.embed(text)
        else:
            return agg, conds
        
class Embedder(nn.Module):
    def __init__(self, nb_out=128):
        super().__init__()
        self.mlp = nn.Sequential(
                nn.Linear(300, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, nb_out)
        )

    def forward(self, x):
        return self.mlp(x)

class WikiSQLOfflineIndex(tasti.Index):
    def get_target_dnn(self):
        model = torch.nn.Identity()
        return model
        
    def get_embedding_dnn(self):
        model = Embedder()
        return model
    
    def get_target_dnn_dataset(self):
        sql_dataset = WikiSQLDataset('../text/data/dev.jsonl')
        sql_dataset.mode = 'input'
        return sql_dataset
    
    def get_embedding_dnn_dataset(self):
        sql_dataset = WikiSQLDataset('../text/data/dev.jsonl')
        sql_dataset.mode = 'input'
        return sql_dataset
    
    def override_target_dnn_cache(self, target_dnn_cache):
        sql_dataset = WikiSQLDataset('../text/data/dev.jsonl')
        sql_dataset.mode = 'output'
        return sql_dataset
    
    def is_close(self, label1, label2):
        return label1 == label2
    
class WikiSQLAggregateQuery(tasti.AggregateQuery):
    def score(self, target_dnn_output):
        return target_dnn_output[1]
    
class WikiSQLSUPGPrecisionQuery(tasti.SUPGPrecisionQuery):
    def score(self, target_dnn_output):
        return 1.0 if target_dnn_output[0] == 0 else 0.0
    
class WikiSQLOfflineConfig(tasti.IndexConfig):
    def __init__(self):
        super().__init__()
        self.do_mining = True
        self.do_training = True
        self.do_infer = True
        self.do_bucketting = True
        self.nb_train = 500
        self.nb_buckets = 500
        self.batch_size = 1

In [5]:
config = WikiSQLOfflineConfig()
index = WikiSQLOfflineIndex(config)
index.init()

wiki.en.vec: 6.60GB [02:37, 41.8MB/s]                               
  0%|          | 0/2519371 [00:00<?, ?it/s]Skipping token 2519370 with 1-dimensional vector ['300']; likely a header
100%|██████████| 2519371/2519371 [03:12<00:00, 13061.42it/s]


HBox(children=(FloatProgress(value=0.0, description='Embedding DNN', max=8421.0, style=ProgressStyle(descripti…




RandomBucketter: 100%|██████████| 749/749 [00:00<00:00, 1684.35it/s]
FPFBucketter: 100%|██████████| 2250/2250 [00:01<00:00, 1687.15it/s]
100%|██████████| 8421/8421 [00:00<00:00, 17217.30it/s]


HBox(children=(FloatProgress(value=0.0, description='Target DNN', max=3000.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Triplet Dataset Init', max=3000.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Training Step', max=12000.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Inference', max=8421.0, style=ProgressStyle(description_w…

RandomBucketter: 100%|██████████| 124/124 [00:00<00:00, 1702.35it/s]
FPFBucketter:  47%|████▋     | 177/375 [00:00<00:00, 1767.55it/s]




FPFBucketter: 100%|██████████| 375/375 [00:00<00:00, 1727.97it/s]
100%|██████████| 8421/8421 [00:00<00:00, 96504.95it/s]


HBox(children=(FloatProgress(value=0.0, description='Target DNN Invocations', max=500.0, style=ProgressStyle(d…


