In [4]:
import ollama

from trying_graph_rag.graph_rag.types import Entity, Relationship

In [1]:
with open("../trying_graph_rag/graph_rag/prompts/entities_and_relationships_extraction.txt") as file:
    PROMPT = file.read()

In [52]:
def parse_output(output: str, tuple_delimiter: str = r'<>', record_delimiter: str = '\n', completion_delimiter: str = '###END###'):
    output = output.replace(completion_delimiter, '')
    records = output.strip().split(record_delimiter)
    
    entities: list[Entity] = []
    relationships: list[Relationship] = []
    
    for record in records:
        record = record.strip().lstrip('(').rstrip(')')
        # skip empty records
        if not record:
            continue
        
        record_content = [record_field.strip().strip("'\"") for record_field in record.split(tuple_delimiter)]
        
        # skip empty records pt2
        if not len(record_content) or not all(record_content):
            continue
        
        record_type = record_content[0]
        if record_type not in ['entity', 'relationship']:
            raise ValueError(f"Invalid record type: {record_content}")
        
        if record_type == 'entity' and len(record_content) == 4:
            entity = Entity(
                name=record_content[1],
                type=record_content[2],
                description=record_content[3]
            )
            entities.append(entity)
        elif record_type == 'relationship' and len(record_content) == 5:
            try:
                relationship_strength = int(record_content[4])
            except ValueError:
                raise ValueError("Invalid relationship strength")
            
            relationship = Relationship(
                source_entity=record_content[1],
                target_entity=record_content[2],
                description=record_content[3],
                strength=relationship_strength
            )
            relationships.append(relationship)
        else:
            print(f"Invalid record format: {record}")
    
    return entities, relationships

In [53]:
def extract_entities_and_relationships(document: str, entity_types: list[str], tuple_delimiter: str = r'<>', record_delimiter: str = '\n', completion_delimiter: str = '###END###') -> str:
    ollama_response = ollama.generate(model='gemma2:2b', prompt=PROMPT.format(input_text=document, entity_types=str(entity_types)[1:-1], tuple_delimiter=tuple_delimiter, record_delimiter=record_delimiter, completion_delimiter=completion_delimiter), options={"temperature": 0})
    
    content = ollama_response['response']
    return parse_output(content, tuple_delimiter, record_delimiter, completion_delimiter)

In [54]:
document = "The quick brown fox jumps over the lazy dog."
entity_types = ['animal', 'color']
extract_entities_and_relationships(document, entity_types)

"entity"<>"fox"<>"animal"<>"A small, agile animal known for its reddish-brown fur and swiftness."
['entity', 'fox', 'animal', 'A small, agile animal known for its reddish-brown fur and swiftness.']
"entity"<>"dog"<>"animal"<>"A domesticated canine with a thick coat of fur and a wagging tail."
['entity', 'dog', 'animal', 'A domesticated canine with a thick coat of fur and a wagging tail.']
"""


([Entity(entity_name='fox', entity_type='animal', entity_description='A small, agile animal known for its reddish-brown fur and swiftness.'),
  Entity(entity_name='dog', entity_type='animal', entity_description='A domesticated canine with a thick coat of fur and a wagging tail.')],
 [])