In [16]:
import os
import json
import openai
import pickle
import warnings
import requests
import pandas as pd
from tqdm import tqdm
from typing import List
from azureml.core import Workspace
from azure.identity import DefaultAzureCredential
from azureml.core.authentication import ServicePrincipalAuthentication

from langchain.chains import LLMChain
from langchain_openai import AzureChatOpenAI
from langchain_core.prompts import  PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts.few_shot import FewShotPromptTemplate

from langchain.output_parsers import PydanticOutputParser, JsonOutputKeyToolsParser, CommaSeparatedListOutputParser
pd.set_option('display.max_rows', None)
warnings.filterwarnings("ignore")
ws = Workspace.from_config()


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


In [17]:
class BearerAuth(requests.auth.AuthBase):
    def __init__(self, token):
        self.token = token
    def __call__(self, r):
        r.headers["authorization"] = "Bearer " + self.token
        return r



    
def initialize_llm(model_name) -> AzureChatOpenAI:
    ws = Workspace.from_config()
    keyvault = ws.get_default_keyvault()
    credential = DefaultAzureCredential()
    workspacename = keyvault.get_secret("project-workspace-name")
    access_token = credential.get_token("https://cognitiveservices.azure.com/.default")
    os.environ["AZURE_OPENAI_KEY"] = access_token.token
    openai.api_type = "azure_ad"
    os.environ["AZURE_OPENAI_ENDPOINT"] = f"https://{workspacename}openai.openai.azure.com/"
    openai.api_version = "2023-07-01-preview"
    subscriptionId = keyvault.get_secret("project-subscription-id")
    # Ensure you have these environment variables set up with your Azure OpenAI credentials
    os.environ["AZURE_OPENAI_API_KEY"] = "ee0dd46654bd4427ba4f5580b5a0db0a"
    os.environ["AZURE_OPENAI_API_BASE"] = "https://xqrojjmb2wjlqopopenai.openai.azure.com/"

    if model_name == "gpt-4o":
        os.environ["AZURE_OPENAI_API_VERSION"] = "2024-05-01-preview"
        os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "gpt-4o"
    

    

        subscriptionId = keyvault.get_secret("project-subscription-id")
        apiVersion = "2023-10-01-preview"
        url = f"https://management.azure.com/subscriptions/{subscriptionId}/resourceGroups/{workspacename}-common/providers/Microsoft.CognitiveServices/accounts/{workspacename}openai/deployments?api-version={apiVersion}"
        accessToken = credential.get_token("https://management.azure.com/.default")
        response = requests.get(url, auth=BearerAuth(accessToken.token));
    
        print(f'Initializing Model : {os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"]}')
        model = AzureChatOpenAI(
                    deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
                    azure_endpoint=os.environ["AZURE_OPENAI_API_BASE"],
                    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
                    openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
                    max_tokens=4000,
                    temperature=0.9,
                    model_kwargs={"seed": 1337}
                )
        
        print(f'Model {model_name} Initialized')

    elif model_name == "gpt-4.1":

        os.environ["AZURE_OPENAI_API_VERSION"] = "2024-12-01-preview"
        os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "gpt-4.1"
    
    
        subscriptionId = keyvault.get_secret("project-subscription-id")
        apiVersion = "2024-12-01-preview"
        url = f"https://management.azure.com/subscriptions/{subscriptionId}/resourceGroups/{workspacename}-common/providers/Microsoft.CognitiveServices/accounts/{workspacename}openai/deployments?api-version={apiVersion}"
        accessToken = credential.get_token("https://management.azure.com/.default")
        response = requests.get(url, auth=BearerAuth(accessToken.token));
    
        print(f'Initializing Model : {os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"]}')
        model = AzureChatOpenAI(
                    deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
                    azure_endpoint=os.environ["AZURE_OPENAI_API_BASE"],
                    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
                    openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
                    max_tokens=4000,
                    temperature=0.9,
                    model_kwargs={"seed": 1337}
                )
        
        print(f'Model {model_name} Initialized')
        
    
    return model        

