<a href="https://colab.research.google.com/github/pavankumarbalijepalli/pr-phi2-vs-defog/blob/main/inf_phi2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Data Exploration Imports
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import transformers
from peft import LoraConfig, PeftConfig
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer,
                          BitsAndBytesConfig,
                          TrainingArguments,
                          pipeline,
                          logging)
from sklearn.metrics import (accuracy_score,
                             classification_report,
                             confusion_matrix)

In [2]:
!nvidia-smi

Wed Feb 21 21:34:30 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 441.37       Driver Version: 441.37       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 166... WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   69C    P8     7W /  N/A |    729MiB /  6144MiB |      2%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|    0  

In [43]:
test_df = pd.read_csv('../data/test.csv').drop("Unnamed: 0", axis=1)
test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7858 entries, 0 to 7857
Data columns (total 11 columns):
 #   Column           Non-Null Count  Dtype 
---  ------           --------------  ----- 
 0   question         7858 non-null   object
 1   context          7858 non-null   object
 2   answer           7858 non-null   object
 3   table_count      7858 non-null   int64 
 4   sub_query_count  7858 non-null   int64 
 5   joins_count      7858 non-null   int64 
 6   where_count      7858 non-null   int64 
 7   group_by_count   7858 non-null   int64 
 8   columns_count    7858 non-null   int64 
 9   complexity       7858 non-null   int64 
 10  difficulty       7858 non-null   object
dtypes: int64(7), object(4)
memory usage: 675.4+ KB


In [102]:
from sklearn.model_selection import train_test_split
df_a, df_b = train_test_split(test_df, test_size=0.2, stratify=test_df["difficulty"])
df_a, df_c = train_test_split(df_a, test_size=0.2, stratify=df_a["difficulty"])

In [103]:
df_a['difficulty'].value_counts(), df_b['difficulty'].value_counts(), df_c['difficulty'].value_counts()

(difficulty
 easy      3484
 medium    1519
 hard        25
 Name: count, dtype: int64,
 difficulty
 easy      1089
 medium     475
 hard         8
 Name: count, dtype: int64,
 difficulty
 easy      872
 medium    380
 hard        6
 Name: count, dtype: int64)

In [6]:
def generate_test_prompt(data_point):
    return f"""### Task
Generate a SQL query to answer the following question:
`{data_point['question']}`

### Database Schema
The query will run on a database with the following schema:
{data_point['context']}

### Answer
Given the database schema, here is the SQL query that answers `{data_point['question']}`:
```sql""".strip()

In [9]:
from llama_cpp import Llama

phi2 = Llama(model_path="../phi2_sqlcoder_f16.gguf")

llama_model_loader: loaded meta data with 19 key-value pairs and 453 tensors from ../phi2_sqlcoder_f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = phi2
llama_model_loader: - kv   1:                               general.name str              = Phi2
llama_model_loader: - kv   2:                        phi2.context_length u32              = 2048
llama_model_loader: - kv   3:                      phi2.embedding_length u32              = 2560
llama_model_loader: - kv   4:                   phi2.feed_forward_length u32              = 10240
llama_model_loader: - kv   5:                           phi2.block_count u32              = 32
llama_model_loader: - kv   6:                  phi2.attention.head_count u32              = 32
llama_model_loader: - kv   7:               phi2.attention.head_count_kv u32              =

In [81]:
from datetime import datetime

def predict(df, llm, out):
    for point in tqdm(df.iloc, desc="No. of rows", total=df.shape[0]):
        start = datetime.now()
        prompt = generate_test_prompt(point)
        out['prompt'].append(prompt)
        out['actu'].append(point['answer'])
        result = llm(prompt=prompt,
            max_tokens = 50,
            temperature = 0.2,
            stop = ['```'])
        answer = result
        end = datetime.now()
        out['inf_time'].append((end - start).total_seconds())
        out['pred'].append(answer['choices'][0]['text'].strip())
        out['temperature'].append(0.2)
        out['difficulty'].append(point['difficulty'])
        out['token_in'].append(result['usage']['prompt_tokens'])
        out['token_out'].append(result['usage']['completion_tokens']+1)
        out['tokens_per_sec'].append(result['usage']['completion_tokens']/((end - start).total_seconds()))
    return out

In [104]:
out = {"prompt": [], "pred": [], "actu": [], "inf_time": [], "temperature": [], "difficulty": [], "token_in": [], "token_out": [], "tokens_per_sec": []}

