In [3]:
import numpy as np
import pandas as pd
import gc
import nltk
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from gpt_j.dataset import load_spider_datasets
from gpt_j.gpt_j import GPTJForCausalLM
from t5.inference import evaluate_result
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
_, _, dev_spider = load_spider_datasets()

In [5]:
# parameters

tokenizer_name = 'EleutherAI/gpt-j-6B'
base_model_name = "hivemind/gpt-j-6B-8bit"
result_path = f'results/predicted_result_gptj_base.txt'

In [None]:
# model

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", padding_side="left")
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model = GPTJForCausalLM.from_pretrained("hivemind/gpt-j-6B-8bit", low_cpu_mem_usage=True)
model = model.to("cuda")

In [None]:
step = 10

for idx in range(0, 1040, step):
    gc.collect()
    torch.cuda.empty_cache()
    
    print(idx)
    # Tokenize prompt, load to GPU with model
    inputs = tokenizer(list(dev_spider.iloc[idx:idx+step]['prompt']), 
                      padding=True, 
                      truncation=True,
                      return_tensors="pt", 
                      )
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    output_tokens = model.generate(
        input_ids=inputs["input_ids"],
        temperature=0.1,
        max_new_tokens=128,
        do_sample=False,
        early_stopping=True
        )
    output_text = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
    
    # parse outputs from prompt
    outputs = []
    for text in output_text:
        gen_tokens_post_few_shot = text.split('Answer:')[-1]
        extracted_string = gen_tokens_post_few_shot.split('###')[0]
        extracted_string = extracted_string.strip()
        outputs.append(extracted_string)

    with open(result_path, 'a', encoding='utf-8') as f:
        for idx, output in enumerate(outputs):
            db_id = dev_spider.iloc[idx]['db_id']
            f.write(output + '\t' + db_id + '\n')

In [None]:
evaluate_result(result_path)