In [18]:
def get_query_for_specialty_layman(model : AzureChatOpenAI, search_query : str):


    class SpecialtiesResponse(BaseModel):
        queries: List[str] = Field(description="List of queries corresponding to user provided medical specialty")

    # Set up the PydanticOutputParser with the SpecialtiesResponse model
    output_parser = CommaSeparatedListOutputParser()

    system_prompt =    """You are a helpful AI assistant specializing in healthcare. 
                        Your task is to paraphrase the web search queries provided by the user into 5 different ways. Please make sure that the queries
                        are of high quality and semantically represent the original query provided by the user.
                        
                        
                        PLEASE TAKE YOUR TIME IN UNDERSTANDING THE USER QUERY AND THEN GENERATE POSSIBLE PARAPHRASED AUGMENTED SEARCH QUERIES WHILE
                        RESPECTING THE ABOVE DEFINED CONDITION.
                        
                        PLEASE ONLY OUTPUT THE PARAPHRASED AUGMENTED SEARCH QUERIES NO OTHER TEXT.
                        
                        Format Instructions:
                        {format_instructions}
                        
                        search_query: {search_query}"""
    
    # Feedback : limit the number of words between 3 to 7

    # Define the prompt template
    prompt_template = PromptTemplate.from_template(
        template=system_prompt
    )
    
    
    chain = prompt_template | model | output_parser
    result = chain.invoke(input={"search_query": search_query, "format_instructions" : output_parser.get_format_instructions()})
    return result


In [19]:
def load_query_classified_specialty_filtered_diagnostic_dataset(dataset_path : str) -> dict:

    
    all_files = [dataset_path + file for file in os.listdir(dataset_path) if '.json' in file]

    specialty_query_dict = {}


    for file_path in all_files:
        with open(file_path,'r') as file:
            data = json.load(file)
            
        specialty = list(data.keys())[0]
        queries = list(data.values())[0]

        specialty_query_dict[specialty] = queries
    
    return specialty_query_dict


In [20]:
dataset_path_synthetic = '../../../datasets/datasets_augmented/augmentation_set3/gpt41_query_clasification_results/'

synthetic_specialty_query_dict = load_query_classified_specialty_filtered_diagnostic_dataset(dataset_path = dataset_path_synthetic)

print(f'Syenthetic Queries Dataset Size : {len(synthetic_specialty_query_dict)}')

dataset_path_ues = '../../../datasets/datasets_augmented/augmentation_set3/ues_keyword_nucc_classification/nucc_classification_by_specialties/'

ues_specialty_query_dict = load_query_classified_specialty_filtered_diagnostic_dataset(dataset_path = dataset_path_ues)

print(f'UES Queries Dataset Size : {len(ues_specialty_query_dict)}')

all_specialties_ues = list(ues_specialty_query_dict.keys())
all_specialties_synthetic_nucc = list(synthetic_specialty_query_dict.keys())
print(f'Verifying UES labeled Specialties exist in NUCC specialty list : {True if len(set(all_specialties_ues).intersection(set(all_specialties_synthetic_nucc))) > 0 else False}')

Syenthetic Queries Dataset Size : 590
UES Queries Dataset Size : 553
Verifying UES labeled Specialties exist in NUCC specialty list : True


# Find Distinct UES Queries In Each Specialty Group

In [21]:
import sys
from typing import Dict, List, Set, Tuple
sys.path.append('./')
from Query_Diversity_Selection_Algorithm import QueryDiversitySelector

In [22]:
def get_diverse_queries_by_specialty(load_unique_queries : bool, specialty_query_dict_queries_only : Dict[str, List[str]]):

    if load_unique_queries:

        with open('./similarity_results.json', 'r') as file:
            similarity_results = json.load(file)

        return similarity_results

    else:
        
    
        embedding_model = '../../../../shekhar_tanwar/ICD-ICD-Triplet/model/NovaSearch_stella_en_1.5B_v5/'
    
        selector = QueryDiversitySelector(embedding_model = embedding_model)
    
        # run both algorithms
    
        # if it throws an error, use k_cluster = None
    
        similarity_results, cluster_results = selector.run_comparison(specialty_queries = specialty_query_dict_queries_only, similarity_threshold = 0.9, k_clusters = None)
        with open('./similarity_results.json','w') as file:
            json.dump(similarity_results, file, indent = 4)
        
        return similarity_results

In [23]:
load_unique_queries = True
similarity_results = get_diverse_queries_by_specialty(load_unique_queries = load_unique_queries,  specialty_query_dict_queries_only = ues_specialty_query_dict)

