<a href="https://colab.research.google.com/github/tariksghiouri/BookStoreMVC/blob/master/SQLCoder_8b_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install torch transformers bitsandbytes accelerate sqlparse pyodbc psycopg2 mysql-connector-python -q


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m309.4/309.4 kB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m334.7/334.7 kB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.4/19.4 MB[0m [31m67.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import sqlparse
import pyodbc
import psycopg2
import mysql.connector

In [None]:
torch.cuda.is_available()
available_memory = torch.cuda.get_device_properties(0).total_memory
print(available_memory)


In [32]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

In [None]:

if available_memory > 20e9:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        load_in_4bit=True,
        device_map="auto",
        use_cache=True,
    )


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [25]:
def extract_table_names(schema):
    table_name_pattern_1 = re.compile(r'CREATE TABLE `(\w+)`', re.IGNORECASE)
    table_name_pattern_2 = re.compile(r'CREATE TABLE "(\w+)"', re.IGNORECASE)
    table_name_pattern_3 = re.compile(r'CREATE TABLE (\w+\.\w+)', re.IGNORECASE)
    table_names_1 = table_name_pattern_1.findall(schema)
    table_names_2 = table_name_pattern_2.findall(schema)
    table_names_3 = table_name_pattern_3.findall(schema)
    return table_names_1 + table_names_2 + table_names_3


prompt_template = """
### Instructions:
Your task is to convert a question into a vlid and executable SQL query, given a {databasetype} database schema.
Adhere to these rules:

  - if the question does not specify which columns select all cloumns from the table
  - **Deliberately go through the question and database schema word by word** to appropriately answer the question
  - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
  - When creating a ratio, always cast the numerator as float

### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
{schema} \n {context}

### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
```sql
"""

def get_context_query(database_type, table_names):
    if database_type == "SQLSERVER":
        return f"""
USE [{dataBaseName}];
GO

SELECT
    ' -- ' + fk_table.name + '.' + fk_col.name + ' can be joined with ' + pk_table.name + '.' + pk_col.name AS constraint_description
FROM
    sys.foreign_key_columns AS fk
    INNER JOIN sys.tables AS fk_table ON fk.parent_object_id = fk_table.object_id
    INNER JOIN sys.columns AS fk_col ON fk.parent_object_id = fk_col.object_id AND fk.parent_column_id = fk_col.column_id
    INNER JOIN sys.tables AS pk_table ON fk.referenced_object_id = pk_table.object_id
    INNER JOIN sys.columns AS pk_col ON fk.referenced_object_id = pk_col.object_id AND fk.referenced_column_id = pk_col.column_id
WHERE
    fk_table.name IN ({', '.join([f"'{table}'" for table in table_names])})
    OR pk_table.name IN ({', '.join([f"'{table}'" for table in table_names])})
ORDER BY
    fk_table.name, fk_col.name;
"""
    elif database_type == "POSTGRES":
        print("in postgres case")
        return f"""
WITH fk_info AS (
    SELECT
        conname AS constraint_name,
        conrelid::regclass AS table_from,
        a.attname AS column_from,
        confrelid::regclass AS table_to,
        af.attname AS column_to
    FROM
        pg_constraint AS c
    JOIN
        pg_attribute AS a ON a.attnum = ANY(c.conkey) AND a.attrelid = c.conrelid
    JOIN
        pg_attribute AS af ON af.attnum = ANY(c.confkey) AND af.attrelid = c.confrelid
    WHERE
        c.contype = 'f'
        AND (conrelid::regclass::text IN ({', '.join([f"'{table}'" for table in table_names])})
             OR confrelid::regclass::text IN ({', '.join([f"'{table}'" for table in table_names])})
        )
)
SELECT
    '-- ' || table_from || '.' || column_from || ' can be joined with ' || table_to || '.' || column_to AS join_statement
FROM
    fk_info;
"""
    elif database_type == "MYSQL":
        return f"""
SELECT
    CONCAT(
        kcu1.TABLE_NAME, '.', kcu1.COLUMN_NAME,
        ' can be joined with ',
        kcu2.TABLE_NAME, '.', kcu2.COLUMN_NAME
    ) AS constraints
FROM
    information_schema.KEY_COLUMN_USAGE kcu1
    JOIN information_schema.REFERENTIAL_CONSTRAINTS rc
    ON kcu1.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
    AND kcu1.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA
    JOIN information_schema.KEY_COLUMN_USAGE kcu2
    ON kcu2.CONSTRAINT_NAME = rc.UNIQUE_CONSTRAINT_NAME
    AND kcu2.CONSTRAINT_SCHEMA = rc.UNIQUE_CONSTRAINT_SCHEMA
WHERE
    kcu1.TABLE_SCHEMA = '{dataBaseName}'
    AND kcu1.TABLE_NAME IN ({', '.join([f"'{table}'" for table in table_names])})
    AND kcu2.TABLE_NAME IN ({', '.join([f"'{table}'" for table in table_names])});
"""
    else:
        return ""

def get_context(database_type, table_names, connection_params):
    def get_sqlserver_context_query_result(query, server, database, username, password):
        conn_str = f'DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={server};DATABASE={database};UID={username};PWD={password}'
        conn = pyodbc.connect(conn_str)
        cursor = conn.cursor()
        cursor.execute(query)
        result = cursor.fetchall()
        conn.close()
        return "\n".join([row[0] for row in result])

    def get_postgres_context_query_result(query, host, database, user, password, port):
        conn = psycopg2.connect(host=host, database=database, user=user, password=password, port=port)
        cursor = conn.cursor()
        cursor.execute(query)
        result = cursor.fetchall()
        conn.close()
        return "\n".join([row[0] for row in result])

    def get_mysql_context_query_result(query, host, database, user, password):
        conn = mysql.connector.connect(host=host, database=database, user=user, password=password)
        cursor = conn.cursor()
        cursor.execute(query)
        result = cursor.fetchall()
        conn.close()
        return "\n".join([row[0] for row in result])

    context_query = get_context_query(database_type, table_names)

    if database_type == "SQLSERVER":
        return get_sqlserver_context_query_result(
            context_query,
            connection_params['server'],
            connection_params['database'],
            connection_params['username'],
            connection_params['password']
        )
    elif database_type == "POSTGRES":
        return get_postgres_context_query_result(
            context_query,
            connection_params['host'],
            connection_params['database'],
            connection_params['user'],
            connection_params['password'],
            connection_params['port']
        )
    elif database_type == "MYSQL":
        return get_mysql_context_query_result(
            context_query,
            connection_params['host'],
            connection_params['database'],
            connection_params['user'],
            connection_params['password']
        )
    else:
        raise ValueError("Unsupported database type")

def generate_query(dataBaseType,question, schema, context):
    # torch.cuda.empty_cache()
    # torch.cuda.synchronize()
    updated_prompt = prompt_template.format(question=question, schema=schema, context=context,databasetype= dataBaseType)


    print("updated_prompt: ", updated_prompt)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
        top_p=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)


    return outputs[0].split("```sql")[1].split(";")[0]