out = predict(df_c, phi2, out)
json.dump(out, open("phi2_eval_df_c.json", "w"))

No. of rows:   0%|       | 0/1258 [00:00<?, ?it/s]Llama.generate: prefix-match hit

llama_print_timings:        load time =   11457.89 ms
llama_print_timings:      sample time =       6.65 ms /    21 runs   (    0.32 ms per token,  3156.95 tokens per second)
llama_print_timings: prompt eval time =    9239.41 ms /    97 tokens (   95.25 ms per token,    10.50 tokens per second)
llama_print_timings:        eval time =    4370.24 ms /    20 runs   (  218.51 ms per token,     4.58 tokens per second)
llama_print_timings:       total time =   13702.39 ms /   117 tokens
No. of rows:   0%| | 1/1258 [00:13<4:47:14, 13.71sLlama.generate: prefix-match hit

llama_print_timings:        load time =   11457.89 ms
llama_print_timings:      sample time =       5.77 ms /    18 runs   (    0.32 ms per token,  3118.50 tokens per second)
llama_print_timings: prompt eval time =    7916.35 ms /    83 tokens (   95.38 ms per token,    10.48 tokens per second)
llama_print_timings:        eval time =    3641.99

In [112]:
del phi2
del out

In [107]:
sqlc = Llama(model_path="../sqlcoder-7b-q5_k_m.gguf")

llama_model_loader: loaded meta data with 22 key-value pairs and 291 tensors from ../sqlcoder-7b-q5_k_m.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = .
llama_model_loader: - kv   2:                       llama.context_length u32              = 16384
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              

In [115]:
out = {"prompt": [], "pred": [], "actu": [], "inf_time": [], "temperature": [], "difficulty": [], "token_in": [], "token_out": [], "tokens_per_sec": []}

out = predict(df_c, sqlc, out)
json.dump(out, open("sqlc_eval_df_c.json", "w"))

No. of rows:   0%|       | 0/1258 [00:00<?, ?it/s]Llama.generate: prefix-match hit

llama_print_timings:        load time =  138797.67 ms
llama_print_timings:      sample time =       8.40 ms /    32 runs   (    0.26 ms per token,  3811.34 tokens per second)
llama_print_timings: prompt eval time =   19408.26 ms /    94 tokens (  206.47 ms per token,     4.84 tokens per second)
llama_print_timings:        eval time =    8400.02 ms /    31 runs   (  270.97 ms per token,     3.69 tokens per second)
llama_print_timings:       total time =   27927.12 ms /   125 tokens
No. of rows:   0%| | 1/1258 [00:27<9:45:10, 27.93sLlama.generate: prefix-match hit

llama_print_timings:        load time =  138797.67 ms
llama_print_timings:      sample time =       8.25 ms /    31 runs   (    0.27 ms per token,  3755.75 tokens per second)
llama_print_timings: prompt eval time =   22683.84 ms /   103 tokens (  220.23 ms per token,     4.54 tokens per second)
llama_print_timings:        eval time =    8678.70

In [110]:
generate_test_prompt(test_df.iloc[0]), test_df.iloc[0]['answer']

('### Task\nGenerate a SQL query to answer the following question:\n`who directed with\xa0original air date\xa0being november18,1995`\n\n### Database Schema\nThe query will run on a database with the following schema:\nCREATE TABLE table_14637853_3 (directed_by VARCHAR, original_air_date VARCHAR)\n\n### Answer\nGiven the database schema, here is the SQL query that answers `who directed with\xa0original air date\xa0being november18,1995`:\n```sql',
 'SELECT directed_by FROM table_14637853_3 WHERE original_air_date = "November18,1995"')

In [120]:
hard_a = df_a[df_a['difficulty']=='hard']

In [121]:
hard_b = df_b[df_b['difficulty']=='hard']

In [125]:
hard_questions = pd.concat([hard_a, hard_b])

In [128]:
from datetime import datetime

def predict(df, llm, out):
    for point in tqdm(df.iloc, desc="No. of rows", total=df.shape[0]):
        start = datetime.now()
        prompt = generate_test_prompt(point)
        out['prompt'].append(prompt)
        out['actu'].append(point['answer'])
        result = llm(prompt=prompt,
            max_tokens = 150,
            temperature = 0.2,
            stop = ['```'])
        answer = result
        end = datetime.now()
        out['inf_time'].append((end - start).total_seconds())
        out['pred'].append(answer['choices'][0]['text'].strip())
        out['temperature'].append(0.2)
        out['difficulty'].append(point['difficulty'])
        out['token_in'].append(result['usage']['prompt_tokens'])
        out['token_out'].append(result['usage']['completion_tokens']+1)
        out['tokens_per_sec'].append(result['usage']['completion_tokens']/((end - start).total_seconds()))
    return out

In [129]:
from llama_cpp import Llama

phi2 = Llama(model_path="../phi2_sqlcoder_f16.gguf")
out = {"prompt": [], "pred": [], "actu": [], "inf_time": [], "temperature": [], "difficulty": [], "token_in": [], "token_out": [], "tokens_per_sec": []}

out = predict(hard_questions, phi2, out)
json.dump(out, open("phi2_eval_hard.json", "w"))

llama_model_loader: loaded meta data with 19 key-value pairs and 453 tensors from ../phi2_sqlcoder_f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = phi2
llama_model_loader: - kv   1:                               general.name str              = Phi2
llama_model_loader: - kv   2:                        phi2.context_length u32              = 2048
llama_model_loader: - kv   3:                      phi2.embedding_length u32              = 2560
llama_model_loader: - kv   4:                   phi2.feed_forward_length u32              = 10240
llama_model_loader: - kv   5:                           phi2.block_count u32              = 32
llama_model_loader: - kv   6:                  phi2.attention.head_count u32              = 32
llama_model_loader: - kv   7:               phi2.attention.head_count_kv u32              =

In [130]:
sqlc = Llama(model_path="../sqlcoder-7b-q5_k_m.gguf")
out = {"prompt": [], "pred": [], "actu": [], "inf_time": [], "temperature": [], "difficulty": [], "token_in": [], "token_out": [], "tokens_per_sec": []}

out = predict(hard_questions, sqlc, out)
json.dump(out, open("sqlc_eval_hard.json", "w"))

llama_model_loader: loaded meta data with 22 key-value pairs and 291 tensors from ../sqlcoder-7b-q5_k_m.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = .
llama_model_loader: - kv   2:                       llama.context_length u32              = 16384
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              

In [None]:
phi2(prompt=

In [132]:
schema = """CREATE TABLE DimDate (
  DateID int, -- Unique Identifier for Dates in the Dimension Table
  Calendar date, -- Represents the Calendar Date for DimDate Table
  DayNumberOfWeek tinyint, -- Day Number of Week Indicator
  EnglishDayNameOfWeek nvarchar, -- Represents the day of the week in English for the given date
  DayNumberOfMonth tinyint, -- Represents the day number of the month in the DimDate table
  DayNumberOfYear smallint, -- Represents the day number of the year in the DimDate table
  WeekNumberOfYear tinyint, -- Represents the week number of the year for the given date
  EnglishMonthName nvarchar, -- English Month Name Column represents the month in English
  MonthNumberOfYear tinyint, -- Represents the month number of the year in the DimDate table
  CalendarQuarter tinyint, -- Represents the Calendar Quarter associated with the date
  CalendarYear smallint, -- Represents the Calendar Year in the DimDate table
  CalendarSemester tinyint, -- Indicates the Semester of the Calendar Year
  ShortEnglishMonthName varchar, -- Short English Month Name Column: Displays the month name in a shortened English format
  MonthNum varchar, -- Represents the month number in the date format
  WeekNum varchar, -- Represents the week number associated with the date
  WeekDisplay varchar, -- Week Display Format for Date Representation
  FinancialYear int, -- Financial Year Representing the Year for Financial Analysis
  FinancialQuarter int, -- Represents the Financial quarter associated with the date
  PRIMARY KEY (DateID)
);

CREATE TABLE DimProductSubCategory (
  ProductSubCategoryID int, -- Unique Identifier for Product Sub-Category
  ProductCategoryID int, -- Unique Identifier for Product Category in Sub-Category Table
  SubCategoryName varchar, -- SubCategoryName represents the name of a product subcategory
  PRIMARY KEY (ProductSubCategoryID),
  FOREIGN KEY (DimProductSubCategory.ProductCategoryID) REFERENCES (DimProductCategory.ProductCategoryID)
);

CREATE TABLE DimOrders (
  OrderID int, -- Unique Identifier for Orders
  OrderDate date, -- Order Date Represents the Time When an Order Was Placed
  CustomerID int, -- Unique Identifier for Customers in Orders
  ClientIDNumber varchar, -- Unique Identifier for Clients in DimOrders
  ClientGUID varchar, -- Unique Identifier for Clients in DimOrders
  PRIMARY KEY (OrderID),
  FOREIGN KEY (DimOrders.ClientID) REFERENCES (DimClient.ClientID),
  FOREIGN KEY (DimOrders.EmployeeID) REFERENCES (DimEmployee.EmployeeID),
  FOREIGN KEY (DimOrders.ShipmentID) REFERENCES (DimShipment.ShipmentID),
  FOREIGN KEY (DimOrders.CustomerID) REFERENCES (DimCustomer.CustomerID)
);

CREATE TABLE DimCustomer (
  CustomerID int, -- Unique Identifier for Customers
  FirstName nvarchar, -- First Name of Customers
  MobileNumber nvarchar, -- Mobile Number for Customers
  ClientIDNumber varchar, -- Unique Identifier for Customers
  ClientGUID varchar, -- Unique Identifier for Each Customer
  PRIMARY KEY (CustomerID),
  FOREIGN KEY (DimCustomer.EmployeeID) REFERENCES (DimEmployee.EmployeeID)
);

CREATE TABLE DimProductCategory (
  ProductCategoryID int, -- Unique Identifier for Product Categories
  CategoryName varchar, -- Category Name representing the product's classification
  ClientIDNumber nvarchar, -- Unique Client ID Number for Product Category
  ClientGUID varchar, -- Unique Identifier for Client in Product Category Dimension
  PRIMARY KEY (ProductCategoryID)
);

CREATE TABLE FactSales (
  ProductID int, -- Unique Identifier for Products in Sales Facts
  EmployeeID int, -- Unique Identifier for Sales Employees
  UnitPrice float, -- Represents the unit price for each sale in the FactSales table
  ExtendedAmount float, -- Represents the extended amount in float format
  UnitPriceDiscountPct float, -- Represents the percentage discount applied to the unit price in sales transactions
  DiscountAmount float, -- Represents the discount amount applied to a sale
  ProductStandardCost float, -- Represents the standard cost of each product in the sales fact table
  TotalProductCost float, -- Represents the total product cost in the FactSales table
  SalesAmount float, -- Sales Amount in FactSales Table
  TaxAmt float, -- Represents the tax amount in the FactSales table
  Freight float, -- Freight Float represents additional charges for shipping products
  OrderDate date, -- Date of the Order Placed
  DueDate datetime, -- Represents the due date for each sales transaction
  ShipDate datetime, -- Represents the Date when the Shipment was Scheduled
  StoreID int, -- Unique Identifier for Stores in Sales Facts
  CustomerID int, -- Unique Identifier for Customers in Sales Facts
  ChannelID int, -- Unique Identifier for Sales Channels
  AuditID int, -- Unique Identifier for Audit Records in FactSales
  DateID int, -- Unique Identifier for Sales Dates
  ClientID int, -- Unique Identifier for Clients in FactSales
  OrderQuantity int, -- Order Quantity represents the number of items purchased in a sale
  UnitsSold int, -- Number of Units Sold in Sales Facts
  UnitsIn int, -- Number of Units Sold in FactSales
  UnitsOut int, -- Units Out represents the number of products sold in a specific time period
  ProductCategoryID int, -- Represents the Product Category ID for each sale in the FactSales table
  SalesOrderID int, -- Unique Identifier for Sales Orders
  SalesOrderDetailID int, -- Unique Identifier for Sales Order Details
  CarrierTrackingNumber nvarchar, -- Unique Identifier for Tracking Shipments
  SpecialOfferID int, -- Unique Identifier for Special Offers in Sales Facts
  LineTotal float, -- LineTotal represents the total amount for each sale line in the FactSales table
  ClientIDNumber varchar, -- Unique Identifier for Clients in Sales Facts
  ClientGUID varchar, -- Unique Identifier for Clients in FactSales
  TimeZone varchar, -- Represents the time zone associated with the sales data
  BucketRange varchar, -- Represents the sales bucket range
  Returns int, -- Represents the number of sales transactions
  ReturnID int, -- Unique Identifier for Return Transactions
  PRIMARY KEY (),
  FOREIGN KEY (FactSales.ProductID) REFERENCES (DimProduct.ProductID),
  FOREIGN KEY (FactSales.EmployeeID) REFERENCES (DimEmployee.EmployeeID),
  FOREIGN KEY (FactSales.StoreID) REFERENCES (DimStore.StoreID),
  FOREIGN KEY (FactSales.CustomerID) REFERENCES (DimCustomer.CustomerID),
  FOREIGN KEY (FactSales.ChannelID) REFERENCES (DimChannel.ChannelID),
  FOREIGN KEY (FactSales.AuditID) REFERENCES (DimAudit.AuditID),
  FOREIGN KEY (FactSales.DateID) REFERENCES (DimDate.DateID),
  FOREIGN KEY (FactSales.ClientID) REFERENCES (DimClient.ClientID)
);

CREATE TABLE DimRegion (
  RegionID int, -- Unique Identifier for Regions
  Level4Value varchar, -- represents City names
  Level3Value varchar, -- represents State Names
  Level2Value varchar, -- represents Country Names
  ClientIDNumber varchar, -- Unique Identifier for Clients
  ClientGuid varchar, -- Unique Identifier for Clients in DimRegion
  Level1Value nvarchar, -- Represents Region Names (N, S, E, W)
  PRIMARY KEY (RegionID)
);

CREATE TABLE DimProduct (
  ProductID int, -- Unique Identifier for Products
  ProductName varchar, -- Product Name represents the unique name of the item in the DimProduct table
  Colour nvarchar, -- Represents the product's colour
  BrandID int, -- Unique Identifier for Product Brand
  ProductSubCategoryID int, -- Identifies the Subcategory of a Product
  PRIMARY KEY (ProductID),
  FOREIGN KEY (DimProduct.BrandID) REFERENCES (DimBrand.BrandID),
  FOREIGN KEY (DimProduct.OrderID) REFERENCES (DimOrders.OrderID),
  FOREIGN KEY (DimProduct.ProductCategoryID) REFERENCES (DimProductCategory.ProductCategoryID)
);

CREATE TABLE DimStore (
  StoreID int, -- Unique Identifier for Each Store
  StoreUUID int, -- Unique Identifier for Each Store
  StoreName varchar, -- Store Name represents the name of the store in the DimStore table
  YearOpened varchar, -- Year when the store was opened
  SquareFt int, -- Represents the square footage of a store
  BussinessType varchar, -- Represents the business type associated with the store
  City varchar, -- City representing the location of the store
  State varchar, -- Represents the state associated with the store
  Country varchar, -- Country Represents the Country of the DimStore
  PromotionID int, -- Unique Identifier for Promotions
  WarehouseID int, -- Unique Identifier for the Warehouse in the DimStore Table
  RegionID int, -- RegionID represents the unique identifier for the geographical area associated with the store
  ClientIDNumber varchar, -- Unique Identifier for Clients
  ClientGuid varchar, -- Unique Identifier for Clients in the DimStore Table
  StoreType varchar, -- Description: Indicates the type of store associated with each record in the DimStore table
  PRIMARY KEY (StoreID),
  FOREIGN KEY (DimStore.PromotionID) REFERENCES (DimPromotion.PromotionID),
  FOREIGN KEY (DimStore.CustomerID) REFERENCES (DimCustomer.CustomerID),
  FOREIGN KEY (DimStore.WarehouseID) REFERENCES (DimWarehouse.WarehouseID),
  FOREIGN KEY (DimStore.RegionID) REFERENCES (DimRegion.RegionID)
);

CREATE TABLE DimSupplier (
  SupplierID int, -- Unique Identifier for Suppliers
  CompanyName varchar, -- Company Name represents the name of the supplier
  Address varchar, -- Address of Supplier
  PostalCode int, -- Postal Code represents the area code for suppliers
  ClientIDNumber nvarchar, -- Unique Identifier for Client in Supplier Relationship
  ClientGUID varchar, -- Unique Identifier for Client in DimSupplier Table
  SupplierCategoryID int, -- Identifies the Supplier Category ID
  PaymentDays int, -- Represents the number of days it takes for a supplier to receive payment
  PRIMARY KEY (SupplierID)
);

CREATE TABLE DimBrand (
  BrandID int, -- Unique Identifier for Brands
  BrandName varchar, -- Brand Name of the Entity
  PRIMARY KEY (BrandID)
);
"""

question = "Which store has maximum sales for Maza Spl"

prompt = f"""### Task
Generate a SQL query to answer the following question:
`{question}`

### Database Schema
The query will run on a database with the following schema:
{schema}

### Answer
Given the database schema, here is the SQL query that answers `{question}`:
```sql"""

In [134]:
sqlc(prompt=prompt,
    max_tokens = 150,
    temperature = 0.2,
    stop = ['```'])

ValueError: Requested tokens (2742) exceed context window of 512