In [30]:
import sys
import ContextPredictor
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset, Dataset, load_from_disk


In [31]:
# load model
model = ContextPredictor.ContextPredictor()
model.load_state_dict(torch.load('../saved_models/predictor_epoch_3.pt'))
model.eval()



ContextPredictor(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear

In [32]:
dataset_path = "../data"
reloaded = load_from_disk(dataset_path)
input_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
pred_tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

def tokenize_function(example):
    tokenized = input_tokenizer(
        example["sql_prompt"],
        example["sql_context"],
        padding="max_length",
        truncation=True,
        max_length=512,
    )
    tokenized_target = pred_tokenizer(
        example["sql"],
        truncation=False,
        padding=False
    )
    sql_len = float(len(tokenized_target["input_ids"]))
    tokenized["labels"] = sql_len
    return tokenized

tokenized_datasets = reloaded.map(tokenize_function, batched=False, remove_columns=reloaded["train"].column_names)

tokenized_datasets.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"],
)

eval_dataset = tokenized_datasets["test"]

eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset, batch_size=1, shuffle=True
)



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Map: 100%|██████████| 2917/2917 [00:01<00:00, 2064.61 examples/s]


In [None]:
for samples in list(eval_dataloader)[0:10]:
    input_ids = samples["input_ids"]
    attention_mask = samples["attention_mask"]
    labels = samples["labels"]

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = int(outputs.item()*512)
        actuals = labels.item()
        print(f"Predicted SQL length: {predictions:.2f}, Actual SQL length: {actuals}")