In [1]:
import nltk
import torch
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
from peft import PeftConfig, PeftModel
from t5.dataset import load_spider_datasets
from t5.inference import inference, evaluate_result

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
print(torch.cuda.is_available())

True


In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [9]:
# parameters

tokenizer_name = 't5-base'
base_model_name = "RoxyRong/t5_base_finetuned_test_5"
peft_model_name = "RoxyRong/t5_base_soft_prompt_tune_v3"
result_path = f'results/predicted_result_t5_base_finetuned_v3.txt'

eval_soft_prompt = True

In [5]:
# evaluate 

tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, model_max_length=512)
model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)

# if eval_soft_prompt:
peft_model_id = peft_model_name
config = PeftConfig.from_pretrained(peft_model_id)
model = PeftModel.from_pretrained(model, peft_model_id)

model = model.to("cuda")

You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
Token indices sequence length is longer than the specified maximum sequence length for this model (613 > 512). Running this sequence through the model will result in indexing errors


In [6]:
# load dataset
_, _, dev_spider = load_spider_datasets()

In [10]:
# inference
inference(dev_spider, model, tokenizer, result_path)

0
100
200
300
400
500
600
700
800
900
1000


In [11]:
# evaluate 
evaluate_result(result_path)

('                     easy                 medium               '
 'hard                 extra                all                 \n'
 'count                250                  440                  '
 '174                  170                  1034                \n'
 'execution            0.588                0.380                '
 '0.282                0.176                0.380               \n'
 '\n'
 'exact match          0.632                0.384                '
 '0.276                0.135                0.385               \n'
 '\n'
 '---------------------PARTIAL MATCHING ACCURACY----------------------\n'
 'select               0.960                0.923                '
 '0.987                0.882                0.940               \n'
 'select(no AGG)       0.977                0.940                '
 '1.000                0.882                0.955               \n'
 'where                0.918                0.846                '
 '0.531                0.565         

In [56]:
import subprocess
import pprint

eval_path = f"third_party/spider/evaluation.py"
gold = f"third_party/spider/evaluation_examples/gold_example.txt"
pred = result_path
db_dir = f"spider/database"
table = f"spider/tables.json"
etype = "all"

cmd_str = f"python3 \"{eval_path}\" --gold \"{gold}\" --pred \"{pred}\" --db \"{db_dir}\" --table \"{table}\" --etype {etype} "
result = subprocess.run(cmd_str, shell=True, capture_output=True, text=True)
pprint.pprint(result.stdout[-4633:])

('                     easy                 medium               '
 'hard                 extra                all                 \n'
 'count                250                  440                  '
 '174                  170                  1034                \n'
 'execution            0.588                0.380                '
 '0.282                0.176                0.380               \n'
 '\n'
 'exact match          0.632                0.384                '
 '0.276                0.135                0.385               \n'
 '\n'
 '---------------------PARTIAL MATCHING ACCURACY----------------------\n'
 'select               0.960                0.923                '
 '0.987                0.882                0.940               \n'
 'select(no AGG)       0.977                0.940                '
 '1.000                0.882                0.955               \n'
 'where                0.918                0.846                '
 '0.531                0.565         

In [31]:
label_embedding = []
for query in dev_spider['query']:
    tokens = tokenizer.tokenize(query)
    inputs = tokenizer.encode(
        query,
        max_length=100,
        padding='max_length',
        truncation=True,
        return_attention_mask=False,
        return_tensors='pt'
    )
    embedding = inputs.float()
    label_embedding.append(embedding)
    

In [18]:
query = dev_spider.iloc[0]['query']
tokens = tokenizer.tokenize(query)
inputs = tokenizer.encode(
    query,
    max_length=100,
    padding='max_length',
    truncation=True,
    return_attention_mask=False,
    return_tensors='pt'
)

In [33]:
average_embeddings

tensor([[7.6830e+01, 2.2278e+04, 1.3605e+04, 3.1774e+03, 5.6087e+03, 7.4438e+03,
         4.6589e+03, 7.2769e+03, 4.7087e+03, 5.5371e+03, 5.1068e+03, 4.7747e+03,
         5.0000e+03, 4.6756e+03, 4.4861e+03, 5.3498e+03, 4.5753e+03, 3.9869e+03,
         4.6533e+03, 4.0747e+03, 3.7121e+03, 3.9079e+03, 3.7692e+03, 3.0347e+03,
         3.3331e+03, 2.7068e+03, 2.4735e+03, 2.8643e+03, 2.5881e+03, 2.6125e+03,
         2.4479e+03, 2.2896e+03, 1.5905e+03, 2.3633e+03, 1.9401e+03, 1.4892e+03,
         2.2747e+03, 1.7833e+03, 1.5376e+03, 1.9648e+03, 1.8355e+03, 1.6025e+03,
         1.5095e+03, 1.4514e+03, 1.3454e+03, 1.2976e+03, 1.5645e+03, 1.5560e+03,
         1.2821e+03, 1.2193e+03, 1.1046e+03, 1.4017e+03, 1.2231e+03, 1.0214e+03,
         8.8721e+02, 9.6897e+02, 1.2687e+03, 9.9763e+02, 9.4790e+02, 9.8317e+02,
         9.3616e+02, 1.0790e+03, 8.0039e+02, 6.7364e+02, 4.9599e+02, 8.0332e+02,
         5.6892e+02, 6.6995e+02, 5.0928e+02, 6.1764e+02, 6.0084e+02, 5.9382e+02,
         5.3180e+02, 5.0104e