# Quantized model

In [39]:
from mlx_lm import load, generate 
import os
import subprocess

In [42]:
def generate_response(model_path, prompt, max_tokens=500) -> None:
    # load the model and tokenizer
    model, tokenizer = load(model_path)
    
    # create a message dictionary for the chat
    messages = [{"role": "user", "content": prompt}]
    
    # apply the chat template using the tokenizer
    formatted_prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    # generate the response from the model
    response = generate(
        model, tokenizer, 
        prompt=formatted_prompt, 
        max_tokens=max_tokens, 
        verbose=True
    )

In [43]:
mlx_model = "/Users/tesla/side/finetune_llm/mlx_model"
prompt = "What is SQL"
max_tokens=500

generate_response(mlx_model, prompt, max_tokens)

Prompt: <s>[INST] What is SQL[/INST]
SQL (Structured Query Language) is a programming language used to manage and manipulate relational databases. It was first introduced by IBM in the 1970s and has since become a standard for database management systems. SQL is used to create, modify, and query databases, as well as to define and manipulate the logical structure of a database.

SQL is a declarative language, meaning that it is used to specify what data is needed, rather than how to retrieve it. It is composed of several types of statements, including:

1. Data Definition Language (DDL) statements: These are used to create, alter, and drop database objects such as tables, views, and indexes.
2. Data Manipulation Language (DML) statements: These are used to insert, update, delete, and select data from tables.
3. Data Control Language (DCL) statements: These are used to control access to the database, such as granting or revoking privileges.

SQL is an essential tool for developers, data

## Finetuned with QLoRA

In [8]:
def run_command_with_live_output(command: list[str]) -> None:
    """
    Courtesy of ChatGPT:
    Runs a command and prints its output line by line as it executes.

    Args:
        command (List[str]): The command and its arguments to be executed.

    Returns:
        None
    """
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # print the output line by line
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
        
    # print the error output, if any
    err_output = process.stderr.read()
    if err_output:
        print(err_output)


## Prompt 1 

In [44]:
fused_model = "lora_fused_model"
adapter_path = "adapters.npz" 
max_tokens_str = str(max_tokens)

prompt = """Retrieve the names and species of all animals that weigh more than 500 kg
CREATE TABLE animals (id INT, name VARCHAR(20), species VARCHAR(20), weight DECIMAL(5, 2)); 
INSERT INTO animals (id, name, species, weight) VALUES (1, 'Elephant', 'African', 6000), 
(2, 'Lion', 'African', 400), (3, 'Hippo', 'African', 3000), (4, 'Tiger', 'Asian', 300), 
(5, 'Crane', 'African', 100), (6, 'Rhinoceros', 'African', 2000), (7, 'Zebra', 'African', 450), 
(8, 'Giraffe', 'African', 1200), (9, 'Bear', 'Eurasian', 600), (10, 'Crocodile', 'Nile', 700);"""

### Expected code
The correct expected response for the above prompt is the following:
```
SELECT name, species FROM animals WHERE weight > 500;
```

### Adapter file output

In [48]:
# define command
command = ['python', 'scripts/lora.py', '--model', mlx_model, '--adapter-file', adapter_path, '--max-tokens', max_tokens_str, '--prompt', prompt]

# run command and print results continuously
run_command_with_live_output(command)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loading pretrained model
Total parameters 1133.580M
Trainable parameters 0.852M
Loading datasets
Generating
Retrieve the names and species of all animals that weigh more than 500 kg
CREATE TABLE animals (id INT, name VARCHAR(20), species VARCHAR(20), weight DECIMAL(5, 2));
INSERT INTO animals (id, name, species, weight) VALUES (1, 'Elephant', 'African', 6000),
(2, 'Lion', 'African', 400), (3, 'Hippo', 'African', 3000), (4, 'Tiger', 'Asian', 300),
(5, 'Crane', 'African', 100), (6, 'Rhinoceros', 'African', 2000), (7, 'Zebra', 'African', 450),
(8, 'Giraffe', 'African', 1200), (9, 'Bear', 'Eurasian', 600), (10, 'Crocodile', 'Nile', 700);
SELECT name, species FROM animals WHERE weight > 500;
This query retrieves the names and species of all animals that weigh more than 500 kg. It first creates a table named 'animals' with four columns: id, name, species, and weight. It then inserts data into this table. Finally, it uses the WHERE clause to filter the data and select only the rows where the 

