In [None]:
from datasets import load_from_disk

dataset = load_from_disk(r"./data/sql_chats")
dataset

In [None]:
import torch

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

print(device)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer


provider = "HuggingFaceTB/"
model_version = "SmolLM2-360M-Instruct"
model_name = provider + model_version

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_name,
    use_cache=False,
    cache_dir="./models/" + model_version,
).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=model_name,
    cache_dir="./tokenizers/" + model_version,
)

finetune_name = f"./results/{model_version}"
finetune_tags = ["smol-sql", "v1"]

In [None]:
# from peft import LoraConfig

# rank_dimension = 6
# lora_alpha = 8
# lora_dropout = 0.05

# peft_config = LoraConfig(
#     r=rank_dimension,
#     lora_alpha=lora_alpha,
#     lora_dropout=lora_dropout,
#     bias="none",
#     target_modules="all-linear",
#     task_type="CAUSAL_LM",
# )

In [None]:
from trl import SFTConfig, SFTTrainer

args = SFTConfig(
    output_dir=finetune_name,
    num_train_epochs=2,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=32,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    learning_rate=5e-5,
    warmup_ratio=0.1,
    lr_scheduler_type="constant",
    fp16=True,
    dataloader_num_workers=6,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    packing=True,
    push_to_hub=False,
    report_to="none",
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": False,
    },
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    processing_class=tokenizer,
)

In [None]:
import gc

torch.cuda.empty_cache()
gc.collect()

trainer.train()

In [None]:
trainer.save_model()

In [None]:
# from peft import AutoPeftModelForCausalLM
# from transformers import pipeline

# model = AutoPeftModelForCausalLM.from_pretrained(
#     pretrained_model_name_or_path=args.output_dir,
#     torch_dtype=torch.float16,
#     low_cpu_mem_usage=True,
# )

# merged_model = model.merge_and_unload()
# merged_model.save_pretrained(
#     args.output_dir, safe_serialization=True, max_shard_size="2GB"
# )

# tokenizer = AutoTokenizer.from_pretrained(finetune_name)
# model = AutoPeftModelForCausalLM.from_pretrained(
#     finetune_name, device_map="auto", torch_dtype=torch.float16
# )
# pipe = pipeline(
#     "text-generation", model=merged_model, tokenizer=tokenizer, device=device
# )

# del model
# del trainer
# torch.cuda.empty_cache()

In [None]:
from transformers import pipeline

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=finetune_name,
    use_cache=False,
).to(device)


tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=finetune_name)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device,
)

In [None]:
def test_inference(prompt):
    prompt = pipe.tokenizer.apply_chat_template(
        prompt,
        tokenize=False,
        add_generation_prompt=True,
    )
    outputs = pipe(
        prompt,
    )
    return outputs[0]["generated_text"][len(prompt) :].strip()


system_prompt = "You are a SQL assistant"
sql_prompt = "What is the total volume of timber sold by each salesperson, sorted by salesperson?"
sql_context = "CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');"
user = f"{sql_prompt}\n\nDatabase schema:\n{sql_context}"

prompt = [
    {"content": system_prompt, "role": "system"},
    {"content": user, "role": "user"},
]

print(test_inference(prompt))