In [24]:
all_specialties = list(similarity_results.keys())

In [25]:
len(all_specialties)

553

# Paraphrase Distinct UES Queries

In [26]:

def get_specialty_paraphrased_queries(load_paraphrased_result : bool, similarity_results : Dict[str, List[str]]):


    if load_paraphrased_result:
        with open('./specialty_paraphrased_queries_dict.json','r') as file:
            specialty_paraphrased_queries_dict = json.load(file)
        
        return specialty_paraphrased_queries_dict, _

    else:

        model = initialize_llm(model_name="gpt-4.1")
    
        specialty_paraphrased_queries_dict = {}
        error_specialty_paraphrased_queries_dict = {}
        
        
        for specialty, queries in tqdm(similarity_results.items()):
        
            if len(queries) == 0:
                specialty_paraphrased_queries_dict[specialty] = []
                continue
        
        
            else:
        
                print(f'Generating Papraphrased Queries for Specialty : {specialty}')
        
                paraphrased_queries_list = []
                error_list = []
                for search_query in queries:
                    try:
                        result = get_query_for_specialty_layman(model = model, search_query = search_query)
                        paraphrased_queries_list = paraphrased_queries_list + result
                    except:
                        error_list.append(search_query)
        
                specialty_paraphrased_queries_dict[specialty] = paraphrased_queries_list
                error_specialty_paraphrased_queries_dict[specialty] = error_list

        # these are the paraphrased production queries by each specialty
        with open('./specialty_paraphrased_queries_dict.json','w') as file:
            json.dump(specialty_paraphrased_queries_dict, file, indent = 4)
            
        return specialty_paraphrased_queries_dict , error_specialty_paraphrased_queries_dict
        
load_paraphrased_result = True                    
specialty_paraphrased_queries_dict, error_specialty_paraphrased_queries_dict =  get_specialty_paraphrased_queries(load_paraphrased_result = load_paraphrased_result, similarity_results = similarity_results)          

# Generating Summary Stats For Adding Paraphrased Production Queries To Synthetic Queries

In [71]:
def get_summary_stats(specialty_paraphrased_queries_dict : Dict[str, List[str]], ues_specialty_query_dict : Dict[str, List[str]] ) -> pd.DataFrame:

    all_specialties_paraphrased = list(specialty_paraphrased_queries_dict.keys())
    total_queries = []
    
    for specialty, queries in specialty_paraphrased_queries_dict.items():
        paraphrased_queries = specialty_paraphrased_queries_dict.get(specialty)
        total_queries.append(len(paraphrased_queries))
    paraphrased_queries_stats = pd.DataFrame(list(zip(all_specialties_paraphrased, total_queries)), columns = ['Specialties','Total_Queries_Synthetic'])    
    
    
    all_specialties_synthetic = list(ues_specialty_query_dict.keys())
    
    total_queries = []
    
    for specialty, queries in ues_specialty_query_dict.items():
        synthetic_queries = ues_specialty_query_dict.get(specialty)
        total_queries.append(len(synthetic_queries))
    synthetic_queries_stats = pd.DataFrame(list(zip(all_specialties_synthetic, total_queries)), columns = ['Specialties','Total_Queries_Paraphrased'])    
    
    
    final_stats = pd.merge(synthetic_queries_stats, paraphrased_queries_stats, how = 'left', left_on = ['Specialties'], right_on = ['Specialties'])
    
    final_stats = final_stats.fillna(0)
    final_stats['Total_Queries'] = final_stats['Total_Queries_Synthetic'] + final_stats['Total_Queries_Paraphrased']

    return final_stats

In [28]:
final_stats = get_summary_stats(specialty_paraphrased_queries_dict = specialty_paraphrased_queries_dict, ues_specialty_query_dict = ues_specialty_query_dict)

In [29]:

# to remove
# ambulance
# art therapist
# case managercare coordinator
# doula_doula	
# drama therapist
# indian
# 

