<a href="https://colab.research.google.com/github/robol921/Platform/blob/master/defog_sqlcoder_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#SQLCoder-7b-2
Run the cells below to run inference on our text-to-SQL LLM: SQLCoder-7b-2.

⭐️ [Github Repo](https://github.com/defog-ai/sqlcoder)

🤗 [Huggingface Page](https://huggingface.co/defog/sqlcoder-7b-2)

##Setup

In [65]:
!pip install torch transformers bitsandbytes accelerate sqlparse sqlglot



In [126]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlglot
import re

In [67]:
torch.cuda.is_available()

True

In [68]:
available_memory = torch.cuda.get_device_properties(0).total_memory
print(available_memory)

42474471424


##Download the Model
Use any model on Colab (or any system with >30GB VRAM on your own machine) to load this in f16. If unavailable, use a GPU with minimum 8GB VRAM to load this in 8bit, or with minimum 5GB of VRAM to load in 4bit.

This step can take around 5 minutes the first time. So please be patient :)

In [69]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 8 bits – this is a bit slower
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

##Set the Question & Prompt and Tokenize
Feel free to change the schema in the prompt below to your own schema

In [177]:
prompt = """
You are an expert in SQL query generation. Generate a query for the given request
### Task
Generate a query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that the id column in products table is 'id'
- Remember that when tables has column 'deleted_at' like in 'sale_concepts' always should select where deleted_at is null
- Remember that if I ask 'In which warehouses' you have to show warehouses names
- Remember that if I ask 'In which sales' you have to show a 'Folio' column using CONCAT(sales.serie, '-', sales.folio)
- Remember to translate column alias to spanish and always use double quotes
- Remember you never have to use spaces in column alias
- Remember if I ask for the count of sold products on each sale you have to search for 'sale_concepts' table in which 'sale_id' is the related sale
- Remember the schema for table 'user' is 'central_systemTA' for any other table schema is 'datastorage' so when using 'user' in a query you have to add prefix 'central_systemTA.'

### Columns names
- When column is warehouses.name use Almacén as column title





### Database Schema
This query will run on a database whose schema is represented in this string:
--
-- Set character set the client will use to send SQL statements to the server
--
SET NAMES 'utf8';

--
-- Set default database
--
USE central_systemTA;

--
-- Create table `user`
--
CREATE TABLE central_systemTA.user (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT,
  nickname VARCHAR(150) CHARACTER SET latin1 COLLATE latin1_swedish_ci NOT NULL,
  name VARCHAR(300) CHARACTER SET latin1 COLLATE latin1_swedish_ci NOT NULL,
  lastname VARCHAR(150) CHARACTER SET latin1 COLLATE latin1_swedish_ci NOT NULL,
  surname VARCHAR(150) CHARACTER SET latin1 COLLATE latin1_swedish_ci DEFAULT NULL,
  company INT NOT NULL,
  passw BLOB NOT NULL,
  picture VARCHAR(500) CHARACTER SET latin1 COLLATE latin1_swedish_ci DEFAULT NULL,
  status INT NOT NULL DEFAULT 0 COMMENT '1 Acceso correcto
2 primer acceso
3 usuario bloqueado
4 cambio de contraseña',
  created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
  deleted_at DATETIME DEFAULT NULL,
  updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  last_seen DATETIME DEFAULT NULL,
  is_online TINYINT(1) NOT NULL DEFAULT 0,
  session_id VARCHAR(255) DEFAULT NULL,
  token VARCHAR(6) DEFAULT NULL,
  is_admin TINYINT NOT NULL DEFAULT 0,
  created_user INT DEFAULT NULL,
  updated_user INT DEFAULT NULL,
  last_subsidiary_id INT DEFAULT NULL,
  type ENUM('E-Commerce','WebService','Support') DEFAULT NULL,
  time_zone_id INT DEFAULT 94,
  PRIMARY KEY (id)
);

--
-- Set default database
--
USE datastorage;

CREATE TABLE sales (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT,
  cfdi_id INT UNSIGNED DEFAULT NULL,
  sale_type ENUM('national','international','N/A') NOT NULL DEFAULT 'N/A',
  advance_id INT DEFAULT NULL,
  advance_amount DECIMAL(24, 6) DEFAULT NULL,
  sale_process_id INT UNSIGNED NOT NULL,
  sale_charge_id INT UNSIGNED DEFAULT NULL,
  sale_debt_id INT DEFAULT NULL,
  movement_id INT UNSIGNED DEFAULT NULL,
  customer_id INT UNSIGNED NOT NULL,
  price_id INT UNSIGNED NOT NULL,
  required_date DATE DEFAULT NULL,
  sale_date DATETIME NOT NULL,
  sale_date_utc DATETIME DEFAULT NULL COMMENT '// Fecha usada solo para auxiliar a la elección de que fecha colocar en el expedition date de un cfdi de ingreso, se hizo esto por que sale_date ya no es seguro para saber si estan en utc o en gmt',
  stamp_date_last_attempt DATETIME DEFAULT NULL COMMENT '// Guarda la fecha de la última vez que se intento timbrar, esto para generar el mismo sello en caso de haber fallado el timbrado.',
  serie VARCHAR(25) DEFAULT NULL,
  folio INT UNSIGNED DEFAULT NULL,
  manual_folio VARCHAR(255) DEFAULT NULL,
  customer_order VARCHAR(255) DEFAULT NULL,
  subsidiary_id INT UNSIGNED DEFAULT NULL,
  warehouse_id INT UNSIGNED DEFAULT NULL,
  cfdi_usage_id INT UNSIGNED NOT NULL,
  currency_id INT UNSIGNED NOT NULL,
  reference TEXT DEFAULT NULL,
  addendum TEXT DEFAULT NULL,
  exchange_rate DECIMAL(24, 6) DEFAULT NULL,
  sale_cost DECIMAL(24, 6) DEFAULT NULL,
  sale_price DECIMAL(24, 6) NOT NULL,
  tax_price DECIMAL(24, 6) NOT NULL,
  withheld_tax_price DECIMAL(24, 6) NOT NULL,
  discount DECIMAL(24, 6) NOT NULL,
  total_price DECIMAL(24, 6) NOT NULL,
  payment_date DATE DEFAULT NULL,
  due_date DATE DEFAULT NULL,
  payment_method_id INT UNSIGNED NOT NULL DEFAULT 1,
  payment_way_id INT UNSIGNED NOT NULL DEFAULT 21,
  responsible_id INT UNSIGNED DEFAULT NULL,
  status_id TINYINT UNSIGNED DEFAULT 1,
  ecommerce_shopping_cart_id INT UNSIGNED DEFAULT NULL,
  requisition_id INT UNSIGNED DEFAULT NULL,
  journal_entry_id INT DEFAULT NULL,
  cn_cancel_id INT UNSIGNED DEFAULT NULL,
  type ENUM('Gestión de ventas','E-Commerce','WebService','Punto de venta') NOT NULL DEFAULT 'Gestión de ventas',
  cfdi_type_id INT UNSIGNED NOT NULL DEFAULT 1,
  addendum_id INT DEFAULT NULL,
  invoice_relations_id TINYINT DEFAULT NULL,
  project_id INT DEFAULT NULL,
  project_grouper_id INT DEFAULT NULL,
  project_concept_id INT DEFAULT NULL,
  manual_delivery_relation INT DEFAULT NULL,
  created_at DATETIME DEFAULT NULL,
  updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  deleted_at DATETIME DEFAULT NULL,
  created_user INT UNSIGNED DEFAULT NULL,
  updated_user INT UNSIGNED DEFAULT NULL,
  pos_cash_register_id INT DEFAULT NULL,
  total_tdc DECIMAL(17, 6) DEFAULT NULL,
  import INT DEFAULT NULL,
  commission_plan_id INT DEFAULT NULL,
  commission_amount_seller DECIMAL(12, 2) DEFAULT NULL,
  site_import_id INT DEFAULT NULL,
  uuid VARCHAR(255) DEFAULT NULL,
  text_exchange_rate VARCHAR(45) DEFAULT NULL,
  deleted_user INT DEFAULT NULL,
  restored_user INT DEFAULT NULL,
  restored_at DATETIME DEFAULT NULL,
  PRIMARY KEY (id)
);

CREATE TABLE sale_concepts (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT,
  sale_id INT UNSIGNED NOT NULL,
  product_id INT UNSIGNED DEFAULT NULL,
  invoice_ps_id INT UNSIGNED NOT NULL,
  currency_id INT UNSIGNED NOT NULL,
  product_code VARCHAR(100) DEFAULT NULL,
  product_barcode VARCHAR(13) DEFAULT NULL,
  product_name TEXT DEFAULT NULL,
  quantity DECIMAL(24, 6) NOT NULL,
  unit_id INT UNSIGNED DEFAULT NULL,
  unit VARCHAR(10) DEFAULT NULL,
  currency_unit_price DECIMAL(24, 6) DEFAULT NULL,
  exchange_rate DECIMAL(24, 6) NOT NULL DEFAULT 1.000000,
  unit_cost DECIMAL(24, 6) DEFAULT NULL,
  unit_price DECIMAL(24, 6) NOT NULL,
  subtotal DECIMAL(24, 6) NOT NULL,
  tax_amount DECIMAL(24, 6) NOT NULL,
  tax_withholding_amount DECIMAL(24, 6) NOT NULL,
  discount_value DECIMAL(24, 6) DEFAULT NULL,
  discount_type TINYINT(1) DEFAULT NULL,
  discount_amount DECIMAL(24, 6) DEFAULT NULL,
  total DECIMAL(24, 6) NOT NULL, -- This is the total price of the sale concept, this is the data to search when asked by revenue of specific product
  is_manual TINYINT UNSIGNED NOT NULL DEFAULT 0,
  is_service TINYINT UNSIGNED DEFAULT 0 COMMENT '0 = Warehouse Product, 1 = Service.',
  inventariable INT DEFAULT NULL,
  is_storable TINYINT UNSIGNED NOT NULL DEFAULT 0,
  created_at DATETIME DEFAULT NULL,
  updated_at DATETIME DEFAULT NULL,
  created_user INT UNSIGNED DEFAULT NULL,
  updated_user INT DEFAULT NULL,
  is_kit TINYINT(1) DEFAULT 0,
  total_tdc DECIMAL(17, 6) DEFAULT NULL,
  service_order_id INT DEFAULT NULL,
  from_service_order INT DEFAULT 0,
  quantity_service INT DEFAULT 0,
  quantity_deliv_req INT DEFAULT 0,
  comment TEXT DEFAULT NULL,
  deleted_at DATETIME DEFAULT NULL,
  sale_concept_id INT DEFAULT NULL,
  fixed_price DECIMAL(24, 6) DEFAULT NULL,
  fixed_price_currency_id INT DEFAULT NULL,
  accounting_center_id INT DEFAULT NULL,
  production_delivery_date DATETIME DEFAULT NULL,
  import_motion_id INT DEFAULT NULL,
  is_subject_to_tax ENUM('01','02','03','04') DEFAULT NULL,
  expiration_date_id INT DEFAULT NULL,
  expiration_date DATE DEFAULT NULL,
  manual_expiration_date DATE DEFAULT NULL,
  property_account VARCHAR(64) DEFAULT NULL,
  manual_import_motion VARCHAR(255) DEFAULT NULL,
  PRIMARY KEY (id)
);

CREATE TABLE products (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT, -- Unique ID for each product
  name VARCHAR(255) NOT NULL,
  code VARCHAR(64) NOT NULL,
  sku VARCHAR(45) DEFAULT NULL,
  unit_cost DECIMAL(24, 6) DEFAULT NULL,
  currency_id INT UNSIGNED NOT NULL DEFAULT 100,
  unit_id INT UNSIGNED DEFAULT NULL,
  invoice_type_id INT UNSIGNED DEFAULT NULL,
  popular_name VARCHAR(255) DEFAULT NULL,
  reorder_point DECIMAL(24, 6) DEFAULT NULL,
  barcode VARCHAR(13) DEFAULT NULL,
  description TEXT DEFAULT NULL,
  inventariable TINYINT UNSIGNED NOT NULL DEFAULT 1,
  active TINYINT UNSIGNED NOT NULL DEFAULT 1,
  ecommerce_excluded TINYINT UNSIGNED NOT NULL DEFAULT 0,
  expried TINYINT UNSIGNED NOT NULL DEFAULT 0,
  register_serial_numbers TINYINT UNSIGNED NOT NULL DEFAULT 0,
  favorite TINYINT UNSIGNED NOT NULL DEFAULT 0,
  deep DECIMAL(12, 6) DEFAULT NULL,
  deep_unit_id INT DEFAULT NULL,
  width DECIMAL(12, 6) DEFAULT NULL,
  width_unit_id INT DEFAULT NULL,
  height DECIMAL(12, 6) DEFAULT NULL,
  height_unit_id INT DEFAULT NULL,
  weight DECIMAL(12, 6) DEFAULT NULL,
  weight_unit_id INT DEFAULT NULL,
  packaging_type INT DEFAULT NULL,
  tariff_item VARCHAR(10) DEFAULT NULL,
  customs_unit_id INT UNSIGNED ZEROFILL DEFAULT NULL,
  group_id INT UNSIGNED DEFAULT NULL,
  line_id INT UNSIGNED DEFAULT NULL,
  accounting_account_id INT DEFAULT NULL,
  accounting_account_mining_id INT DEFAULT NULL,
  accounting_account_is_main TINYINT(1) DEFAULT 0,
  category_id INT UNSIGNED DEFAULT NULL,
  product_type_id INT DEFAULT NULL,
  type ENUM('producto','servicio','activo fijo') DEFAULT 'producto',
  brand_id INT UNSIGNED DEFAULT NULL,
  model VARCHAR(45) DEFAULT NULL,
  is_kit TINYINT NOT NULL DEFAULT 0,
  is_service TINYINT NOT NULL DEFAULT 0,
  is_fixed_asset TINYINT NOT NULL DEFAULT 0,
  score FLOAT DEFAULT 0,
  data_sheet_description TEXT DEFAULT NULL,
  expense_type_id INT DEFAULT NULL,
  catalog_included TINYINT NOT NULL DEFAULT 1,
  has_decimals TINYINT NOT NULL DEFAULT 0,
  created_at DATETIME DEFAULT NULL,
  created_user_id INT DEFAULT NULL,
  updated_at DATETIME DEFAULT NULL,
  update_user_id INT DEFAULT NULL,
  deleted_at DATETIME DEFAULT NULL,
  deleted_user_id INT DEFAULT NULL,
  delivery_days INT DEFAULT 1,
  is_material TINYINT NOT NULL DEFAULT 0,
  security_stock FLOAT DEFAULT NULL,
  lot_quantity FLOAT DEFAULT NULL,
  io_type_id INT DEFAULT NULL,
  default_provider INT DEFAULT NULL,
  is_internally_produced TINYINT DEFAULT 0,
  apply_importation TINYINT DEFAULT 0,
  make_to_order TINYINT DEFAULT NULL,
  property_account VARCHAR(64) DEFAULT NULL,
  weekly_production_capacity DECIMAL(24, 6) DEFAULT NULL,
  additional_accounting_account_id INT DEFAULT NULL,
  PRIMARY KEY (id),
  INDEX prod_indx(id)
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE warehouses (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT,
  code VARCHAR(15) DEFAULT NULL,
  name VARCHAR(255) DEFAULT NULL,
  is_fixed_asset TINYINT(1) DEFAULT 0,
  description VARCHAR(255) DEFAULT NULL,
  email VARCHAR(56) DEFAULT NULL,
  phone VARCHAR(45) DEFAULT NULL,
  active TINYINT DEFAULT 1,
  is_third_party TINYINT(1) DEFAULT 0,
  manager_id INT DEFAULT NULL,
  subsidiary_id INT DEFAULT NULL,
  accounting_center_id INT DEFAULT NULL,
  accounting_account_id INT DEFAULT NULL,
  use_accounting TINYINT(1) DEFAULT 1,
  created_at DATETIME DEFAULT NULL,
  created_user_id INT DEFAULT NULL,
  updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  update_user_id INT DEFAULT NULL,
  deleted_at DATETIME DEFAULT NULL,
  deleted_user_id INT DEFAULT NULL,
  PRIMARY KEY (id),
  INDEX idx_name_code_id(id, code, name)
);

CREATE TABLE stocks (
  id INT UNSIGNED NOT NULL AUTO_INCREMENT,
  product_id INT UNSIGNED NOT NULL,
  warehouse_id INT UNSIGNED NOT NULL,
  third_party_available DECIMAL(24, 6) UNSIGNED NOT NULL DEFAULT 0.000000,
  self_blocked DECIMAL(24, 6) NOT NULL DEFAULT 0.000000,
  reorder_point DECIMAL(24, 6) DEFAULT NULL,
  actual_stock DECIMAL(24, 6) DEFAULT 0.000000,
  has_movements TINYINT NOT NULL DEFAULT 0,
  PRIMARY KEY (id)
);

ALTER TABLE stocks
  ADD INDEX fk_stocks_warehouse_id_warehouse_idx(warehouse_id);

ALTER TABLE stocks
  ADD UNIQUE INDEX product_id_warehouse_id(product_id, warehouse_id);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Examples
- For question: 'Cuanto he vendido de RIMEL WATERPROOF?'
- Query should be: SELECT SUM(sale_concepts.total) AS total_sold
    FROM sale_concepts
    JOIN products ON sale_concepts.product_id = products.id
    WHERE products.name like '%RIMEL WATERPROOF%';
- Important: You shouldn't use ILIKE in the query


### Answer
Given the database schema, here is the MySQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

##Generate the SQL
This can be excruciatingly slow on a T4 in Colab, and can take 10-20 seconds per query. On faster GPUs, this will take ~1-2 seconds

Ideally, you should use `num_beams`=4 for best results. But because of memory constraints, we will stick to just 1 for now.

In [178]:
import sqlparse

def preprocess_query(query: str):
    """Converts SQL query to a single line and ensures alias quoting."""

    # Step 1: Convert query to a single line (remove newlines, excessive spaces)
    query = re.sub(r'\s+', ' ', query).strip()

    # Step 2: Use regex to correctly quote multi-word aliases
    # This prevents capturing the next SQL clause like FROM, WHERE, etc.
    pattern = r'\bAS\s+((?!FROM|WHERE|ORDER|GROUP|HAVING|LIMIT|JOIN|ON|BY|NULLS|ASC|DESC)[A-Za-zÁÉÍÓÚÑáéíóúñ0-9_]+(?:\s+[A-Za-zÁÉÍÓÚÑáéíóúñ0-9_]+)*)\b'

    def replacer(match):
        alias = match.group(1).strip()
        return f'AS "{alias}"'

    query = re.sub(pattern, replacer, query)

    return query

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=200,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [183]:
# question = "create a mysql query for: How many stock of 'GLOSS ITALIA DELUXE' do I have on each warehouse? Omit warehouses having 0 stock. Show me warehouses names and product names."
# question = "In which warehouse do I have less 'GLOSS ITALIA DELUXE' stock? exclude zero stock. Show also stock quantity. Remember that the query must be compatible with mySQL"
# question = "In which warehouses do I have stock 10 for product which name contains'TOM'? show warehouse name, product name and quantity. sort by warehouse and product"
# question = "show me deleted warehouses and deletion date. show the date formated like '03 de Enero de 2024'"
# question = "who have deleted warehouse TORREON"
# question = "show deleted stores deleted on august 2019 also show deleted dates"
# question = "In which sales I sold 'Ventilador industrial'?"
# question = "show me products sold in sale 'ESC-29' include quantities"
question = "show me products with quantities in sale PV-V-17"
# question = "Cuanto stock de 'RIMEL WATERPROOF' tengo en el almacen 'Durango'?"
# question = "Cuanto he vendido de 'TOMATE Premium'?"

generated_postgre_sql = generate_query(question)
# print(generated_postgre_sql)

# Pre-process to add quotes
# processed_query = preprocess_query(generated_postgre_sql)
# print(processed_query)

mysql_query = sqlglot.transpile(generated_postgre_sql, read="postgres", write="mysql")[0]

print(mysql_query)

SELECT p.name, SUM(sc.quantity) AS total_quantity FROM sale_concepts AS sc JOIN products AS p ON sc.product_id = p.id JOIN sales AS s ON sc.sale_id = s.id WHERE s.warehouse_id = 17 GROUP BY p.name
