# Extract Knowledge Graph with 1 Shot Extraction

In [70]:
import json
import os
from pathlib import Path

from dotenv import load_dotenv
import tiktoken

import requests

In [10]:
load_dotenv("../.env")

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_API_KEY[:5]

True

In [13]:
DATA = Path("data")
PAPERS = DATA / "text"
papers = list(PAPERS.glob("*.txt"))
text = papers[0].read_text()

In [20]:
MODEL_NAME = "gpt-4-1106-preview"

In [30]:
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")

In [80]:
def capping_tokens(
    text,
    max_tokens: int=30e3
) -> str:
    tokens = tokenizer.encode(text)
    if len(tokens) > max_tokens:
        tokens = tokens[:int(max_tokens)]
    return tokenizer.decode(tokens)

## Role prompt

In [66]:
ROLE = """
Please extract knowledge from the following piece of text from a paper.
please enhance the knowledge graph with the extracted knowledge if necessary.
The node types include, but not limited to:
disease, symptom, diagnosis, stage, test, treatment, drug, gene, protein
Please use the following format to extract the knowledge, the edge's source and target could only be the node's integer id, return json only.
{
    "nodes": [
        {
            "id": 1,
            "type": "disease",
            "name": "COVID-19"
        },
        {
            "id": 2,
            "type": "symptom",
            "name": "fever"
        }
    ],
    "edges": [
        {
            "source": 1,
            "target": 2,
            "type": "has_symptom"
        }
    ]
}
"""

In [67]:
def get_payload(
    system_role: str,
    user_message: str
) -> str:
    """
    system_role: str
        The role of the system, e.g. "system"
    user_message: str
        The user's message
    """
    return {
        "model": MODEL_NAME,
        "response_format": {"type": "json_object"},
        "messages": [
            {
                "role":"system",
                "content": system_role
            },
            {
                "role":"user",
                "content": user_message
            }
        ],
    }

In [77]:
def get_graphs(
    text: str,
    paper_max_token: int=60e3,
):
    # capping for max tokens
    text = capping_tokens(text)
    payload = get_payload(ROLE, text)

    res = requests.post(
        "https://api.openai.com/v1/chat/completions",
        json=payload,
        headers={"Authorization": f"Bearer {OPENAI_API_KEY}"},
    )
    json_data = res.json()
    if 'choices' in json_data:
        return_text = json_data['choices'][0]['message']['content']
    try:
        return json.loads(return_text)
    except json.decoder.JSONDecodeError:
        print("json parse error")
        return {
            "nodes": [],
            "edges": []
        }

In [78]:
kg = get_graphs(text)

In [79]:
kg

{'nodes': [{'id': 1,
   'type': 'disease',
   'name': 'superficial esophageal squamous cell carcinoma'},
  {'id': 2, 'type': 'disease', 'name': 'lymph node metastasis'},
  {'id': 3, 'type': 'diagnosis', 'name': 'lymphovascular invasion'},
  {'id': 4, 'type': 'test', 'name': 'nomogram'},
  {'id': 5, 'type': 'treatment', 'name': 'esophagectomy'},
  {'id': 6, 'type': 'treatment', 'name': 'endoscopic resection'},
  {'id': 7, 'type': 'test', 'name': 'endoscopic ultrasonography'},
  {'id': 8, 'type': 'test', 'name': 'contrast-enhanced computed tomography'},
  {'id': 9,
   'type': 'test',
   'name': 'fluorodeoxyglucose positron emission tomography'},
  {'id': 10,
   'type': 'symptom',
   'name': 'lymphatic or blood vessel tumor cells'},
  {'id': 11, 'type': 'test', 'name': 'endoscopic mucosal dissection'},
  {'id': 12, 'type': 'test', 'name': 'endoscopic mucosal resection'},
  {'id': 13, 'type': 'treatment', 'name': 'lymph node dissection'},
  {'id': 14, 'type': 'diagnosis', 'name': 'lymph no