In [None]:
!pip install -q -U datasets==2.16.1

In [None]:
!pip install -q -U bitsandbytes==0.42.0
!pip install -q -U peft==0.8.2
!pip install -q -U trl==0.7.10
!pip install -q -U accelerate==0.27.1
!pip install -q -U transformers==4.38.0

In [None]:
# !pip install huggingface_hub --q

In [None]:
!nvidia-smi

In [None]:
from huggingface_hub import notebook_login
notebook_login()# run and pass the HuggingFace API Token

In [None]:
checkpoint = "suriya7/Gemma2B-Finetuned-Sql-Generator"

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
model_id = checkpoint


model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config=bnb_config, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

## Inference

In [None]:
import torch

prompt_template = """
<start_of_turn>user
You are an intelligent AI specialized in generating SQL queries.
Your task is to assist users in formulating SQL queries to retrieve specific information from a database.
Please provide the SQL query corresponding to the given prompt and context:

Prompt:
find the price of laptop

Context:
CREATE TABLE products (
    product_id INT,
    product_name VARCHAR(100),
    category VARCHAR(50),
    price DECIMAL(10, 2),
    stock_quantity INT
);

INSERT INTO products (product_id, product_name, category, price, stock_quantity) 
VALUES 
    (1, 'Smartphone', 'Electronics', 599.99, 100),
    (2, 'Laptop', 'Electronics', 999.99, 50),
    (3, 'Headphones', 'Electronics', 99.99, 200),
    (4, 'T-shirt', 'Apparel', 19.99, 300),
    (5, 'Jeans', 'Apparel', 49.99, 150);<end_of_turn>
<start_of_turn>model
"""

prompt = prompt_template
encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = encodeds.to(device)


# Increase max_new_tokens if needed
generated_ids = model.generate(inputs, max_new_tokens=1000, do_sample=True, temperature = 0.7,pad_token_id=tokenizer.eos_token_id)
ans = ''
for i in tokenizer.decode(generated_ids[0], skip_special_tokens=True).split('<end_of_turn>')[:2]:
    ans += i

# Extract only the model's answer
model_answer = ans.split("model")[1].strip()
print(model_answer)

In [None]:
from datasets import load_dataset, DatasetDict
from datasets import concatenate_datasets


dataset = load_dataset('gretelai/synthetic_text_to_sql')#load the dataset that you want to fine tune

In [None]:
dataset

### Prompt Format for Google Gemma Model

In [None]:
prompt_template = """
<start_of_turn>user
Answer the following question in a concise and informative manner:
 
Explain why the sky is blue<end_of_turn>
<start_of_turn>model
"""

In [None]:
def generate_prompt(data_point):
    """Generate input text based on a prompt, task instruction, (context info), and answer.

    :param data_point: dict: Data point
    :return: dict: Data point with the added "prompt" field
    """
    prompt_template = """
    <start_of_turn>user
    You are an intelligent AI specialized in generating SQL queries.
    Your task is to assist users in formulating SQL queries to retrieve specific information from a database.
    Please provide the SQL query corresponding to the given prompt and context:

    Prompt:
    {sql_prompt}

    Context:
    {sql_context}<end_of_turn>
    <start_of_turn>model
    {sql}<end_of_turn>
    """

    prompt_text = prompt_template.format(sql_prompt=data_point["sql_prompt"],sql_context=data_point["sql_context"],sql=data_point["sql"])
    data_point["prompt"] = prompt_text

    return data_point

# Add the "prompt" column to the dataset
dataset = dataset['train'].map(generate_prompt)

# Print the updated dataset
dataset

In [None]:
print(dataset['prompt'][99])

In [None]:
dataset = dataset.shuffle(seed=2000)  # Shuffle dataset here
dataset = dataset.map(lambda samples: tokenizer(samples["prompt"]), batched=True)

In [None]:
dataset = dataset.train_test_split(test_size=0.2)
train_data = dataset["train"]
test_data = dataset["test"]

In [None]:
test_data

In [None]:
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
import bitsandbytes as bnb
def find_all_linear_names(model):
  cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split('.')
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
      lora_module_names.remove('lm_head')
  return list(lora_module_names)

In [None]:
modules = find_all_linear_names(model)
print(modules)

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=100,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [None]:
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

In [None]:
from trl import SFTTrainer
import transformers
import torch

tokenizer.pad_token = tokenizer.eos_token
torch.cuda.empty_cache()

trainer = SFTTrainer(
    model,
    train_dataset=dataset['train'],
    dataset_text_field="prompt",
    args=transformers.TrainingArguments(
        learning_rate=2e-4,
        output_dir="Gemma",
        per_device_train_batch_size=1,
        num_train_epochs=1,
        logging_strategy="steps",
        save_strategy="steps",
        logging_steps=10,
        save_steps=100000,
        weight_decay=0.01,
        adam_beta1=0.9,
        adam_beta2=0.98,
        adam_epsilon=1e-8,
        warmup_steps=0.03,
        fp16=True,
        seed=42,
        save_total_limit=1,
    ),data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

In [None]:
model.config.use_cache = False
trainer.train()

In [None]:
new_model = "gemma-model"

In [None]:
trainer.model.save_pretrained("gemma-model")# save the model

### Save the model weights and adapters

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
merged_model= PeftModel.from_pretrained(base_model, new_model)
merged_model= merged_model.merge_and_unload()

# Save the merged model
merged_model.save_pretrained("merged_model",safe_serialization=True)
tokenizer.save_pretrained("merged_model")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

### Test With New Samples

In [None]:
prompt_template = """
<start_of_turn>user
You are an intelligent AI specialized in generating SQL queries.
Your task is to assist users in formulating SQL queries to retrieve specific information from a database.
Please provide the SQL query corresponding to the given prompt and context:

Prompt:
What is the total production of oil from the onshore fields in the Beaufort Sea?

Context:
CREATE TABLE beaufort_sea_oil_production (field VARCHAR(255), year INT, production FLOAT);<end_of_turn>
<start_of_turn>model
    """

prompt = prompt_template
encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
merged_model.to(device)
inputs = encodeds.to(device)


# Increase max_new_tokens if needed
generated_ids = merged_model.generate(inputs, max_new_tokens=1000, do_sample=True, temperature = 0.7,pad_token_id=tokenizer.eos_token_id)
ans = ''
for i in tokenizer.decode(generated_ids[0], skip_special_tokens=True).split('<end_of_turn>')[:2]:
    ans += i

# Extract only the model's answer
model_answer = ans.split("model")[1].strip()
print(model_answer)