### Mistral-7B-Instruct-v0.3-4bit

In [49]:
generate_response(mlx_model, prompt, max_tokens)

Prompt: <s>[INST] Retrieve the names and species of all animals that weigh more than 500 kg
CREATE TABLE animals (id INT, name VARCHAR(20), species VARCHAR(20), weight DECIMAL(5, 2)); 
INSERT INTO animals (id, name, species, weight) VALUES (1, 'Elephant', 'African', 6000), 
(2, 'Lion', 'African', 400), (3, 'Hippo', 'African', 3000), (4, 'Tiger', 'Asian', 300), 
(5, 'Crane', 'African', 100), (6, 'Rhinoceros', 'African', 2000), (7, 'Zebra', 'African', 450), 
(8, 'Giraffe', 'African', 1200), (9, 'Bear', 'Eurasian', 600), (10, 'Crocodile', 'Nile', 700);[/INST]
To retrieve the names and species of all animals that weigh more than 500 kg from the provided table, you can use the SQL `SELECT` statement with a `WHERE` clause. Here's the SQL query:

```sql
SELECT name, species FROM animals WHERE weight > 500;
```

Executing this query on the given table will return the following results:

```
name     | species
---------|-------
Elephant | African
Hippo    | African
Rhinoceros| African
Crane    

### Fine-tuned Mistral-7B-Instruct-v0.3-4bit

In [50]:
generate_response(fused_model, prompt, max_tokens)

Prompt: <s>[INST] Retrieve the names and species of all animals that weigh more than 500 kg
CREATE TABLE animals (id INT, name VARCHAR(20), species VARCHAR(20), weight DECIMAL(5, 2)); 
INSERT INTO animals (id, name, species, weight) VALUES (1, 'Elephant', 'African', 6000), 
(2, 'Lion', 'African', 400), (3, 'Hippo', 'African', 3000), (4, 'Tiger', 'Asian', 300), 
(5, 'Crane', 'African', 100), (6, 'Rhinoceros', 'African', 2000), (7, 'Zebra', 'African', 450), 
(8, 'Giraffe', 'African', 1200), (9, 'Bear', 'Eurasian', 600), (10, 'Crocodile', 'Nile', 700);[/INST]

SELECT name, species FROM animals WHERE weight > 500;
This query retrieves the names and species of all animals that weigh more than 500 kg.
Prompt: 259 tokens, 238.570 tokens-per-sec
Generation: 38 tokens, 33.753 tokens-per-sec
Peak memory: 7.776 GB


### Prompt 2

In [51]:
prompt = """Count the number of unique users who engaged with virtual tours in 'January'?
CREATE TABLE user_activity (user_id INT, activity_date DATE); 
INSERT INTO user_activity (user_id, activity_date) VALUES (1, '2023-01-01'), (2, '2023-01-05'), 
(3, '2023-01-07'), (4, '2023-01-10'), (5, '2023-01-12');"""

### Expected code
The correct expected response for the above prompt is the following:
```
SELECT COUNT(DISTINCT user_id) FROM user_activity WHERE EXTRACT(MONTH FROM activity_date) = 1;
```

### Mistral-7B-Instruct-v0.3-4bit

In [52]:
generate_response(mlx_model, prompt, max_tokens)

Prompt: <s>[INST] Count the number of unique users who engaged with virtual tours in 'January'?
CREATE TABLE user_activity (user_id INT, activity_date DATE); 
INSERT INTO user_activity (user_id, activity_date) VALUES (1, '2023-01-01'), (2, '2023-01-05'), 
(3, '2023-01-07'), (4, '2023-01-10'), (5, '2023-01-12');[/INST]
To count the number of unique users who engaged with virtual tours in 'January', you can use the SQL query below:

```sql
SELECT COUNT(DISTINCT user_id) as unique_users
FROM user_activity
WHERE activity_date LIKE '2023-01-%';
```

This query selects the count of distinct user_id from the user_activity table where the activity_date starts with '2023-01-'. This will give you the number of unique users who engaged with virtual tours in January.

Please note that the year in the example is 2023, if you have a different year, you should replace it accordingly.

If you want to get the users who engaged with virtual tours in a specific year, you can replace the '2023-01-%' with 

### Fine-tuned Mistral-7B-Instruct-v0.3-4bit

In [53]:
generate_response(fused_model, prompt, max_tokens)

Prompt: <s>[INST] Count the number of unique users who engaged with virtual tours in 'January'?
CREATE TABLE user_activity (user_id INT, activity_date DATE); 
INSERT INTO user_activity (user_id, activity_date) VALUES (1, '2023-01-01'), (2, '2023-01-05'), 
(3, '2023-01-07'), (4, '2023-01-10'), (5, '2023-01-12');[/INST]

SELECT COUNT(DISTINCT user_id) FROM user_activity WHERE activity_date LIKE '2023-01-%';
This query counts the number of unique users who engaged with virtual tours in 'January' by filtering the user_activity table for rows where the activity_date is in January. The DISTINCT keyword is used to remove duplicates, and the LIKE operator is used to match any date in January.
Prompt: 135 tokens, 216.642 tokens-per-sec
Generation: 100 tokens, 33.565 tokens-per-sec
Peak memory: 7.776 GB


### Prompt 3

In [58]:
prompt = """How many hospitals are there in each state in the United States?
CREATE TABLE Hospitals (HospitalID INT, HospitalName VARCHAR(255), State VARCHAR(255)); 
INSERT INTO Hospitals (HospitalID, HospitalName, State) VALUES (1, 'Alaska General Hospital', 'Alaska'), (2, 'Bama Healthcare', 'Alabama'), (3, 'California Care', 'California'), (4, 'Delaware Hospital', 'Delaware'), (5, 'Florida Medical Center', 'Florida');"""

### Expected code
The correct expected response for the above prompt is the following:
```
SELECT State, COUNT(*) FROM Hospitals GROUP BY State;
```

### Mistral-7B-Instruct-v0.3-4bit

In [59]:
generate_response(mlx_model, prompt, max_tokens)

Prompt: <s>[INST] How many hospitals are there in each state in the United States?
CREATE TABLE Hospitals (HospitalID INT, HospitalName VARCHAR(255), State VARCHAR(255)); 
INSERT INTO Hospitals (HospitalID, HospitalName, State) VALUES (1, 'Alaska General Hospital', 'Alaska'), (2, 'Bama Healthcare', 'Alabama'), (3, 'California Care', 'California'), (4, 'Delaware Hospital', 'Delaware'), (5, 'Florida Medical Center', 'Florida');[/INST]
To provide the data for all 50 states, I'll provide you with the data for the first 10 states, and you can continue the pattern for the remaining states. Here's the SQL code for the first 10 states:

```sql
INSERT INTO Hospitals (HospitalID, HospitalName, State) VALUES
(11, 'Georgia General Hospital', 'Georgia'),
(12, 'Hawaii Medical Center', 'Hawaii'),
(13, 'Idaho Healthcare', 'Idaho'),
(14, 'Illinois Hospital', 'Illinois'),
(15, 'Indiana Medical Center', 'Indiana'),
(16, 'Iowa Health Services', 'Iowa'),
(17, 'Kansas Medical Center', 'Kansas'),
(18, 'Kentu

### Fine-tuned Mistral-7B-Instruct-v0.3-4bit

In [60]:
generate_response(fused_model, prompt, max_tokens)

Prompt: <s>[INST] How many hospitals are there in each state in the United States?
CREATE TABLE Hospitals (HospitalID INT, HospitalName VARCHAR(255), State VARCHAR(255)); 
INSERT INTO Hospitals (HospitalID, HospitalName, State) VALUES (1, 'Alaska General Hospital', 'Alaska'), (2, 'Bama Healthcare', 'Alabama'), (3, 'California Care', 'California'), (4, 'Delaware Hospital', 'Delaware'), (5, 'Florida Medical Center', 'Florida');[/INST]

SELECT COUNT(*) FROM Hospitals WHERE State = 'Alaska';
SELECT COUNT(*) FROM Hospitals WHERE State = 'Alabama';
SELECT COUNT(*) FROM Hospitals WHERE State = 'California';
SELECT COUNT(*) FROM Hospitals WHERE State = 'Delaware';
SELECT COUNT(*) FROM Hospitals WHERE State = 'Florida';
This query calculates the number of hospitals in each state by counting the number of rows in the Hospitals table where the State column matches the state name.
Prompt: 131 tokens, 210.787 tokens-per-sec
Generation: 120 tokens, 33.814 tokens-per-sec
Peak memory: 7.776 GB


### Prompt 4

In [73]:
prompt = """What is the average life expectancy in each country?
CREATE TABLE countries (country_id INT, country_name VARCHAR(255)); INSERT INTO countries VALUES (1, 'CountryA'), (2, 'CountryB'); CREATE TABLE life_expectancy (country_id INT, life_expectancy DECIMAL(5,2), year INT); INSERT INTO life_expectancy VALUES (1, 75.5, 2020), (2, 80.3, 2020);"""

### Expected code
The correct expected response for the above prompt is the following:
```
SELECT countries.country_name, AVG(life_expectancy) AS avg_life_expectancy FROM countries INNER JOIN life_expectancy ON countries.country_id = life_expectancy.country_id GROUP BY countries.country_name;
```

### Mistral-7B-Instruct-v0.3-4bit

In [74]:
generate_response(mlx_model, prompt, max_tokens)

Prompt: <s>[INST] What is the average life expectancy in each country?
CREATE TABLE countries (country_id INT, country_name VARCHAR(255)); INSERT INTO countries VALUES (1, 'CountryA'), (2, 'CountryB'); CREATE TABLE life_expectancy (country_id INT, life_expectancy DECIMAL(5,2), year INT); INSERT INTO life_expectancy VALUES (1, 75.5, 2020), (2, 80.3, 2020);[/INST]
It seems like you're trying to create a simple SQL database structure to store data about the life expectancy of different countries. However, I can't provide the actual average life expectancy for each country as it changes over time and varies based on numerous factors. Here's how your SQL tables would look:

```sql
CREATE TABLE countries (
    country_id INT PRIMARY KEY,
    country_name VARCHAR(255)
);

CREATE TABLE life_expectancy (
    country_id INT,
    life_expectancy DECIMAL(5,2),
    year INT,
    FOREIGN KEY (country_id) REFERENCES countries(country_id)
);

INSERT INTO countries VALUES (1, 'CountryA');
INSERT INTO c

### Fine-tuned Mistral-7B-Instruct-v0.3-4bit

In [75]:
generate_response(fused_model, prompt, max_tokens)

Prompt: <s>[INST] What is the average life expectancy in each country?
CREATE TABLE countries (country_id INT, country_name VARCHAR(255)); INSERT INTO countries VALUES (1, 'CountryA'), (2, 'CountryB'); CREATE TABLE life_expectancy (country_id INT, life_expectancy DECIMAL(5,2), year INT); INSERT INTO life_expectancy VALUES (1, 75.5, 2020), (2, 80.3, 2020);[/INST]

SELECT c.country_name, le.life_expectancy FROM countries c JOIN life_expectancy le ON c.country_id = le.country_id WHERE le.year = 2020;
This query joins the countries table with the life_expectancy table on the country_id column, and filters the results to only include rows where the year is 2020. It then selects the country_name and life_expectancy columns from the resulting table.
Prompt: 128 tokens, 257.002 tokens-per-sec
Generation: 107 tokens, 34.094 tokens-per-sec
Peak memory: 7.776 GB