# Example schema from the prompt
example_schema = """
CREATE TABLE "res_company_users_rel" ("cid" integer NOT NULL, "user_id" integer NOT NULL);
 CREATE TABLE "res_config" ("id" integer DEFAULT nextval('res_config_id_seq'::regclass) NOT NULL, "create_uid" integer, "write_uid" integer, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "res_config_installer" ("id" integer DEFAULT nextval('res_config_installer_id_seq'::regclass) NOT NULL, "create_uid" integer, "write_uid" integer, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
CREATE TABLE "hr_department_hr_leave_mandatory_day_rel" ("hr_leave_mandatory_day_id" integer NOT NULL, "hr_department_id" integer NOT NULL);
 CREATE TABLE "hr_departure_reason" ("id" integer DEFAULT nextval('hr_departure_reason_id_seq'::regclass) NOT NULL, "sequence" integer, "reason_code" integer, "create_uid" integer, "write_uid" integer, "name" jsonb NOT NULL, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "hr_departure_wizard" ("id" integer DEFAULT nextval('hr_departure_wizard_id_seq'::regclass) NOT NULL, "departure_reason_id" integer NOT NULL, "employee_id" integer NOT NULL, "create_uid" integer, "write_uid" integer, "departure_date" date NOT NULL, "departure_description" text, "create_date" timestamp without time zone, "write_date" timestamp without time zone, "set_date_end" boolean, "release_campany_car" boolean, "cancel_leaves" boolean, "archive_allocation" boolean);
 CREATE TABLE "hr_employee" ("id" integer DEFAULT nextval('hr_employee_id_seq'::regclass) NOT NULL, "resource_id" integer NOT NULL, "company_id" integer NOT NULL, "resource_calendar_id" integer, "message_main_attachment_id" integer, "color" integer, "department_id" integer, "job_id" integer, "address_id" integer, "work_contact_id" integer, "work_location_id" integer, "user_id" integer, "parent_id" integer, "coach_id" integer, "private_state_id" integer, "private_country_id" integer, "country_id" integer, "children" integer, "country_of_birth" integer, "bank_account_id" integer, "km_home_work" integer, "departure_reason_id" integer, "create_uid" integer, "write_uid" integer, "name" character varying, "job_title" character varying, "work_phone" character varying, "mobile_phone" character varying, "work_email" character varying, "private_street" character varying, "private_street2" character varying, "private_city" character varying, "private_zip" character varying, "private_phone" character varying, "private_email" character varying, "lang" character varying, "gender" character varying, "marital" character varying, "spouse_complete_name" character varying, "place_of_birth" character varying, "ssnid" character varying, "sinid" character varying, "identification_id" character varying, "passport_id" character varying, "permit_no" character varying, "visa_no" character varying, "certificate" character varying, "study_field" character varying, "study_school" character varying, "emergency_contact" character varying, "emergency_phone" character varying, "employee_type" character varying NOT NULL, "barcode" character varying, "pin" character varying, "private_car_plate" character varying, "spouse_birthdate" date, "birthday" date, "visa_expire" date, "work_permit_expiration_date" date, "departure_date" date, "employee_properties" jsonb, "additional_note" text, "notes" text, "departure_description" text, "active" boolean, "work_permit_scheduled_activity" boolean, "create_date" timestamp without time zone, "write_date" timestamp without time zone, "attendance_manager_id" integer, "last_attendance_id" integer, "last_check_in" timestamp without time zone, "last_check_out" timestamp without time zone, "contract_id" integer, "vehicle" character varying, "first_contract_date" date, "contract_warning" boolean, "mobility_card" character varying, "leave_manager_id" integer, "expense_manager_id" integer, "hourly_cost" numeric);
 CREATE TABLE "hr_employee_category" ("id" integer DEFAULT nextval('hr_employee_category_id_seq'::regclass) NOT NULL, "color" integer, "create_uid" integer, "write_uid" integer, "name" character varying NOT NULL, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
CREATE TABLE "spreadsheet_dashboard_group" ("id" integer DEFAULT nextval('spreadsheet_dashboard_group_id_seq'::regclass) NOT NULL, "sequence" integer, "create_uid" integer, "write_uid" integer, "name" jsonb NOT NULL, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "spreadsheet_dashboard_share" ("id" integer DEFAULT nextval('spreadsheet_dashboard_share_id_seq'::regclass) NOT NULL, "dashboard_id" integer NOT NULL, "create_uid" integer, "write_uid" integer, "access_token" character varying NOT NULL, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "stock_assign_serial" ("id" integer DEFAULT nextval('stock_assign_serial_id_seq'::regclass) NOT NULL, "move_id" integer, "next_serial_count" integer NOT NULL, "create_uid" integer, "write_uid" integer, "next_serial_number" character varying NOT NULL, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "stock_backorder_confirmation" ("id" integer DEFAULT nextval('stock_backorder_confirmation_id_seq'::regclass) NOT NULL, "create_uid" integer, "write_uid" integer, "show_transfers" boolean, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "stock_backorder_confirmation_line" ("id" integer DEFAULT nextval('stock_backorder_confirmation_line_id_seq'::regclass) NOT NULL, "backorder_confirmation_id" integer, "picking_id" integer, "create_uid" integer, "write_uid" integer, "to_backorder" boolean, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "stock_change_product_qty" ("id" integer DEFAULT nextval('stock_change_product_qty_id_seq'::regclass) NOT NULL, "product_id" integer NOT NULL, "product_tmpl_id" integer NOT NULL, "create_uid" integer, "write_uid" integer, "new_quantity" numeric NOT NULL, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "stock_conflict_quant_rel" ("stock_inventory_conflict_id" integer NOT NULL, "stock_quant_id" integer NOT NULL);
 CREATE TABLE "stock_inventory_adjustment_name" ("id" integer DEFAULT nextval('stock_inventory_adjustment_name_id_seq'::regclass) NOT NULL, "create_uid" integer, "write_uid" integer, "inventory_adjustment_name" character varying, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "stock_inventory_adjustment_name_stock_quant_rel" ("stock_inventory_adjustment_name_id" integer NOT NULL, "stock_quant_id" integer NOT NULL);
 CREATE TABLE "stock_inventory_conflict" ("id" integer DEFAULT nextval('stock_inventory_conflict_id_seq'::regclass) NOT NULL, "create_uid" integer, "write_uid" integer, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "stock_inventory_conflict_stock_quant_rel" ("stock_inventory_conflict_id" integer NOT NULL, "stock_quant_id" integer NOT NULL);
 CREATE TABLE "stock_inventory_warning" ("id" integer DEFAULT nextval('stock_inventory_warning_id_seq'::regclass) NOT NULL, "create_uid" integer, "write_uid" integer, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "stock_inventory_warning_stock_quant_rel" ("stock_inventory_warning_id" integer NOT NULL, "stock_quant_id" integer NOT NULL);
CREATE TABLE "hr_leave_accrual_plan" ("id" integer DEFAULT nextval('hr_leave_accrual_plan_id_seq'::regclass) NOT NULL, "time_off_type_id" integer, "company_id" integer, "carryover_day" integer, "create_uid" integer, "write_uid" integer, "name" character varying NOT NULL, "transition_mode" character varying NOT NULL, "accrued_gain_time" character varying NOT NULL, "carryover_date" character varying NOT NULL, "carryover_month" character varying, "added_value_type" character varying, "active" boolean, "is_based_on_worked_time" boolean, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "hr_leave_allocation" ("id" integer DEFAULT nextval('hr_leave_allocation_id_seq'::regclass) NOT NULL, "holiday_status_id" integer NOT NULL, "employee_id" integer, "employee_company_id" integer, "manager_id" integer, "parent_id" integer, "approver_id" integer, "mode_company_id" integer, "department_id" integer, "category_id" integer, "accrual_plan_id" integer, "create_uid" integer, "write_uid" integer, "private_name" character varying, "state" character varying, "holiday_type" character varying NOT NULL, "allocation_type" character varying NOT NULL, "date_from" date NOT NULL, "date_to" date, "lastcall" date, "nextcall" date, "notes" text, "active" boolean, "multi_employee" boolean, "already_accrued" boolean, "create_date" timestamp without time zone, "write_date" timestamp without time zone, "number_of_days" double precision, "overtime_id" integer);
 CREATE TABLE "hr_leave_employee_type_report" ("id" bigint, "employee_id" integer, "active_employee" boolean, "number_of_days" double precision, "department_id" integer, "leave_type" integer, "holiday_status" text, "state" character varying, "date_from" timestamp without time zone, "date_to" timestamp without time zone, "company_id" integer);
 CREATE TABLE "hr_leave_mandatory_day" ("id" integer DEFAULT nextval('hr_leave_mandatory_day_id_seq'::regclass) NOT NULL, "company_id" integer NOT NULL, "color" integer, "resource_calendar_id" integer, "create_uid" integer, "write_uid" integer, "name" character varying NOT NULL, "start_date" date NOT NULL, "end_date" date NOT NULL, "create_date" timestamp without time zone, "write_date" timestamp without time zone);
 CREATE TABLE "hr_leave_report" ("id" bigint, "leave_id" integer, "employee_id" integer, "name" character varying, "number_of_days" double precision, "leave_type" text, "category_id" integer, "department_id" integer, "holiday_status_id" integer, "state" character varying, "holiday_type" character varying, "date_from" timestamp without time zone, "date_to" timestamp without time zone, "company_id" integer);
 CREATE TABLE "hr_leave_report_calendar" ("id" integer, "name" text, "start_datetime" timestamp without time zone, "stop_datetime" timestamp without time zone, "employee_id" integer, "state" character varying, "department_id" integer, "duration" double precision, "company_id" integer, "job_id" integer, "tz" character varying, "is_striked" boolean, "is_hatched" boolean);
"""


In [26]:
dataBaseType = "POSTGRES"
dataBaseName = "o2maroc"

conn_params = {
    'host': '141.95.164.109',
    'database': 'o2maroc',
    'user': 'odoo',
    'password': 'odoo17@2023',
    'port': 10019
}

In [27]:
table_names = extract_table_names(example_schema)
print(table_names)
context_result = get_context(dataBaseType, table_names, conn_params)
# print(context_result)

in postgres case


In [29]:
', '.join([f"'{table}'" for table in table_names])



In [36]:
question = "Give me all the information about the  employees  that work in administration departement "
generated_sql1 = generate_query(dataBaseType,question, example_schema, context_result)
print(sqlparse.format(generated_sql1, reindent=True))

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [30]:
torch.cuda.empty_cache()


In [33]:
print(f"Number of input tokens: {len(tokenizer.tokenize(example_schema+prompt_template+context_result))}")

Number of input tokens: 6630


In [34]:
print(f"Number of tokens: {len(tokenizer.tokenize('hello world, I`m tarik how are you doing? I hope everyone is doing great'))}")

Number of tokens: 19


In [None]:
`