In [1]:
import pandas as pd
import numpy as np
import os
import jinja2 as j2
import json, logging
from llm import Models

In [2]:
os.environ["OPENAI_API_KEY"] = "AIzaSyArHpPqB8Omm7yVMHX5WwOVerONReTfy_0"
GEMINI_PRO_MODEL = "models/gemini-2.5-pro-preview-05-06"
GEMINI_FLASH_MODEL = "models/gemini-2.5-flash-preview-04-17-thinking"

In [5]:
from sft_dataset_generator import SFTDatasetGenerator
from types import SimpleNamespace

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler("sft_dataset_generator.log"),
        logging.StreamHandler()
    ]
)
args = SimpleNamespace()
args.api_key = os.getenv("OPENAI_API_KEY")
args.model = Models.GEMINI_FLASH_MODEL.value
args.base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
args.csv_file = "../../data/processed/doctor_patient_conversations.csv"
args.max_examples = 5000
args.min_examples = 5000
args.min_q_early_dx = 3000
args.output_path = f"../../data/training/sft_dataset_{args.max_examples}.json"

generator = SFTDatasetGenerator(api_key=args.api_key, model_name=args.model, base_url=args.base_url)

logging.info(f"Loading conversations from: {args.csv_file}")
source_conversations = generator.load_conversations_from_csv(args.csv_file)

if source_conversations:
    logging.info(f"Creating SFT dataset with target examples: {args.min_examples}-{args.max_examples}")
    fine_tuning_dataset = generator.create_sft_dataset(
        source_conversations,
        args.min_examples,
        args.max_examples
    )

    logging.info(f"\n--- Dataset Creation Complete ---")
    logging.info(f"Final number of examples in dataset: {len(fine_tuning_dataset)}")

    if len(fine_tuning_dataset) > 0:
        logging.info("\nSample of the first few examples:")
        for i in range(min(3, len(fine_tuning_dataset))):
            example = fine_tuning_dataset[i]
            prompt_summary = (f"History with {len(example['prompt'])} turns. "
                                f"Last turn by '{example['prompt'][-1]['role']}': "
                                f"{example['prompt'][-1]['content'][:60]}...")
            label_summary = f"{example['label'][:200].replace(chr(10), ' ')}..." 
            logging.info(f"\nExample {i+1}:")
            logging.info(f"  SFT Prompt Summary (model input): {prompt_summary}")
            logging.info(f"  SFT Label Preview (model target): {label_summary}")

        generator.save_dataset_to_disk(fine_tuning_dataset, args.output_path)
    else:
        logging.info("No data in the final SFT dataset. Nothing to save.")

else:
    logging.info("No conversations found in the CSV file. Nothing to process.")

logging.info("Script finished.")

2025-05-21 10:07:58,106 - INFO - OpenAI client initialized for model models/gemini-2.5-flash-preview-04-17-thinking at base URL https://generativelanguage.googleapis.com/v1beta/openai/.
2025-05-21 10:07:58,106 - INFO - Loading conversations from: ../../data/processed/doctor_patient_conversations.csv
2025-05-21 10:07:58,192 - INFO - Found 270 unique conversation_ids in ../../data/processed/doctor_patient_conversations.csv.
Loading Conversations from CSV: 100%|██████████| 270/270 [00:00<00:00, 1446.84it/s]
2025-05-21 10:07:58,382 - INFO - Successfully loaded 270 valid conversations from ../../data/processed/doctor_patient_conversations.csv
2025-05-21 10:07:58,402 - INFO - Creating SFT dataset with target examples: 5000-5000
2025-05-21 10:07:58,419 - INFO - Extracted 13144 potential SFT examples from conversations.
2025-05-21 10:07:58,423 - INFO - Attempting to generate ~5000 SFT examples.
Generating SFT Dataset Labels:   0%|          | 0/5000 [00:00<?, ?it/s]2025-05-21 10:08:04,947 - INF

Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]

2025-05-21 23:49:51,009 - INFO - Dataset successfully saved to ../../data/training/sft_dataset_5000.json
2025-05-21 23:49:51,019 - INFO - Script finished.


In [34]:
from datasets import load_from_disk
import random

args.output_path = f"../../data/training/sft_dataset_{args.max_examples}.json"

dataset = load_from_disk(args.output_path)

def filter_error_rows(example):
    # Check if the label contains 'error'
    if 'LLM_CALL_FAILED_NON_RATELIMIT' in example['label'] or 'LLM_CALL_FAILED_ALL_RETRIES' in example['label']:
        return False
    tags_complete = "<think>" in example['label'] and "</think>" in example['label']
    if not tags_complete:
        print(f"Example {example['label']}")
        print("--")
        return False
    return True
filtered_dataset = dataset.filter(filter_error_rows)
output_path = f"../../data/training/sft_filtered_dataset__{args.max_examples}.json"

generator.save_dataset_to_disk(filtered_dataset, output_path)

Filter:   0%|          | 0/5000 [00:00<?, ? examples/s]

Example <think
--
Example <think>
Ok let us see. Here is a summary of the case so far: The patient reports nasal and facial congestion starting 5
--
Example <think>
Ok let us see. Here is a summary of the case so far: The patient reports nasal obstruction and clear nasal discharge for one week. The obstruction is primarily at night and is relieved by changing position. They deny fever, pain, headache, changes in smell/taste, eye/ear symptoms, sore
--
Example <think>
Ok let us see. Here is a summary of the case so far: 5-year-old boy, Dave, with a history of asthma and dermatitis, presenting with a few days of significantly worse wheezing, difficulty sleeping and eating, irritability, and breathlessness with exertion. Rescue inhaler use is increased but ineffective; he also uses a daily steroid inhaler. He has a cough that started 2 days ago but no fever, sputum, or chest pain. Possible triggers include cold/windy weather and exercise. He has no other significant symptoms or sick contac

Saving the dataset (0/1 shards):   0%|          | 0/2178 [00:00<?, ? examples/s]

2025-05-22 03:36:04,978 - INFO - Dataset successfully saved to ../../data/training/sft_filtered_dataset__5000.json


In [17]:
import re
cnt = 0
for i in range(len(filtered_dataset)):    
    l = filtered_dataset[i]['label']
    # ignore all the text between <think> and </think>
    if "We've covered a fair bit of ground" in l:
        print(l)
        cnt +=1
    #else:
    #    print(re.sub(r'<think>.*?</think>', '', l, flags=re.DOTALL))
print(cnt)

0


In [31]:
filtered_dataset

Dataset({
    features: ['prompt', 'label'],
    num_rows: 2178
})

In [None]:
for i in range(20):
    # pick a random example
    random_index = random.randint(0, len(filtered_dataset) - 1)
    print(filtered_dataset['prompt'][random_index][-2:])
    print("---")

In [20]:
import tiktoken
# count the number of tokens in fine_tuning_dataset['label'][1] using toktoken and model as gpt-3.5-turbo
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
tokens = encoding.encode(fine_tuning_dataset['label'][1])
print(f"Number of tokens in fine_tuning_dataset['label'][1]: {len(tokens)}")


Number of tokens in fine_tuning_dataset['label'][1]: 42
