In [4]:
from openai import OpenAI
from langchain.prompts import PromptTemplate
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
import json
import re

openai_api_key = "xxx"  

In [5]:
class GenePhenotypeExtractor:
    def __init__(self, openai_api_key):
        # Initialize OpenAI client
        self.client = OpenAI(api_key=openai_api_key)
        
        # Define response schemas for the required fields
        self.phenotypes_field = ResponseSchema(
            name="phenotypes",
            description="Extract the list of phenotypes present in the patient from the request. Format the output as a JSON list."
        )
        
        self.genes_field = ResponseSchema(
            name="genes",
            description="Extract the list of genes mentioned in the patient request. Format the output as a JSON list."
        )
        
        # Combine the schemas into a structured output parser
        self.request_metadata_output_schema_parser = StructuredOutputParser.from_response_schemas(
            [
                self.phenotypes_field,
                self.genes_field
            ]
        )
        self.request_metadata_output_schema = self.request_metadata_output_schema_parser.get_format_instructions()
        
        # Define the prompt template
        self.request_metadata_prompt_template_str = """
        Given the following patient request, extract the following metadata according to the format instructions below.
        << FORMATTING >>
        {format_instructions}
        << INPUT >>
        {user_request}
        << OUTPUT (remember to include the ```json)>>"""
        self.request_metadata_prompt_template = PromptTemplate.from_template(
            template=self.request_metadata_prompt_template_str)

    def extract_metadata(self, user_request):
        # Generate the prompt
        prompt = self.request_metadata_prompt_template.format(
            format_instructions=self.request_metadata_output_schema,
            user_request=user_request
        )
        
        # Send the request to OpenAI and get the response
        response = self.client.chat.completions.create(
            model="gpt-4-turbo-preview",
            messages=[
                {"role": "system", "content": "You are a helpful assistant that extracts gene and phenotype information from patient requests."},
                {"role": "user", "content": prompt}
            ],
            temperature=0,
            max_tokens=150
        )
        
        # Extract the content from the response
        content = response.choices[0].message.content

        # Extract JSON from the content
        json_match = re.search(r'```json\n(.*?)\n```', content, re.DOTALL)
        if json_match:
            json_str = json_match.group(1)
            try:
                data = json.loads(json_str)
                phenotypes = data.get('phenotypes', [])
                genes = data.get('genes', [])
                return phenotypes, genes
            except json.JSONDecodeError:
                print("Error decoding JSON from the response.")
                return [], []
        else:
            print("No JSON found in the response.")
            return [], []



In [6]:
extractor = GenePhenotypeExtractor(openai_api_key=openai_api_key)

In [7]:
user_request = "Given a patient with these phenotypes: HP:0001249, HP:0001254, HP:0000712 and these genes: ENSG00000146085, what could be the disease?"

In [8]:
# Extract metadata
phenotypes, genes = extractor.extract_metadata(user_request)
print("Phenotypes:", phenotypes)
print("Genes:", genes)

Phenotypes: ['HP:0001249', 'HP:0001254', 'HP:0000712']
Genes: ['ENSG00000146085']