final_stats_less_200_diagnostic_quries = final_stats[final_stats['Total_Queries'] < 250]
final_stats_less_200_diagnostic_quries['Specialties_Trimmed'] = final_stats_less_200_diagnostic_quries['Specialties'].apply(lambda x : x.split('_')[0])
final_stats_less_200_diagnostic_quries = final_stats_less_200_diagnostic_quries[~final_stats_less_200_diagnostic_quries['Specialties_Trimmed'].isin(['ambulance','art therapist','case managercare coordinator','drama therapist','indian'])]
final_stats_less_200_diagnostic_quries = final_stats_less_200_diagnostic_quries[~final_stats_less_200_diagnostic_quries['Specialties_Trimmed'].str.contains('indian')]

In [30]:
final_stats_less_200_diagnostic_quries.head()

Unnamed: 0,Specialties,Total_Queries_Synthetic,Total_Queries_Paraphrased,Total_Queries,Specialties_Trimmed
0,acupuncturist_acupuncturist,0,0,0,acupuncturist
1,advanced practice midwife_advanced practice mi...,3,5,8,advanced practice midwife
2,allergy & immunology_allergy & immunology,38,30,68,allergy & immunology
4,allergy & immunology_clinical & laboratory imm...,41,30,71,allergy & immunology
8,anesthesiology_anesthesiology,5,5,10,anesthesiology


In [31]:
final_stats_less_200_diagnostic_quries.shape

(488, 5)

In [32]:
final_stats_less_200_diagnostic_quries.to_csv('./final_stats_less_200_diagnostic_quries.csv')

## NOTE : these are those specialties for which the number of queries after diagnotic filter with addition of paraphrased production queries would be less than 250

# Feedback : 

Regenerate Diagnostic Queries (using Stage_2_step1_query_generator_07142025.py ) for the Specialties_Subspecialties (total 116) identified in final_stats_zero_diagnostic_quries

In [16]:
focus_list = list(final_stats_zero_diagnostic_quries['Specialties'])

In [68]:
with open('specialty_focus_list.pkl','wb') as file:
    pickle.dump(focus_list, file)

In [69]:
len(focus_list)

116

In [33]:
# from pathlib import Path
# import shutil

# # Define source and destination folders
# source_folder = Path("../../../datasets/datasets_augmented/augmentation_set3/")
# destination_folder = Path("../../../datasets/datasets_augmented/augmentation_set3/iteration1/")

# # Define the file extension to move
# extension = "*.json"  # Example: move all .jpg files

# # Create the destination folder if it doesn't exist
# destination_folder.mkdir(parents=True, exist_ok=True)

# # Iterate and move files
# for file_path in source_folder.glob(extension):
#     try:
#         shutil.move(file_path, destination_folder / file_path.name)
#         print(f"Moved: {file_path.name}")
#     except shutil.Error as e:
#         print(f"Error moving {file_path.name}: {e}")
#     except Exception as e:
#         print(f"An unexpected error occurred while moving {file_path.name}: {e}")

# Adding Synthetic Queries and Paraphrased Production Queries

In [34]:
import os
import json

In [51]:
def get_synthetic_queries():

    path = '../../../datasets/datasets_augmented/augmentation_set4/iteration1/'
    all_files = [path + file for file in os.listdir(path) if '.json' in file]    
   
    
    specialty_query_dict = {}
    for file_path in all_files:
        with open(file_path, 'r') as file:
            data_specialty = json.load(file)
    
        specialty = list(data_specialty.keys())[0]
        queries = list(data_specialty.values())[0]
    
        specialty_query_dict[specialty] = queries

    return specialty_query_dict

In [55]:
def get_paraphrased_production_queries():

    file_path = './specialty_paraphrased_queries_dict.json'

    with open(file_path, 'r')  as file:
        paraphrased_production_queries = json.load(file)

    return paraphrased_production_queries

In [72]:
specialty_query_dict = get_synthetic_queries()
paraphrased_production_queries = get_paraphrased_production_queries()
final_stats = get_summary_stats(specialty_paraphrased_queries_dict = specialty_query_dict, ues_specialty_query_dict = paraphrased_production_queries)

# Analysis

In [73]:
final_stats.head(10)

