In [1]:
%load_ext autoreload
%autoreload 2

import sys  
sys.path.insert(1, '../')

In [59]:
from nl2sql_src.nl2sql_generic import Nl2sqlBq_rag
import pandas as pd
from pandas import DataFrame

import json

import vertexai
from vertexai.language_models import TextGenerationModel

In [60]:
project_id = "cdii-poc"
dataset_id = "HHS_Program_Counts"

In [61]:
questions = ["How many people are enrolled in CalFresh?",
             "Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program?",
             "What county has the greatest enrollment in WIC per capita?",
              #"How many Black individuals are served across CalHHS programs?",
              "Which counties have the highest and lowest ratios of providers to enrolled participants in Medi-Cal?",
             ]

In [62]:
question = questions[1]
question
# question = 'Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program? '

'Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program?'

### Generate SQL with zero-shot prompting

In [63]:
# Initializing when metadata cache is already created
meta_data_json_path = "../cache_metadata/metadata_cache.json"
# meta_data_json_path = "../cache_metadata/updated-metadata.json"
nl2sqlbq_client = Nl2sqlBq_rag(project_id=project_id,
                           dataset_id=dataset_id,
                           metadata_json_path = meta_data_json_path, #"../cache_metadata/metadata_cache.json",
                           model_name="text-bison"
                           # model_name="code-bison"
                          )


In [64]:
print(question)
table_info = nl2sqlbq_client.table_filter(question)
table_info

sql_query = nl2sqlbq_client.text_to_sql(question)
print(sql_query)

Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
 
SELECT
  Number AS County,
  ROUND(
    (CAST(_2_programs AS FLOAT64) + CAST(_3_programs AS FLOAT64) + CAST(_4_programs AS FLOAT64) + CAST(_5__programs AS FLOAT64)) /
    CAST(Person AS FLOAT64),
    2
  ) AS Coenrollment_Rate
FROM cdii-poc.HHS_Program_Counts.`calhhs-dashboard-2015-2020-annual-data-file`
WHERE
  Program = "CalFresh"
ORDER BY
  Coenrollment_Rate DESC
LIMIT 1;



## Using PostgreSQL and VectorDB indices to search for Closest matching queries for Few-shot prompting

### Initialize the PostgreDB
Please make sure the PostgreDB instance and database are created

In [65]:
# Table name is 'documents'

PGPROJ = "cdii-poc"
PGLOCATION = 'us-central1'
PGINSTANCE = "cdii-demo-temp"
PGDB = "demodbcdii"
PGUSER = "postgres"
PGPWD = "cdii-demo"

nl2sqlbq_client.init_pgdb(PGPROJ, PGLOCATION, PGINSTANCE, PGDB, PGUSER, PGPWD)




### Sql Generation with few-shot prompting

In [66]:
print(question)
table_info = nl2sqlbq_client.table_filter(question)
table_info

sql_query = nl2sqlbq_client.text_to_sql_fewshot(question)
print(sql_query)

Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
Table name  calhhs-dashboard-2015-2020-annual-data-file
 
