# Extract graph

In [1]:
from dotenv import load_dotenv
import json
from time import sleep
load_dotenv("../.env")

True

In [2]:
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from langchain.chains import LLMChain
from langchain.llms import BaseLLM
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate


In [3]:
chat = ChatOpenAI(model="gpt-3.5-turbo-16k")

Build a chain for knowledge extraction

In [4]:
class KGChain(LLMChain):
    @classmethod
    def from_llm(cls, llm: BaseLLM):

        prompt = PromptTemplate(
            template="""
    Please extract knowledge from the following piece of text from a paper.:
    {paper_text}, current knowledge graph structure is:{knowledge_graph},
    please enhance the knowledge graph with the extracted knowledge if necessary.
    The node types include, but not limited to:
    * disease
    * symptom
    * diagnosis
    * stage
    * test
    * treatment
    * drug
    * gene
    * protein
    The edge types include, but not limited to:
    * has_symptom
    * associate_with
    * diagnose
    * has_stage
    * cure
    if it's references and background research ignore it.
    Please return only the json structure between '###'
    """,
            input_variables=[
                "paper_text",
                "knowledge_graph",
            ],
        )

        return cls(
            llm=llm,
            prompt=prompt,
            verbose=True,
        )

In [5]:
EMPTY_KG_DATA = dict(
    nodes=[
        dict(ntype="disease", name="Esophageal Carcinoma", id=1),
        dict(ntype="symptom", name="Dysphagia", id=2),
    ],
    edges=[
        dict(etype="has_symptom", source=1, target=2),
    ],
)

EMPTY_KG = f"###\n{json.dumps(EMPTY_KG_DATA,)}\n###"

In [6]:
extractor = KGChain.from_llm(llm=chat)

## Data
How to get more data please refer to [crawler](crawler.ipynb)

In [7]:
from pathlib import Path

In [8]:
DATA = Path("./data/")
TEXT = DATA / "text"

In [9]:
paper_list = list(TEXT.glob("*.txt"))
paper_list

[PosixPath('data/text/PMC9713002.txt'),
 PosixPath('data/text/PMC9709273.txt'),
 PosixPath('data/text/PMC9713855.txt'),
 PosixPath('data/text/PMC9708733.txt'),
 PosixPath('data/text/PMC9722938.txt'),
 PosixPath('data/text/PMC9712805.txt'),
 PosixPath('data/text/PMC9713848.txt'),
 PosixPath('data/text/PMC9712015.txt'),
 PosixPath('data/text/PMC9708886.txt'),
 PosixPath('data/text/PMC9714501.txt'),
 PosixPath('data/text/PMC9709130.txt'),
 PosixPath('data/text/PMC9711964.txt'),
 PosixPath('data/text/PMC9713810.txt')]

In [10]:
from tqdm.auto import tqdm
from typing import Iterator

In [32]:
def get_json_string(text, json_end: str = "###") -> str:
    if json_end not in text:
        return text.strip()
    else:
        return text.split(json_end)[1].strip()


def text_slicer(text_path, char_size: int = 3000)->Iterator:
    with open(text_path) as f:
        text = f.read()
    for i in tqdm(list(range(0, len(text), char_size)), leave=False):
        yield text[i:i+char_size]


def extract_one_paper(text_path: Path, char_size: int = 2000 ) -> str:
    kg = EMPTY_KG
    i = 0
    for text in text_slicer(text_path, char_size):
        res = extractor.run(
            paper_text=text,
            knowledge_graph=kg,
        )
        i += 1
        if i>5:
            continue
        sleep(2)

        try:
            kg = json.dumps(json.loads(get_json_string(res)))
        except json.decoder.JSONDecodeError:
            print(f"🥊 Error at {text_path} -({i})")
    return kg
    

In [33]:
from tqdm.auto import tqdm
from openai.error import RateLimitError

In [25]:
results = []
done = []

In [35]:
for paper in tqdm(paper_list):
    if str(paper) in done:
        continue
    try:
        kg_string = extract_one_paper(paper)
    except RateLimitError as e:
        print(f"rate limit error:{paper}")
        sleep(20.1)
        continue
    row = dict(
        kg=kg_string, file_path=str(paper)
    )
    results.append(row)
    done.append(str(paper))

  0%|          | 0/13 [00:00<?, ?it/s]

In [37]:
import pandas as pd

In [38]:
kg_df = pd.DataFrame(results)
kg_df

Unnamed: 0,kg,file_path
0,"{""nodes"": [{""ntype"": ""disease"", ""name"": ""Esoph...",data/text/PMC9713002.txt
1,"{""nodes"": [{""ntype"": ""disease"", ""name"": ""Barre...",data/text/PMC9709273.txt
2,"{""nodes"": [{""ntype"": ""disease"", ""name"": ""Esoph...",data/text/PMC9713855.txt
3,"{""nodes"": [{""ntype"": ""treatment"", ""name"": ""Tis...",data/text/PMC9708733.txt
4,"{""nodes"": [{""ntype"": ""disease"", ""name"": ""Esoph...",data/text/PMC9722938.txt
5,"{""nodes"": [{""ntype"": ""gene"", ""name"": ""IL-15"", ...",data/text/PMC9712805.txt
6,"{""nodes"": [{""ntype"": ""disease"", ""name"": ""Esoph...",data/text/PMC9713848.txt
7,"{""nodes"": [{""ntype"": ""disease"", ""name"": ""Esoph...",data/text/PMC9712015.txt
8,"{""nodes"": [{""ntype"": ""disease"", ""name"": ""Esoph...",data/text/PMC9708886.txt
9,"{""nodes"": [{""ntype"": ""disease"", ""name"": ""Esoph...",data/text/PMC9714501.txt


In [26]:
json.loads(kg_string)

{'nodes': [{'ntype': 'diagnosis',
   'name': 'Lymphovascular Invasion (LVI)',
   'id': 3},
  {'ntype': 'diagnosis', 'name': 'Lymph Node Metastasis (LNM)', 'id': 4},
  {'ntype': 'diagnosis',
   'name': 'Superficial Esophageal Squamous Cell Carcinoma (SESCC)',
   'id': 5},
  {'ntype': 'stage', 'name': 'Tumor Size', 'id': 6},
  {'ntype': 'stage', 'name': 'Circumferential Extension', 'id': 7},
  {'ntype': 'stage', 'name': 'Location within Esophagus', 'id': 8},
  {'ntype': 'stage', 'name': 'Depth of Invasion', 'id': 9},
  {'ntype': 'stage', 'name': 'Tumor Differentiation', 'id': 10},
  {'ntype': 'stage', 'name': 'Macroscopic Type', 'id': 11},
  {'ntype': 'stage', 'name': 'Multiple Lesions', 'id': 12},
  {'ntype': 'test', 'name': 'SPSS', 'id': 13},
  {'ntype': 'test', 'name': 'R', 'id': 14},
  {'ntype': 'treatment', 'name': 'Esophagectomy', 'id': 15},
  {'ntype': 'treatment', 'name': 'Lymph Node Dissection', 'id': 16}],
 'edges': [{'etype': 'associate_with', 'source': 3, 'target': 5},
  {'et

In [34]:
from jinja2 import Template

In [39]:
with open("vis-kg.html") as f:
    nodes_template = Template(f.read())