Unnamed: 0,Specialties,Total_Queries_Paraphrased,Total_Queries_Synthetic,Total_Queries
0,acupuncturist_acupuncturist,0,310.0,310.0
1,advanced practice midwife_advanced practice mi...,5,396.0,401.0
2,allergy & immunology_allergy & immunology,30,379.0,409.0
3,allergy & immunology_allergy,80,382.0,462.0
4,allergy & immunology_clinical & laboratory imm...,30,376.0,406.0
5,ambulance_air transport,0,374.0,374.0
6,ambulance_land transport,0,356.0,356.0
7,anesthesiology_addiction medicine,75,397.0,472.0
8,anesthesiology_anesthesiology,5,354.0,359.0
9,anesthesiology_critical care medicine,55,360.0,415.0


In [76]:
final_stats[final_stats['Total_Queries_Synthetic'] == min(final_stats['Total_Queries_Synthetic'])]

Unnamed: 0,Specialties,Total_Queries_Paraphrased,Total_Queries_Synthetic,Total_Queries
147,emergency medicine_medical toxicology,133,0.0,133.0
416,psychiatry & neurology_brain injury medicine,10,0.0,10.0


In [77]:
final_stats[final_stats['Total_Queries_Synthetic'] < 200]

Unnamed: 0,Specialties,Total_Queries_Paraphrased,Total_Queries_Synthetic,Total_Queries
147,emergency medicine_medical toxicology,133,0.0,133.0
416,psychiatry & neurology_brain injury medicine,10,0.0,10.0


# Feedback:
emergency medicine_medical toxicology and psychiatry & neurology_brain injury medicine have 0 synthetic queries.


Considering augmentation_set3/gpt41_query_clasification_results for adding additional queries 


In [78]:
with open('../../../datasets/datasets_augmented/augmentation_set3/gpt41_query_clasification_results/emergency medicine_medical toxicology.json', 'r') as file:
    em_mt = json.load(file)

with open('../../../datasets/datasets_augmented/augmentation_set3/gpt41_query_clasification_results/psychiatry & neurology_brain injury medicine.json', 'r') as file:
    pn_nbim = json.load(file)

specialty_query_dict['emergency medicine_medical toxicology'] = em_mt['emergency medicine_medical toxicology']
specialty_query_dict['psychiatry & neurology_brain injury medicine'] = pn_nbim['psychiatry & neurology_brain injury medicine']        

In [87]:
final_stats = get_summary_stats(specialty_paraphrased_queries_dict = specialty_query_dict, ues_specialty_query_dict = paraphrased_production_queries)

In [88]:
final_stats.head()

Unnamed: 0,Specialties,Total_Queries_Paraphrased,Total_Queries_Synthetic,Total_Queries
0,acupuncturist_acupuncturist,0,310,310
1,advanced practice midwife_advanced practice mi...,5,396,401
2,allergy & immunology_allergy & immunology,30,379,409
3,allergy & immunology_allergy,80,382,462
4,allergy & immunology_clinical & laboratory imm...,30,376,406


In [89]:
final_stats[final_stats['Total_Queries_Synthetic'] == min(final_stats['Total_Queries_Synthetic'])]

Unnamed: 0,Specialties,Total_Queries_Paraphrased,Total_Queries_Synthetic,Total_Queries
462,radiologic technologist_nuclear medicine techn...,10,232,242


# Combine and Sample Queries 

In [26]:
import pandas as pd
import numpy as np
import json

def get_synthetic_queries():
    """Load synthetic queries from JSON files"""
    path = '../../../datasets/datasets_augmented/augmentation_set4_v40/iteration1/'
    all_files = [path + file for file in os.listdir(path) if '.json' in file]
    
    specialty_query_dict = {}
    for file_path in all_files:
        with open(file_path, 'r') as file:
            data_specialty = json.load(file)
            
        specialty = list(data_specialty.keys())[0]
        queries = list(data_specialty.values())[0]
        
        specialty_query_dict[specialty] = queries
    
    return specialty_query_dict

def get_paraphrased_production_queries():
    """Load paraphrased queries from JSON file"""
    file_path = './specialty_paraphrased_queries_dict.json'
    
    with open(file_path, 'r') as file:
        paraphrased_production_queries = json.load(file)
    
    return paraphrased_production_queries