SELECT Number AS county, 
       COALESCE(SUM(CAST(Person AS INT64)), 0) AS total_calfresh_recipients,
       COALESCE(SUM(CAST(_2_programs AS INT64)), 0) AS calfresh_coenrolled_in_2_programs,
       COALESCE(SUM(CAST(_3_programs AS INT64)), 0) AS calfresh_coenrolled_in_3_programs,
       COALESCE(SUM(CAST(_4_programs AS INT64)), 0) AS calfresh_coenrolled_in_4_programs,
       COALESCE(SUM(CAST(_5__programs AS INT64)), 0) AS calfresh_coenrolled_in_5_programs,
       (COALESCE(SUM(CAST(_2_programs AS INT64)), 0) + 
        COALESCE(SUM(CAST(_3_programs AS INT64)), 0) + 
        COALES

In [67]:
# nl2sqlbq_client_raw = Nl2sqlBq(project_id=project_id,
#                            dataset_id=dataset_id,
#                            metadata_json_path = "metadata_cache.json",
#                            model_name="text-bison"
#                           )
# question = questions[1]
# print(question)
# table_info = nl2sqlbq_client_raw.table_filter(question)
# table_info

# sql_query = nl2sqlbq_client_raw.text_to_sql(question)
# print(sql_query)

In [68]:
all_questions = ["How many people are enrolled in CalFresh?",
    "How many of them live in Los Angeles County?",
    "How has participation in CalFresh changed since 2015?",
    "How do CalFresh program participation trends differ by race and ethnicity?",
    "How have these race and ethnicity trends changed over time?",
    "Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program?",
    "What about three or more additional programs?",
    "Which programs have the highest co-enrollment with CalFresh?",
    "What county has the greatest enrollment in WIC per capita?",
    "Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?",
    "How do infant mortality rates, low birthweight rates, and preterm and very preterm rates compare to WIC enrollment rates by county?",
    "How many Black individuals are served across CalHHS programs?",
    "What is the breakdown by program?",
    "Has this changed over time?",
    "Change over time by program?",
    "Which counties have the highest and lowest ratios of providers to enrolled participants in Medi-Cal?",
    "What is the ratio of non-suspended doctors to Medi-Cal members by County?",
    "What about the ratio to licensed facilities?"
            ]

In [69]:
print(len(all_questions))
results = []
failed_questions = []
for question in all_questions:
    tmp = {}
    print(question)
    tmp['question'] = question
    tmp['generated_sql'] = "SQL not generated"
    try:
        table_info = nl2sqlbq_client.table_filter(question)
        table_info

        sql_query = nl2sqlbq_client.text_to_sql_fewshot(question)
        print(sql_query)
        tmp['question'] = question
        tmp['generated_sql'] = sql_query
    except:
        failed_questions.append(question)
        pass
    
    results.append(tmp)

18
How many people are enrolled in CalFresh?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
How many people are enrolled in CalFresh?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
Table name  calhhs-dashboard-2015-2020-annual-data-file
 
SELECT 
  fileyear,
  Number,
  CalFresh
FROM `cdii-poc.HHS_Program_Counts.calhhs-dashboard-2015-2020-annual-data-file`
WHERE 
  Program = "CalFresh";

How many of them live in Los Angeles County?
Table Filter -  ['20230912_births_final_county_month_sup']
How many of them live in Los Angeles County?
Table Filter -  ['20230912_births_final_county_month_sup']
Table name  20230912_births_final_county_month_sup
 
SELECT 
  SUM(CAST(Count AS INT64)) AS total_births
FROM cdii-poc.HHS_Program_Counts.`20230912_births_final_county_month_sup`
WHERE 
  County = 'Los Angeles'
  AND Geography_Type = 'Residence';

How has participation in CalFresh changed since 2015?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
H

Traceback (most recent call last):
  File "/home/jupyter/nl2sql-fiserv/temp/lib-nl2sql/nl_2_sql_lib/final_lib/notebooks/../nl2sql_src/nl2sql_generic.py", line 314, in text_to_sql_fewshot
    table_json = self.metadata_json[table_name]
KeyError: 'low-and-very-low-birthweight-by-race-ethnicity-2014-2018\npreterm-and-very-preterm-births-by-raceethnicity-2010-2018'


Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
Which county has the greatest proportion of CalFresh recipients co-enrolled in at least one additional program?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
Table name  calhhs-dashboard-2015-2020-annual-data-file
 
SELECT Number AS county, 
       COALESCE(SUM(CAST(Person AS INT64)), 0) AS total_calfresh_recipients,
       COALESCE(SUM(CAST(_2_programs AS INT64)), 0) AS calfresh_coenrolled_in_2_programs,
       COALESCE(SUM(CAST(_3_programs AS INT64)), 0) AS calfresh_coenrolled_in_3_programs,
       COALESCE(SUM(CAST(_4_programs AS INT64)), 0) AS calfresh_coenrolled_in_4_programs,
       COALESCE(SUM(CAST(_5__programs AS INT64)), 0) AS calfresh_coenrolled_in_5_programs,
       (COALESCE(SUM(CAST(_2_programs AS INT64)), 0) + 
        COALESCE(SUM(CAST(_3_programs AS INT64)), 0) + 
        COALESCE(SUM(CAST(_4_programs AS INT64)), 0) + 
        COALESCE(SUM(CAST(_5__programs AS INT64)), 0)) / 
       COALE

Traceback (most recent call last):
  File "/home/jupyter/nl2sql-fiserv/temp/lib-nl2sql/nl_2_sql_lib/final_lib/notebooks/../nl2sql_src/nl2sql_generic.py", line 314, in text_to_sql_fewshot
    table_json = self.metadata_json[table_name]
KeyError: '2021-2022-part-c-wic-redemptions-by-vendor-county-with-family-counts'


Table Filter -  ['wic-redemption-by-county-by-participant-category-data-2010-2018']
How do infant mortality rates, low birthweight rates, and preterm and very preterm rates compare to WIC enrollment rates by county?
Table Filter -  ['wic-redemption-by-county-by-participant-category-data-2010-2018']
Table name  wic-redemption-by-county-by-participant-category-data-2010-2018
 
WITH WIC_DATA AS (
    SELECT
        Vendor_Location,
        Year_Month,
        _Number_of_Participants_Redeemed_
    FROM `cdii-poc.HHS_Program_Counts.wic-redemption-by-county-by-participant-category-data-2010-2018`````
    WHERE
        Participant_Category = 'Infant'
),
INFANT_MORTALITY_DATA AS (
    SELECT
        County,
        Year,
        Infant_Mortality_Rate
    FROM `cdii-poc.HHS_Program_Counts.infant-mortality-rates-by-county-2010-2018
),
LOW_BIRTHWEIGHT_DATA AS (
    SELECT
        County,
        Year,
        Low_Birthweight_Rate
    FROM `cdii-poc.HHS_Program_Counts.low-birthweight-rates-by-coun

Traceback (most recent call last):
  File "/home/jupyter/nl2sql-fiserv/temp/lib-nl2sql/nl_2_sql_lib/final_lib/notebooks/../nl2sql_src/nl2sql_generic.py", line 314, in text_to_sql_fewshot
    table_json = self.metadata_json[table_name]
KeyError: 'calhhs-dashboard-2015-2020-annual-data-file\ncalhhs-dashboard-2015-2020-july-data-file'


Table Filter -  ['2010-2018-part-a-wic-redemptions-by-vendor-county-with-family-counts\n2019-2020-part-b-wic-redemptions-by-vendor-county-with-family-counts\nwic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c']
Has this changed over time?
Table Filter -  ['2010-2018-part-a-wic-redemptions-by-vendor-county-with-family-counts\n2019-2020-part-b-wic-redemptions-by-vendor-county-with-family-counts\nwic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c']
Table name  2010-2018-part-a-wic-redemptions-by-vendor-county-with-family-counts
2019-2020-part-b-wic-redemptions-by-vendor-county-with-family-counts
wic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c
Change over time by program?


Traceback (most recent call last):
  File "/home/jupyter/nl2sql-fiserv/temp/lib-nl2sql/nl_2_sql_lib/final_lib/notebooks/../nl2sql_src/nl2sql_generic.py", line 314, in text_to_sql_fewshot
    table_json = self.metadata_json[table_name]
KeyError: '2010-2018-part-a-wic-redemptions-by-vendor-county-with-family-counts\n2019-2020-part-b-wic-redemptions-by-vendor-county-with-family-counts\nwic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c'


Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
Change over time by program?
Table Filter -  ['calhhs-dashboard-2015-2020-annual-data-file']
Table name  calhhs-dashboard-2015-2020-annual-data-file
 
SELECT fileyear,
       Program,
       SUM(Person) AS total_enrolled
FROM `cdii-poc.HHS_Program_Counts.calhhs-dashboard-2015-2020-annual-data-file`
GROUP BY 1,2
ORDER BY fileyear, total_enrolled DESC;

Which counties have the highest and lowest ratios of providers to enrolled participants in Medi-Cal?
Table Filter -  ['medi-cal-and-calfresh-enrollment']
Which counties have the highest and lowest ratios of providers to enrolled participants in Medi-Cal?
Table Filter -  ['medi-cal-and-calfresh-enrollment']
Table name  medi-cal-and-calfresh-enrollment
 
SELECT
  County,
  (
    COUNT(DISTINCT provider.NPI) /
    SUM(CAST(Number_of_Beneficiaries AS FLOAT64))
  ) * 100 AS provider_to_beneficiary_ratio,
  RANK() OVER (ORDER BY provider_to_beneficiary_ratio DESC) AS highest_rank,


In [70]:
# Generate SQL for a specific question and add to the list
# tmp = {}
# question = all_questions[15]
# print(question)
# table_info = nl2sqlbq_client.table_filter(question)
# table_info

# sql_query = nl2sqlbq_client.text_to_sql_fewshot(question)
# print(sql_query)
# tmp['question'] = question
# tmp['generated_sql'] = sql_query
    
# results.append(tmp)

In [71]:
results

[{'question': 'How many people are enrolled in CalFresh?',
  'generated_sql': ' \nSELECT \n  fileyear,\n  Number,\n  CalFresh\nFROM `cdii-poc.HHS_Program_Counts.calhhs-dashboard-2015-2020-annual-data-file`\nWHERE \n  Program = "CalFresh";\n'},
 {'question': 'How many of them live in Los Angeles County?',
  'generated_sql': " \nSELECT \n  SUM(CAST(Count AS INT64)) AS total_births\nFROM cdii-poc.HHS_Program_Counts.`20230912_births_final_county_month_sup`\nWHERE \n  County = 'Los Angeles'\n  AND Geography_Type = 'Residence';\n"},
 {'question': 'How has participation in CalFresh changed since 2015?',
  'generated_sql': ' \nSELECT\n  fileyear,\n  SUM(CAST(Person AS INT64)) AS total_enrolled\nFROM `cdii-poc.HHS_Program_Counts.calhhs-dashboard-2015-2020-annual-data-file`\nWHERE\n  Program = "CalFresh"\nGROUP BY\n  fileyear\nORDER BY\n  fileyear;\n'},
 {'question': 'How do CalFresh program participation trends differ by race and ethnicity?',
  'generated_sql': " \nSELECT\n  fileyear,\n  COALES

In [72]:
len(results)

18

In [73]:
columns = ['question', 'generated_sql']
df = pd.DataFrame(results, columns=columns)
df.to_csv('output.csv', index=False)

In [74]:
df

Unnamed: 0,question,generated_sql
0,How many people are enrolled in CalFresh?,"\nSELECT \n fileyear,\n Number,\n CalFresh..."
1,How many of them live in Los Angeles County?,\nSELECT \n SUM(CAST(Count AS INT64)) AS tot...
2,How has participation in CalFresh changed sinc...,"\nSELECT\n fileyear,\n SUM(CAST(Person AS I..."
3,How do CalFresh program participation trends d...,"\nSELECT\n fileyear,\n COALESCE(SUM(White),..."
4,How have these race and ethnicity trends chang...,SQL not generated
5,Which county has the greatest proportion of Ca...,"\nSELECT Number AS county, \n COALESCE(..."
6,What about three or more additional programs?,"\nSELECT\n Program,\n COALESCE(SUM(CAST(_3_..."
7,Which programs have the highest co-enrollment ...,"\nSELECT\n Program,\n COALESCE(SUM(Person),..."
8,What county has the greatest enrollment in WIC...,"\nSELECT\n Vendor_Location,\n COALESCE(\n ..."
9,Which five counties have the lowest number of ...,SQL not generated


In [75]:
fq1 = "How have these race and ethnicity trends changed over time?"
table_name1 = 'calhhs-dashboard-2015-2020-annual-data-file'

fq2 = "Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?"
table_name2 = 'wic-redemption-by-county-by-participant-category-data-2010-2018'

fq3 = "What is the breakdown by program?"
table_name3 = 'calhhs-dashboard-2015-2020-annual-data-file'

fq4 = "Has this changed over time?"
table_name4 = 'calhhs-dashboard-2015-2020-annual-data-file'

fq5 = "What about the ratio to licensed facilities?"
table_name5 = 'calhhs-dashboard-2015-2020-annual-data-file'

In [76]:
regen_question = fq1
regen_table = table_name1

# Generate SQL by passing the table name as well

sql_query = nl2sqlbq_client.text_to_sql_fewshot(regen_question, table_name=regen_table)
print(sql_query)

How have these race and ethnicity trends changed over time?
Table name  calhhs-dashboard-2015-2020-annual-data-file
 
SELECT
  fileyear,
  SUM(White) AS total_white,
  SUM(Black) AS total_black,
  SUM(Hispanic) AS total_hispanic,
  SUM(Asian_PI) AS total_asian,
  SUM(Native_American) AS total_native_american
FROM `cdii-poc.HHS_Program_Counts.calhhs-dashboard-2015-2020-annual-data-file`
GROUP BY
  fileyear
ORDER BY
  fileyear;



In [77]:
regen_question = fq2
regen_table = table_name2

# Generate SQL by passing the table name as well

sql_query = nl2sqlbq_client.text_to_sql_fewshot(regen_question, table_name=regen_table)
print(sql_query)

Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?
Table name  wic-redemption-by-county-by-participant-category-data-2010-2018
 
SELECT Vendor_Location,
       (vendor_cnt/total_participants)*100 AS vendor_participants_ratio
FROM (
    SELECT TRIM(Vendor_Location) AS Vendor_Location,
           COALESCE(SUM(SAFE_CAST(_Number_of_Participants_Redeemed_ AS INT64)), 0) AS total_participants
    FROM `cdii-poc.HHS_Program_Counts.wic-redemption-by-county-by-participant-category-data-2010-2018``
    GROUP BY Vendor_Location) AS participants
JOIN (
    SELECT TRIM(COUNTY) AS COUNTY,
           COUNT(VENDOR) AS vendor_cnt
    FROM `cdii-poc.HHS_Program_Counts.women-infants-and-children-wic-authorized-vendors
    GROUP BY COUNTY
    HAVING COUNTY IS NOT NULL) AS vendors
ON UPPER(participants.Vendor_Location) = UPPER(vendors.COUNTY)
WHERE (vendor_cnt/total_participants)*100 IS NOT NULL
ORDER BY vendor_participants_ratio ASC
LIMIT 5;



In [78]:
regen_question = fq3
regen_table = table_name3

# Generate SQL by passing the table name as well

sql_query = nl2sqlbq_client.text_to_sql_fewshot(regen_question, table_name=regen_table)
print(sql_query)

What is the breakdown by program?
Table name  calhhs-dashboard-2015-2020-annual-data-file
 
SELECT
  Program,
  COALESCE(SUM(CAST(Person AS INT64)), 0) AS total_enrolled
FROM `cdii-poc.HHS_Program_Counts.calhhs-dashboard-2015-2020-annual-data-file`
GROUP BY
  Program
ORDER BY
  total_enrolled DESC;



In [79]:
regen_question = fq4
regen_table = table_name4

# Generate SQL by passing the table name as well

sql_query = nl2sqlbq_client.text_to_sql_fewshot(regen_question, table_name=regen_table)
print(sql_query)

Has this changed over time?
Table name  calhhs-dashboard-2015-2020-annual-data-file
 
SELECT
  fileyear,
  COALESCE(SUM(SAFE_CAST(White AS INT64)), 0) AS total_white,
  COALESCE(SUM(SAFE_CAST(Black AS INT64)), 0) AS total_black,
  COALESCE(SUM(SAFE_CAST(Hispanic AS INT64)), 0) AS total_hispanic,
  COALESCE(SUM(SAFE_CAST(Asian_PI AS INT64)), 0) AS total_asian,
  COALESCE(SUM(SAFE_CAST(Native_American AS INT64)), 0) AS total_native_american
FROM `cdii-poc.HHS_Program_Counts.calhhs-dashboard-2015-2020-annual-data-file`
GROUP BY
  fileyear
ORDER BY
  fileyear;



In [80]:
regen_question = fq5
regen_table = table_name5

# Generate SQL by passing the table name as well

sql_query = nl2sqlbq_client.text_to_sql_fewshot(regen_question, table_name=regen_table)
print(sql_query)

What about the ratio to licensed facilities?
Table name  calhhs-dashboard-2015-2020-annual-data-file
 
SELECT
  annual_file.Number,
  (
    COUNT(DISTINCT provider.License_Number) /
    SUM(CAST(annual_file.Person AS INT64))
  ) * 100 AS ratio_to_licensed_facilities
FROM `cdii-poc.HHS_Program_Counts.calhhs-dashboard-2015-2020-annual-data-file` AS annual_file
JOIN
  cdii-poc.HHS_Program_Counts.calhhs_medi-cal_managed_care_provider_listing AS provider
ON
  annual_file.Number = provider.County
WHERE
  annual_file.Level = 'County'
  AND provider.RecordType = 'Provider'
GROUP BY
  annual_file.Number
ORDER BY
  ratio_to_licensed_facilities DESC
LIMIT 5;



In [81]:
failed_questions

['How have these race and ethnicity trends changed over time?',
 'Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?',
 'What is the breakdown by program?',
 'Has this changed over time?']