In [1]:
import pandas as pd
import gradio as gr
from custom_prompt_chain import *

In [2]:
test_file_name = "dataset/test_dataset_all_tables.csv"
test_df = pd.read_csv(test_file_name)
questions = test_df.question.tolist()
del questions[10]
del questions[1]

In [10]:
def b_clicked(b):
    if b:
        return gr.Button.update(interactive=True)
    else:
        return gr.Button.update(interactive=False, value=False)

with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
    gr.Markdown("Select a question or type your own and then click **Run**")
    with gr.Row():
        question = gr.Dropdown(choices=questions, 
                          label="Select or Type a question", allow_custom_value=True, scale=2.5)
        sql_to_nl = gr.Checkbox(label="Natural Language ouput")
        need_insights = gr.Checkbox(label="Detailed Insights", interactive=False)
    
    btn = gr.Button("Run")
        
    sql_out = gr.Textbox(label="SQL Query")
    results_out = gr.Textbox(label="Results")
    insights_out = gr.Textbox(label="Insights")
    
    sql_to_nl.change(fn = b_clicked, inputs = sql_to_nl, outputs = need_insights)
    btn.click(fn=gen_sql, inputs=[question, sql_to_nl, need_insights] , outputs=[sql_out, results_out, insights_out])

demo.launch(share=True)

  question = gr.Dropdown(choices=questions,


Running on local URL:  http://127.0.0.1:7890
Running on public URL: https://3794d024fd5fac534d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)






In [None]:
# DATA_DICT_LOCATION_GCS = "gs://fiserv/fiserv_data_dict.json"

In [2]:
!gsutil cp $DATA_DICT_LOCATION_GCS data_dictionary.json

Copying gs://fiserv/fiserv_data_dict.json...
- [1 files][  1.7 KiB/  1.7 KiB]                                                
Operation completed over 1 objects/1.7 KiB.                                      


In [1]:
log_bucket="nl2sql-logs"
data_dict_loc="fiserv_data_dict.json"
postprocessors=['case_handler_transform']

In [3]:
from nl2sql.nl2sql import AskBQ

In [2]:
#@title App Setup

import gradio as gr
import pygwalker as pyg
from nl2sql.nl2sql import AskBQ
from pygments import highlight
from pygments.lexers import SqlLexer
from pygments.formatters import HtmlFormatter

with gr.Blocks() as demo:
  gr.Markdown("# NL2SQL Experiment Demo")
  askbq = gr.State()

  with gr.Accordion("Initial Config") as accordion_init_config:
    input_location = gr.Textbox(
      label="Location",
      value="us-central1")
    input_project_id = gr.Textbox(
      label="Project ID",
      value="poc-project-guleria")
    input_dataset_id = gr.Textbox(
      label="Dataset ID",
      value="genai_poc")
    input_table_names = gr.Textbox(
      label="Comma Separated Table Names",
      value= "authorizations_search"
    )
    input_enum_option_limit = gr.Number(
      label="Maximum number of distinct items in a column to consider it an enum",
      value=20,
      minimum = 0,
      step=1,
      precision=0
    )
    input_result_row_limit = gr.Number(
      label="Maximum Row Limit",
      value=1000,
      minimum = 0,
      step=1,
      precision=0
    )
    input_include_data_dict = gr.Checkbox(
      label="Leverage predefined Data Dictionary?", value=False
    )
    input_include_postprocessor = gr.Checkbox(
      label="Leverage predefined Postprocessor for case correction?", value=False
    )
    input_obj_init = gr.Button(value="Save", variant='primary')

  with gr.Group(visible = False) as group_input:
    question = gr.Dropdown(
      label="Natural Language Question",
      multiselect = False,
      allow_custom_value = True,
      show_label = True,
      choices = [
        "What were the sales in the last week?",
      ]
    )
    input_run_nl = gr.Button(value="Run", variant='primary', interactive=False)

  text_buffer = gr.Textbox(
    value="Enter a question and click on 'Run'",
    interactive=False,
    container=False,
    visible=False
  )

  with gr.Group(visible = False) as group_output:
    with gr.Tab("Results"):
      with gr.Row(equal_height=True):
        html_sql = gr.HTML()
        dataframe_results = gr.Dataframe()
    with gr.Tab("BI"):
      html_results = gr.HTML()

  question.change(
    lambda q : (
      gr.update(interactive=True)
      if q
      else gr.update(interactive=False)
    ),
    [question],
    [input_run_nl]
  )

  input_run_nl.click(
    lambda question, obj : next(
      map(
        lambda result : [
          highlight(
            result.latest_sql,
            SqlLexer(),
            HtmlFormatter(style='native', noclasses=True)
          ).replace("background: #202020", "") if result.latest_sql else (
              "<br>".join(
                  ["QUERY GENERATION FAILED :"] + [
                      v.exception for v in result.logs.values()
                      if v.exception is not None
                  ]
              )
          ),
          result.latest_data,
          pyg.walk(result.latest_data, return_html=True) if result.latest_sql else "QUERY GENERATION FAILED",
          gr.update(visible=True),
          f"Execution ID : {result.id}"
        ],
        [obj(question)]
      ),
    ),
    [question, askbq],
    [html_sql, dataframe_results, html_results, group_output, text_buffer]
  )

  input_obj_init.click(
    lambda location, project_id, dataset_id, table_names, result_row_limit,
      enum_option_limit, include_data_dict, include_postprocessor : [
      AskBQ(
        location = location,
        project_id = project_id,
        dataset_id = dataset_id,
        table_names = [i.strip() for i in table_names.split(',')],
        enum_option_limit = enum_option_limit,
        result_row_limit = result_row_limit,
        log_bucket=log_bucket,
        data_dict_loc=data_dict_loc if include_data_dict else None,
        postprocessors=postprocessors if include_postprocessor else []
      ),
      gr.Accordion(label="Initial Config", open=False),
      gr.update(visible=True),
      gr.update(visible=True),
    ],
    [
      input_location,
      input_project_id,
      input_dataset_id,
      input_table_names,
      input_result_row_limit,
      input_enum_option_limit,
      input_include_data_dict,
      input_include_postprocessor
    ],
    [
      askbq,
      accordion_init_config,
      group_input,
      text_buffer
    ]
  )

demo.queue().launch(
  share=False,
  quiet=True,
  debug=True,
  inline=False,
#   auth=("adani_google_poc", "1230-=/.,cxz")
)

Running on local URL:  http://127.0.0.1:7860


Traceback (most recent call last):
  File "/opt/homebrew/lib/python3.11/site-packages/gradio/queueing.py", line 407, in call_prediction
    output = await route_utils.call_process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/gradio/route_utils.py", line 226, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/gradio/blocks.py", line 1550, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/gradio/blocks.py", line 1185, in call_function
    prediction = await anyio.to_thread.run_sync(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Keyboard interruption in main thread... closing server.


