In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import bitsandbytes

model_name = "defog/llama-3-sqlcoder-8b"
tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        cache_dir="/model"
    )

model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        cache_dir="/model"
    )
    

prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

DDL statements:

CREATE TABLE home (
  day_id INTEGER PRIMARY KEY, -- Unique ID for each day
  month_id INTEGER PRIMARY KEY, -- Unique ID for each day
  year_id INTEGER PRIMARY KEY, -- Unique ID for each day
  hour_id INTEGER PRIMARY KEY, -- Unique ID for each day
  minute_id INTEGER PRIMARY KEY, -- Unique ID for each day
  second_id INTEGER PRIMARY KEY, -- Unique ID for each day
  air_conditioner_temp INTEGER,  -- Air conditioner temperature
  in_temp INTEGER,   -- Indoor temperature
  out_temp INTEGER,  -- Outdoor temperature
  now_temp INTEGER   -- Current temperature
);


The following SQL query best answers the question `{question}`:
```sql
"""

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



In [None]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
        temperature=0.0,
        top_p=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    # return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)
    return outputs[0].split("```sql")[1].split(";")[0]

In [None]:
question = "2024년 6월18일 우리집 온도 알려줘"
generated_sql = generate_query(question)
print(sqlparse.format(generated_sql, reindent=True))