In [2]:
def get_summary_stats(specialty_query_dict, paraphrased_production_queries):
    """
    Generate summary statistics based on the loaded query dictionaries.
    Uses specialty_query_dict keys as reference since it has more specialties.
    """
    stats_data = []
    
    for specialty in specialty_query_dict.keys():
        # Count synthetic queries
        synthetic_count = len(specialty_query_dict[specialty])
        
        # Count paraphrased queries (may not exist for all specialties)
        paraphrased_count = len(paraphrased_production_queries.get(specialty, []))
        
        # Total queries
        total_count = synthetic_count + paraphrased_count
        
        stats_data.append({
            'Specialties': specialty,
            'Total_Queries_Paraphrased': paraphrased_count,
            'Total_Queries_Synthetic': synthetic_count,
            'Total_Queries': total_count
        })
    
    return pd.DataFrame(stats_data)

In [3]:
def sample_queries_per_specialty(df, target_queries_per_specialty=250, paraphrased_threshold=50):
    """
    Sample queries with preference for paraphrased queries based on specified rules.
    
    Rules:
    1. Target: 250 queries per specialty
    2. If paraphrased <= 50: take all paraphrased, fill remaining with synthetic
    3. If total available < 250: use all queries (no sampling)
    4. If paraphrased = 0: sample only from synthetic
    5. If paraphrased > 200: give equal weight to both types
    6. Otherwise: prioritize paraphrased, fill remaining with synthetic
    
    Args:
        df: DataFrame with columns ['Specialties', 'Total_Queries_Paraphrased', 'Total_Queries_Synthetic', 'Total_Queries']
        target_queries_per_specialty: Target number of queries per specialty (default: 250)
        paraphrased_threshold: Threshold for taking all paraphrased queries (default: 50)
    
    Returns:
        DataFrame with sampling results
    """
    results = []
    
    for _, row in df.iterrows():
        specialty = row['Specialties']
        paraphrased_count = int(row['Total_Queries_Paraphrased'])
        synthetic_count = int(row['Total_Queries_Synthetic'])
        total_available = int(row['Total_Queries'])
        
        # Initialize sampling counts
        sampled_paraphrased = 0
        sampled_synthetic = 0
        sampling_strategy = ""
        
        # Rule 3: If total available < target, use all queries
        if total_available < target_queries_per_specialty:
            sampled_paraphrased = paraphrased_count
            sampled_synthetic = synthetic_count
            sampling_strategy = "Use all available (insufficient total)"
            
        # Rule 4: If no paraphrased queries, sample only from synthetic
        elif paraphrased_count == 0:
            sampled_synthetic = min(target_queries_per_specialty, synthetic_count)
            sampling_strategy = "Synthetic only (no paraphrased available)"
            
        # Rule 2: If paraphrased <= threshold, take all paraphrased
        elif paraphrased_count <= paraphrased_threshold:
            sampled_paraphrased = paraphrased_count
            remaining_needed = target_queries_per_specialty - sampled_paraphrased
            sampled_synthetic = min(remaining_needed, synthetic_count)
            sampling_strategy = f"All paraphrased + synthetic (paraphrased <= {paraphrased_threshold})"
            
        # Rule 5: If paraphrased > 200, give equal weight
        elif paraphrased_count > 200:
            target_each = target_queries_per_specialty // 2  # 125 each
            sampled_paraphrased = min(target_each, paraphrased_count)
            sampled_synthetic = min(target_each, synthetic_count)
            
            # If one type has fewer than target_each, allocate remaining to the other
            total_so_far = sampled_paraphrased + sampled_synthetic
            if total_so_far < target_queries_per_specialty:
                remaining = target_queries_per_specialty - total_so_far
                if sampled_paraphrased < target_each:
                    # Synthetic was limited, try to get more paraphrased
                    additional_paraphrased = min(remaining, paraphrased_count - sampled_paraphrased)
                    sampled_paraphrased += additional_paraphrased
                else:
                    # Paraphrased was limited, try to get more synthetic
                    additional_synthetic = min(remaining, synthetic_count - sampled_synthetic)
                    sampled_synthetic += additional_synthetic
            
            sampling_strategy = "Equal weight (paraphrased > 200)"
            
        # Rule 6: Default case - prioritize paraphrased, fill with synthetic
        else:
            # Prioritize paraphrased - aim for about 70% if possible
            preferred_paraphrased = min(
                int(target_queries_per_specialty * 0.7),
                paraphrased_count
            )
            sampled_paraphrased = preferred_paraphrased
            remaining_needed = target_queries_per_specialty - sampled_paraphrased
            sampled_synthetic = min(remaining_needed, synthetic_count)
            
            # If we still need more and have more paraphrased available
            total_so_far = sampled_paraphrased + sampled_synthetic
            if total_so_far < target_queries_per_specialty and paraphrased_count > sampled_paraphrased:
                additional_needed = target_queries_per_specialty - total_so_far
                additional_paraphrased = min(additional_needed, paraphrased_count - sampled_paraphrased)
                sampled_paraphrased += additional_paraphrased
            
            sampling_strategy = "Prioritize paraphrased, fill with synthetic"
        
        total_sampled = sampled_paraphrased + sampled_synthetic
        paraphrased_ratio = sampled_paraphrased / total_sampled if total_sampled > 0 else 0
        
        results.append({
            'Specialties': specialty,
            'Original_Paraphrased': paraphrased_count,
            'Original_Synthetic': synthetic_count,
            'Original_Total': total_available,
            'Sampled_Paraphrased': sampled_paraphrased,
            'Sampled_Synthetic': sampled_synthetic,
            'Total_Sampled': total_sampled,
            'Paraphrased_Ratio': round(paraphrased_ratio, 3),
            'Sampling_Strategy': sampling_strategy,
            'Target_Met': total_sampled == target_queries_per_specialty
        })
    
    return pd.DataFrame(results)

