<a href="https://colab.research.google.com/github/robol921/Platform/blob/master/defog_sqlcoder_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#SQLCoder-7b-2
Run the cells below to run inference on our text-to-SQL LLM: SQLCoder-7b-2.

⭐️ [Github Repo](https://github.com/defog-ai/sqlcoder)

🤗 [Huggingface Page](https://huggingface.co/defog/sqlcoder-7b-2)

##Setup

In [1]:
!pip install torch transformers bitsandbytes accelerate sqlparse



In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
torch.cuda.is_available()

True

In [3]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [4]:
print(available_memory)

15828320256


##Download the Model
Use any model on Colab (or any system with >30GB VRAM on your own machine) to load this in f16. If unavailable, use a GPU with minimum 8GB VRAM to load this in 8bit, or with minimum 5GB of VRAM to load in 4bit.

This step can take around 5 minutes the first time. So please be patient :)

In [5]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 8 bits – this is a bit slower
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

##Set the Question & Prompt and Tokenize
Feel free to change the schema in the prompt below to your own schema

In [34]:
prompt = """### Task
Generate a MySQL-compatible query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that the id column in products table is 'id'
- Remember that cost is supply_price multiplied by quantity
- Remember that the query must be compatible with mySQL
- Remember that if I ask 'In which warehouses' you have to show warehouses names

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT, -- Unique ID for each product
  name VARCHAR(255) NOT NULL,
  code VARCHAR(64) NOT NULL,
  sku VARCHAR(45) DEFAULT NULL,
  unit_cost DECIMAL(24, 6) DEFAULT NULL,
  currency_id INT UNSIGNED NOT NULL DEFAULT 100,
  unit_id INT UNSIGNED DEFAULT NULL,
  invoice_type_id INT UNSIGNED DEFAULT NULL,
  popular_name VARCHAR(255) DEFAULT NULL,
  reorder_point DECIMAL(24, 6) DEFAULT NULL,
  barcode VARCHAR(13) DEFAULT NULL,
  description TEXT DEFAULT NULL,
  inventariable TINYINT UNSIGNED NOT NULL DEFAULT 1,
  active TINYINT UNSIGNED NOT NULL DEFAULT 1,
  ecommerce_excluded TINYINT UNSIGNED NOT NULL DEFAULT 0,
  expried TINYINT UNSIGNED NOT NULL DEFAULT 0,
  register_serial_numbers TINYINT UNSIGNED NOT NULL DEFAULT 0,
  favorite TINYINT UNSIGNED NOT NULL DEFAULT 0,
  deep DECIMAL(12, 6) DEFAULT NULL,
  deep_unit_id INT DEFAULT NULL,
  width DECIMAL(12, 6) DEFAULT NULL,
  width_unit_id INT DEFAULT NULL,
  height DECIMAL(12, 6) DEFAULT NULL,
  height_unit_id INT DEFAULT NULL,
  weight DECIMAL(12, 6) DEFAULT NULL,
  weight_unit_id INT DEFAULT NULL,
  packaging_type INT DEFAULT NULL,
  tariff_item VARCHAR(10) DEFAULT NULL,
  customs_unit_id INT UNSIGNED ZEROFILL DEFAULT NULL,
  group_id INT UNSIGNED DEFAULT NULL,
  line_id INT UNSIGNED DEFAULT NULL,
  accounting_account_id INT DEFAULT NULL,
  accounting_account_mining_id INT DEFAULT NULL,
  accounting_account_is_main TINYINT(1) DEFAULT 0,
  category_id INT UNSIGNED DEFAULT NULL,
  product_type_id INT DEFAULT NULL,
  type ENUM('producto','servicio','activo fijo') DEFAULT 'producto',
  brand_id INT UNSIGNED DEFAULT NULL,
  model VARCHAR(45) DEFAULT NULL,
  is_kit TINYINT NOT NULL DEFAULT 0,
  is_service TINYINT NOT NULL DEFAULT 0,
  is_fixed_asset TINYINT NOT NULL DEFAULT 0,
  score FLOAT DEFAULT 0,
  data_sheet_description TEXT DEFAULT NULL,
  expense_type_id INT DEFAULT NULL,
  catalog_included TINYINT NOT NULL DEFAULT 1,
  has_decimals TINYINT NOT NULL DEFAULT 0,
  created_at DATETIME DEFAULT NULL,
  created_user_id INT DEFAULT NULL,
  updated_at DATETIME DEFAULT NULL,
  update_user_id INT DEFAULT NULL,
  deleted_at DATETIME DEFAULT NULL,
  deleted_user_id INT DEFAULT NULL,
  delivery_days INT DEFAULT 1,
  is_material TINYINT NOT NULL DEFAULT 0,
  security_stock FLOAT DEFAULT NULL,
  lot_quantity FLOAT DEFAULT NULL,
  io_type_id INT DEFAULT NULL,
  default_provider INT DEFAULT NULL,
  is_internally_produced TINYINT DEFAULT 0,
  apply_importation TINYINT DEFAULT 0,
  make_to_order TINYINT DEFAULT NULL,
  property_account VARCHAR(64) DEFAULT NULL,
  weekly_production_capacity DECIMAL(24, 6) DEFAULT NULL,
  additional_accounting_account_id INT DEFAULT NULL,
  PRIMARY KEY (id),
  INDEX prod_indx(id)
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE warehouses (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT,
  code VARCHAR(15) DEFAULT NULL,
  name VARCHAR(255) DEFAULT NULL,
  is_fixed_asset TINYINT(1) DEFAULT 0,
  description VARCHAR(255) DEFAULT NULL,
  email VARCHAR(56) DEFAULT NULL,
  phone VARCHAR(45) DEFAULT NULL,
  active TINYINT DEFAULT 1,
  is_third_party TINYINT(1) DEFAULT 0,
  manager_id INT DEFAULT NULL,
  subsidiary_id INT DEFAULT NULL,
  accounting_center_id INT DEFAULT NULL,
  accounting_account_id INT DEFAULT NULL,
  use_accounting TINYINT(1) DEFAULT 1,
  created_at DATETIME DEFAULT NULL,
  created_user_id INT DEFAULT NULL,
  updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  update_user_id INT DEFAULT NULL,
  deleted_at DATETIME DEFAULT NULL,
  deleted_user_id INT DEFAULT NULL,
  PRIMARY KEY (id),
  INDEX idx_name_code_id(id, code, name)
);

CREATE TABLE stocks (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT,
  product_id INT UNSIGNED NOT NULL,
  warehouse_id INT UNSIGNED NOT NULL,
  third_party_available DECIMAL(24, 6) UNSIGNED NOT NULL DEFAULT 0.000000,
  self_blocked DECIMAL(24, 6) NOT NULL DEFAULT 0.000000,
  reorder_point DECIMAL(24, 6) DEFAULT NULL,
  actual_stock DECIMAL(24, 6) DEFAULT 0.000000,
  has_movements TINYINT NOT NULL DEFAULT 0,
  PRIMARY KEY (id)
);

ALTER TABLE stocks
  ADD INDEX fk_stocks_warehouse_id_warehouse_idx(warehouse_id);

ALTER TABLE stocks
  ADD UNIQUE INDEX product_id_warehouse_id(product_id, warehouse_id);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Answer
Given the database schema, here is the MySQL-compatible query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

##Generate the SQL
This can be excruciatingly slow on a T4 in Colab, and can take 10-20 seconds per query. On faster GPUs, this will take ~1-2 seconds

Ideally, you should use `num_beams`=4 for best results. But because of memory constraints, we will stick to just 1 for now.

In [35]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=200,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [46]:
# question = "How many stock of 'GLOSS ITALIA DELUXE' do I have on each warehouse? Omit warehouses having 0 stock. Show me warehouses names and product names."
# question = "In which warehouse do I have less 'GLOSS ITALIA DELUXE' stock? exclude zero stock. Show also stock quantity. Remember that the query must be compatible with mySQL"
question = "In which warehouses do I have stock 10 for 'GLOSS ITALIA DELUXE'?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT s.warehouse_id
FROM stocks s
WHERE s.actual_stock >= 10
  AND s.product_id IN
    (SELECT p.id
     FROM products p
     WHERE p.name = 'GLOSS ITALIA DELUXE');
