In [None]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer

base_model = "codellama/CodeLlama-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")

In [None]:
from peft import PeftModel
model = PeftModel.from_pretrained(model, "sql-code-llama/checkpoint-160")

In [None]:
def generate_prompt(data_point):
    full_prompt =f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.

### Input:
{data_point["question"]}

### Context:
{data_point["context"]}

### Response:
"""
    return full_prompt

In [None]:
import json
with open('spider_create_context_val_db_id.json', 'r') as f:
    data = json.load(f)

# Step 2: Apply the function to each entry
# responses = map(your_function, data)
prompts = []
for data_point in data:
    prompt = generate_prompt(data_point)
    # append prompt to prompts
    prompts.append(prompt)

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
batch_size = 12  # You can adjust the batch size based on your GPU capacity
outputs = []

model.eval()
with torch.no_grad(), open('codellama7b-valoutputs2.txt', 'a') as f:
    for i in range(0, len(prompts), batch_size):
        print("i is: ", i)
        batch_inputs = prompts[i:i+batch_size]
        
        # Step 3: Loop over the batches and generate the outputs
        batch_model_inputs = tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True)
        
        input_lengths = [len(input_ids) for input_ids in batch_model_inputs["input_ids"]]
        
        batch_outputs = model.generate(**batch_model_inputs, max_new_tokens=100)
        
        # Step 4: Collect the outputs
        batch_decoded_outputs = [tokenizer.decode(output[input_length:], skip_special_tokens=True) for output, input_length in zip(batch_outputs, input_lengths)]
        outputs.extend(batch_decoded_outputs)
        print(outputs)
        for b in batch_decoded_outputs:
            f.write(b.replace("\n", " ") + "\n")
        
# Now, `outputs` is a list containing the outputs for all 900 inputs