In [4]:
def generate_actual_samples(sampling_results, paraphrased_production_queries, specialty_query_dict):
    """
    Generate actual query samples based on the sampling strategy.
    
    Args:
        sampling_results: Results from sample_queries_per_specialty function
        paraphrased_production_queries: Dictionary with paraphrased queries by specialty
        specialty_query_dict: Dictionary with synthetic queries by specialty
    
    Returns:
        Dictionary with sampled queries by specialty
    """
    sampled_queries = {}
    
    for _, row in sampling_results.iterrows():
        specialty = row['Specialties']
        n_paraphrased = row['Sampled_Paraphrased']
        n_synthetic = row['Sampled_Synthetic']
        
        specialty_samples = {
            'paraphrased': [],
            'synthetic': [],
            'total_count': row['Total_Sampled']
        }
        
        # Sample paraphrased queries
        if n_paraphrased > 0 and specialty in paraphrased_production_queries:
            available_paraphrased = paraphrased_production_queries[specialty]
            if len(available_paraphrased) >= n_paraphrased:
                specialty_samples['paraphrased'] = np.random.choice(
                    available_paraphrased, 
                    size=n_paraphrased, 
                    replace=False
                ).tolist()
            else:
                specialty_samples['paraphrased'] = available_paraphrased
        
        # Sample synthetic queries
        if n_synthetic > 0 and specialty in specialty_query_dict:
            available_synthetic = specialty_query_dict[specialty]
            if len(available_synthetic) >= n_synthetic:
                specialty_samples['synthetic'] = np.random.choice(
                    available_synthetic, 
                    size=n_synthetic, 
                    replace=False
                ).tolist()
            else:
                specialty_samples['synthetic'] = available_synthetic
        
        sampled_queries[specialty] = specialty_samples
    
    return sampled_queries

In [27]:
specialty_query_dict = get_synthetic_queries()
paraphrased_production_queries = get_paraphrased_production_queries()

# Get summary statistics using your actual data
final_stats = get_summary_stats(specialty_query_dict, paraphrased_production_queries)

print(f"Loaded data:")
print(f"- Total specialties in synthetic data: {len(specialty_query_dict)}")
print(f"- Total specialties in paraphrased data: {len(paraphrased_production_queries)}")
print(f"- Specialties with both types: {len(set(specialty_query_dict.keys()) & set(paraphrased_production_queries.keys()))}")
print(f"- Specialties with only synthetic: {len(set(specialty_query_dict.keys()) - set(paraphrased_production_queries.keys()))}")
print()

# Apply sampling algorithm
sampling_results = sample_queries_per_specialty(final_stats, target_queries_per_specialty=250)

# Display results
print("Sampling Results:")
#print("=" * 120)
#print(sampling_results.to_string(index=False))

