In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import bisect

  from .autonotebook import tqdm as notebook_tqdm


Load Dataset in

In [2]:
from datasets import load_dataset

ds = load_dataset("gretelai/synthetic_text_to_sql")
print(type(ds))
print(f"Train info: {ds['train']}")
print(f"Test info: {ds['test']}")


<class 'datasets.dataset_dict.DatasetDict'>
Train info: Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 100000
})
Test info: Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 5851
})


Setup LLaDA tokenizer

In [3]:
label_tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
# The LLaDA architecture theoretically supports both left-padding and right-padding. 
# However, the sampling code implementation is simpler with left-padding.
if label_tokenizer.padding_side != 'left':
    label_tokenizer.padding_side = 'left'

# If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference.
assert label_tokenizer.pad_token_id != 126336

input_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Testing features

In [None]:
training_data = ds['train']
needed_cols = ['id','sql_context', 'sql_prompt', 'sql']
training_data = training_data.select_columns(needed_cols)
prompt = f"Schema:\n{training_data['sql_context'][0]}\n\nPrompt:\n{training_data['sql_prompt'][0]}"
# encoded_outputs = tokenizer(
#     prompt,
#     add_special_tokens=False,
#     padding=True,
#     return_tensors="pt"
# )
# print(encoded_outputs['input_ids'].shape)
# print(encoded_outputs['attention_mask'].shape)

output_lens= []
max, min = 0, 1e6
for instance in training_data['sql']:
    # print(instance)
    tokenized_target = label_tokenizer(
        instance,
        truncation=False,
        padding=False
    )
    sql_len = float(len(tokenized_target["input_ids"]))
    output_lens.append(sql_len)
    if sql_len > max:
        max = sql_len
    if sql_len < min:
        min = sql_len
    # print(f"len(sql): {sql_len}")
print(f"Average SQL length (in tokens): {sum(output_lens)/len(output_lens)}")
print(f"Max SQL length (in tokens): {max}")
print(f"Min SQL length (in tokens): {min}")
# messages = [{"role": "user", "content": prompt} for prompt in prompts]
# prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in messages]

SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;
SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;
SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';
SELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock;
SELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1;
SELECT SUM(spending) FROM defense.eu_humanitarian_assistance WHERE year BETWEEN 2019 AND 2021;
SELECT SpeciesName, AVG(WaterTemp) as AvgTemp FROM SpeciesWaterTemp INNER JOIN FishSpecies ON SpeciesWaterTemp.SpeciesID = FishSpecies.SpeciesID WHERE MONTH(Date) = 2 GROUP BY SpeciesName;
DELETE FROM Program_Outcomes WHERE program

KeyboardInterrupt: 

In [12]:
num_classes = 6
bucket_labels, bins = pd.qcut(output_lens, q=num_classes, retbins=True, labels=False)
print(f"Bucket Boundaries: {bins}")
print(len(bucket_labels))

Bucket Boundaries: [  4.  17.  22.  28.  36.  48. 219.]
5000


In [6]:
queries = training_data[0]["sql_context"].split(";")
queries = [q.strip() for q in queries]
queries = [q for q in queries if "CREATE TABLE" in q]
context = "; ".join(queries)
print(context)

CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE)


In [7]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

model.eval()
with torch.no_grad():
    input = tokenizer(training_data[0:2]['sql_context'], training_data[0:2]['sql_prompt'], padding='max_length', return_tensors='pt', truncation=True)
    print(input["input_ids"].shape)
    outputs = model(**input)
    print(outputs.last_hidden_state.shape)
    cls_output = outputs.last_hidden_state[:, 0, :]
    print(cls_output.shape)



torch.Size([2, 512])
torch.Size([2, 512, 768])
torch.Size([2, 768])


In [8]:
class ContextPredictor(nn.Module):
    def __init__(self, dropout = 0.3, bert_requires_grad = False):
        super().__init__()
        self.bert = AutoModel.from_pretrained("distilbert-base-uncased")
        for param in self.bert.parameters():
            param.requires_grad = bert_requires_grad

        self.seq = nn.Sequential(
            nn.Linear(
                in_features=768,
                out_features=256
            ),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(dropout),

            nn.Linear(
                in_features=256,
                out_features=64
            ),
            nn.LayerNorm(64),
            nn.GELU(),
            nn.Linear(64,3),
        )
    
    def forward(self, input_ids, attention_mask):
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = bert_out.last_hidden_state[:, 0, :]
        return self.seq(cls_output)

In [10]:
context_model = ContextPredictor()

model.eval()
print(training_data)
print()
with torch.no_grad():
    input = tokenizer(training_data[0:10]['sql_context'], training_data[0:10]['sql_prompt'], return_tensors='pt', truncation=True, padding='max_length')
    output = context_model(**input)
    print(output)

Dataset({
    features: ['id', 'sql_context', 'sql_prompt', 'sql'],
    num_rows: 100000
})

tensor([[-0.0228,  0.6681,  0.5718],
        [ 0.2924,  0.5012,  0.5976],
        [ 0.4934,  0.3726,  0.0985],
        [ 0.1952,  0.4844,  0.2544],
        [ 0.5096,  0.4364,  0.1956],
        [ 0.3712,  0.4075,  0.3696],
        [ 0.7025,  0.0497,  0.4119],
        [ 0.2468,  0.4677,  0.3036],
        [ 0.4765,  0.4512, -0.0360],
        [ 0.5447,  0.4783,  0.5373]])
