In [None]:
from openai_client import get_chat_completion
from models import Complaint
from typing import List

COMPLAINT_SYSTEM_MESSAGE = """
You are SynthMedGPT, a bot that generates synthetic medical data. Your task today is to generate a list of possible complaints for a patient, as well as one ore more highly relevant procedures a doctor might initiate to address their complaint. Your outputs should be in the following format:

Complaint: [complaint] Procedures: [procedure1], [procedure2], [procedure3]
Complaint: [complaint] Procedures: [procedure1], [procedure2], [procedure3]
"""

COMPLAINT_PROMPT = """
Generate for me a list of 50 complaint-procedure pairs. Do not number the new lines.
"""

def parse_complaint(complaint: str) -> Complaint:
    # Convert a string of form "Complaint: [complaint] Procedures: [procedure1], [procedure2], [procedure3]" into a Complaint object.
    complaint = complaint.strip()
    complaint = complaint.replace("Complaint: ", "")
    complaint, procedures = complaint.split(" Procedures: ")
    procedures = procedures.split(", ")
    return Complaint(
        complaint=complaint, 
        related_procedures=procedures)

def generate_complaint_strings() -> List[str]:
    response = get_chat_completion(
        messages=[
                {"role": "system", "content": COMPLAINT_SYSTEM_MESSAGE},
                {"role": "user", "content": COMPLAINT_PROMPT},
        ]
    )
    return response.choices[0]["message"]["content"].split("\n")


In [None]:
complaints = []

In [None]:
for i in range(20):
    complaint_strings = generate_complaint_strings()
    for complaint_string in complaint_strings:
        try:
            parsed_complaint = parse_complaint(complaint_string)
            print(f'Successfully parsed string for complaint: {parsed_complaint.complaint}')
            complaints.append(parsed_complaint)
        except:
            print(f'Failed to parse string for complaint: {parsed_complaint.complaint}')

In [None]:
# pickle the complaints array
import pickle

with open('complaints.pkl', 'wb') as f:
    pickle.dump(complaints, f)