In [35]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
import torch.nn.functional as F

Load Dataset in

In [2]:
from datasets import load_dataset

ds = load_dataset("gretelai/synthetic_text_to_sql")
print(type(ds))
print(f"Train info: {ds['train']}")
print(f"Test info: {ds['test']}")

<class 'datasets.dataset_dict.DatasetDict'>
Train info: Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 100000
})
Test info: Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 5851
})


Setup LLaDA tokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
# The LLaDA architecture theoretically supports both left-padding and right-padding. 
# However, the sampling code implementation is simpler with left-padding.
if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'

# If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference.
assert tokenizer.pad_token_id != 126336

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Testing features

In [20]:
training_data = ds['train']
needed_cols = ['id','sql_context', 'sql_prompt', 'sql']
training_data = training_data.select_columns(needed_cols)
prompt = f"Schema:\n{training_data['sql_context'][0]}\n\nPrompt:\n{training_data['sql_prompt'][0]}"
encoded_outputs = tokenizer(
    prompt,
    add_special_tokens=False,
    padding=True,
    return_tensors="pt"
)
print(encoded_outputs['input_ids'].shape)
print(encoded_outputs['attention_mask'].shape)
# messages = [{"role": "user", "content": prompt} for prompt in prompts]
# prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in messages]

torch.Size([1, 180])
torch.Size([1, 180])


In [34]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

model.eval()
with torch.no_grad():
    input = tokenizer(training_data['sql_context'][0], training_data['sql_prompt'][0], return_tensors='pt', truncation=True)
    print(input["input_ids"].shape)
    outputs = model(**input)
    print(outputs.last_hidden_state.shape)
    cls_output = outputs.last_hidden_state[:, 0, :]
    print(cls_output.shape)



torch.Size([1, 176])
torch.Size([1, 176, 768])
torch.Size([1, 768])


In [50]:
class ContextPredictor(nn.Module):
    def __init__(self, dropout = 0.3, bert_requires_grad = False):
        super().__init__()
        self.bert = AutoModel.from_pretrained("distilbert-base-uncased")
        for param in self.bert.parameters():
            param.requires_grad = bert_requires_grad

        self.seq = nn.Sequential(
            nn.Linear(
                in_features=768,
                out_features=256
            ),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(
                in_features=256,
                out_features=64
            ),
            nn.ReLU(),
            nn.Linear(64,1),
            nn.ReLU()
        )
    
    def forward(self, input_ids, attention_mask):
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = bert_out.last_hidden_state[:, 0, :]
        return self.seq(cls_output)

In [55]:
context_model = ContextPredictor()

model.eval()

with torch.no_grad():
    input = tokenizer(training_data['sql_context'][0], training_data['sql_prompt'][0], return_tensors='pt', truncation=True)
    output = context_model(**input)
    print(output)

tensor([[0.]])