# Summary statistics
print(f"\nSummary:")
print(f"Total specialties: {len(sampling_results)}")
print(f"Specialties meeting target (250): {sampling_results['Target_Met'].sum()}")
print(f"Average paraphrased ratio: {sampling_results['Paraphrased_Ratio'].mean():.3f}")
print(f"Specialties with no paraphrased queries: {(sampling_results['Original_Paraphrased'] == 0).sum()}")
print(f"Specialties with >200 paraphrased queries: {(sampling_results['Original_Paraphrased'] > 200).sum()}")

# Generate actual samples
print("\nGenerating actual query samples...")
sampled_queries = generate_actual_samples(sampling_results, paraphrased_production_queries, specialty_query_dict)

# Save results
print("\nSaving results...")
sampling_results.to_csv('sampling_results.csv', index=False)
with open('../../../datasets/datasets_augmented/final_dataset_v40/sampled_queries.json', 'w') as f:
    json.dump(sampled_queries, f, indent=2)

print("Results saved to:")
print("- sampling_results.csv: Detailed sampling statistics")
print("- sampled_queries.json: Actual sampled queries by specialty")


Loaded data:
- Total specialties in synthetic data: 594
- Total specialties in paraphrased data: 553
- Specialties with both types: 551
- Specialties with only synthetic: 43

Sampling Results:

Summary:
Total specialties: 594
Specialties meeting target (250): 593
Average paraphrased ratio: 0.108
Specialties with no paraphrased queries: 165
Specialties with >200 paraphrased queries: 16

Generating actual query samples...

Saving results...
Results saved to:
- sampling_results.csv: Detailed sampling statistics
- sampled_queries.json: Actual sampled queries by specialty


In [28]:
def flatten_queries_by_specialty(sampled_queries : dict):
    
    all_specialties = list(sampled_queries.keys())

    specialty_query_dict = {}

    for specialty in all_specialties:

        paraphrased_queries = sampled_queries.get(specialty).get('paraphrased')
        synthetic_queries = sampled_queries.get(specialty).get('synthetic')
        final_queries = paraphrased_queries + synthetic_queries

        specialty_query_dict[specialty] = final_queries

    return specialty_query_dict

In [29]:
specialty_query_dict = flatten_queries_by_specialty(sampled_queries = sampled_queries)

In [31]:
all_specialties = list(specialty_query_dict.keys())

In [32]:
with open('../../../datasets/datasets_augmented/final_dataset_v40/final_dataset_v40.json', 'w') as f:
    json.dump(specialty_query_dict, f, indent=4)

In [34]:
def split_input_data(input_file, output_dir, num_chunks):
    
    splits_dir = f"{output_dir}splits/"
    all_file_paths = [splits_dir + file for file in splits_dir if 'json' in file]
    if len(all_file_paths) == num_chunks:
        all_split_files = []
        
        for input_file in all_file_paths:
            with open(input_file, 'r') as f:
                specialty_data = json.load(f)
            
            all_split_files.append(specialty_data)
            
    else:
    
        # Create output directory
        splits_dir = f"{output_dir}splits/"
        os.makedirs(splits_dir, exist_ok=True)

        # Load input data
        with open(input_file, 'r') as f:
            specialty_data = json.load(f)

        # Get list of specialties
        specialties = list(specialty_data.keys())
        chunk_size = len(specialties) // num_chunks + (1 if len(specialties) % num_chunks else 0)

        all_split_files = []
        # Create chunks
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min(start_idx + chunk_size, len(specialties))

            chunk_specialties = specialties[start_idx:end_idx]
            chunk_data = {specialty: specialty_data[specialty] for specialty in chunk_specialties}

            # Save chunk
            chunk_file = f"{splits_dir}gpt_specialties_split_{i}.json"
            all_split_files.append(chunk_data)
            with open(chunk_file, 'w') as f:
                json.dump(chunk_data, f, indent=4)

            #print(f"Chunk {i}: {len(chunk_specialties)} specialties, saved to {chunk_file}")

    return all_split_files

In [37]:
input_file = '../../../datasets/datasets_augmented/final_dataset_v40/final_dataset_v40.json'
output_dir = '../../../datasets/datasets_augmented/final_dataset_v40/splits/'
    
num_chunks = 4
file_index = 0
    

all_split_files = split_input_data(input_file, output_dir, num_chunks)