In [None]:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline, BitsAndBytesConfig
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain

CACHE_DIR = "../cache"

model_name_or_path = "meta-llama/Llama-2-7b-hf"  # "chavinlo/alpaca-native"

bnb_config = BitsAndBytesConfig(
    # load_in_8bit=True,
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

In [None]:
tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)

base_model = LlamaForCausalLM.from_pretrained(
    model_name_or_path,
    quantization_config=bnb_config,
    device_map="auto",
)

In [None]:
pipe = pipeline(
    "text-generation",
    model=base_model,
    tokenizer=tokenizer,
    max_length=500,
    temperature=0.3,
    top_p=0.95,
    repetition_penalty=1.2,
)


template = """
Write a SQL Query given the table name {Table} and columns as a list {Columns} for the given question : 
{question}.
"""

prompt = PromptTemplate(
    template=template, input_variables=["Table", "question", "Columns"]
)


local_llm = HuggingFacePipeline(pipeline=pipe)
llm_chain = LLMChain(prompt=prompt, llm=local_llm)


def get_llm_response(tble, question, cols):
    llm_chain = LLMChain(prompt=prompt, llm=local_llm)
    response = llm_chain.run({"Table": tble, "question": question, "Columns": cols})
    print(response)
    return response

In [None]:
tble = "employee"
cols = ["id", "name", "date_of_birth", "band", "manager_id"]
question = "Query the count of employees in band L6 with 239045 as the manager ID"
get_llm_response(tble, question, cols)

## Answer: SELECT COUNT(*) FROM employee WHERE band='L6' AND manager_id=239045;

In [None]:
tble = "employee"
cols = ["id", "name", "date_of_birth", "band", "manager_id"]
question = "Query the count of employees in band L6 and over 40 years of age"
get_llm_response(tble, question, cols)

## Answer: SELECT COUNT(*) FROM employee WHERE band='L6' AND date_of_birth>=(CURDATE() - INTERVAL 40 YEAR);