<a href="https://colab.research.google.com/github/vkrisvasan/llamaKV/blob/main/LearnGrerelAIdataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# This notebook demonstrates how to load and process the "gretelai/synthetic_text_to_sql" dataset,
# specifically focusing on the "Retail" domain. It includes steps for filtering, stratified sampling,
# and selecting specific columns for further analysis or model training.
# save the HF_TOKEN in the collab secrets and provide access to Notebook
# https://gretel.ai/blog/synthetic-text-to-sql-dataset

!pip install datasets -q

In [8]:
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import pandas as pd

dataSetName = "gretelai/synthetic_text_to_sql"
dataSetDomain = "retail"
maxRowtobeselected = 100

# Load the dataset
dataset = load_dataset(dataSetName, split="train")
# Convert to pandas DataFrame for easier manipulation
dataset_df = dataset.to_pandas()

print("\nCount of Unique values in each column of dataset " , dataSetName, " is:\n" ,dataset_df.nunique())

# Filter for the "Retail" domain
retail_dataset = dataset.filter(lambda example: example['domain'] == dataSetDomain)

print ("\nLength of the dataset " , dataSetName , " for domain " , dataSetDomain , " is ", len(retail_dataset))

# Convert to pandas DataFrame for easier manipulation
retail_df = retail_dataset.to_pandas()
print("\nCount of Unique values in each column of dataset " , dataSetName, " is:\n" ,retail_df.nunique())

# Assume we want equal distribution based on 'sql_prompt'
# Get unique values in 'sql_prompt' and calculate the number of samples per group
unique_domains = retail_df['sql_prompt'].unique()
print(len(unique_domains))
samples_per_group = maxRowtobeselected // len(unique_domains)

# Stratified sampling using 'sql_prompt' to ensure equal distribution
stratified_df = pd.concat([
    group.sample(n=min(samples_per_group, len(group)), random_state=42)
    for _, group in retail_df.groupby('sql_prompt')
])

# If fewer than maxRowtobeselected rows are collected due to small groups, sample additional rows randomly
if len(stratified_df) < maxRowtobeselected:
    additional_samples = retail_df.drop(stratified_df.index).sample(n=maxRowtobeselected-len(stratified_df), random_state=42)
    stratified_df = pd.concat([stratified_df, additional_samples])

# Convert back to a Hugging Face Dataset
final_dataset = dataset.from_pandas(stratified_df)

# Select specific columns
final_dataset = final_dataset.map(lambda example: {
    'sql_prompt': example['sql_prompt'],
    'sql_context': example['sql_context'],
    'sql': example['sql']
},remove_columns=dataset.column_names)

# Verify the result by printing it
print(final_dataset)
print(final_dataset[0])


Count of Unique values in each column of dataset  gretelai/synthetic_text_to_sql  is:
 id                            100000
domain                           100
domain_description               100
sql_complexity                     8
sql_complexity_description         8
sql_task_type                      4
sql_task_type_description          4
sql_prompt                    100000
sql_context                    89766
sql                            99271
sql_explanation                99777
dtype: int64

Length of the dataset  gretelai/synthetic_text_to_sql  for domain  retail  is  979

Count of Unique values in each column of dataset  gretelai/synthetic_text_to_sql  is:
 id                            979
domain                          1
domain_description              1
sql_complexity                  7
sql_complexity_description      7
sql_task_type                   4
sql_task_type_description       4
sql_prompt                    979
sql_context                   924
sql           

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Dataset({
    features: ['sql_prompt', 'sql_context', 'sql', '__index_level_0__'],
    num_rows: 100
})
{'sql_prompt': 'Determine the number of products manufactured using sustainable materials', 'sql_context': 'CREATE TABLE product_materials (product_id INT, material TEXT, is_sustainable BOOLEAN);', 'sql': 'SELECT COUNT(*) as num_products_with_sustainable_materials FROM product_materials WHERE is_sustainable = TRUE;', '__index_level_0__': 199}


In [9]:
from datasets import load_dataset
dataset = load_dataset("gretelai/synthetic_text_to_sql", split = "train")
# Filter for "Retail" domain, select first 50 rows, and choose specific columns
retail_dataset = dataset.filter(lambda example: example['domain'] == 'retail') \
                        .select(range(50)) \
                        .map(lambda example: {'sql_prompt': example['sql_prompt'],
                                              'sql_context': example['sql_context'],
                                              'sql': example['sql']},
                             remove_columns=dataset.column_names)  # Remove all original columns
print(retail_dataset)
print(retail_dataset[0])

Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Dataset({
    features: ['sql_prompt', 'sql_context', 'sql'],
    num_rows: 50
})
{'sql_prompt': 'What is the percentage of products that are made from recycled materials?', 'sql_context': 'CREATE TABLE products(product_id INT, is_recycled BOOLEAN);', 'sql': 'SELECT (COUNT(*) * 100.0 / (SELECT COUNT(*) FROM products)) as percentage FROM products WHERE is_recycled = TRUE;'}


In [10]:
# Iterate over the dataset and print each row in a readable format
for i, row in enumerate(retail_dataset):
    print(f"Row {i+1}:")
    for key, value in row.items():
        print(f"  {key}: {value}")
    print("-" * 20)

Row 1:
  sql_prompt: What is the percentage of products that are made from recycled materials?
  sql_context: CREATE TABLE products(product_id INT, is_recycled BOOLEAN);
  sql: SELECT (COUNT(*) * 100.0 / (SELECT COUNT(*) FROM products)) as percentage FROM products WHERE is_recycled = TRUE;
--------------------
Row 2:
  sql_prompt: How many circular supply chain partners does each brand work with, by country?
  sql_context: CREATE TABLE Brands (id INT, brand VARCHAR(255), country VARCHAR(255)); INSERT INTO Brands (id, brand, country) VALUES (1, 'BrandA', 'USA'), (2, 'BrandB', 'Canada'), (3, 'BrandC', 'Mexico'); CREATE TABLE CircularSupplyChain (id INT, brand_id INT, partner_id INT, partner VARCHAR(255), country VARCHAR(255)); INSERT INTO CircularSupplyChain (id, brand_id, partner_id, partner, country) VALUES (1, 1, 1, 'Partner1', 'USA'), (2, 1, 2, 'Partner2', 'USA'), (3, 2, 3, 'Partner3', 'Canada'), (4, 2, 4, 'Partner4', 'Canada'), (5, 3, 5, 'Partner5', 'Mexico'), (6, 3, 6, 'Partner6', 