In [None]:
import stanza
from stanza.server import CoreNLPClient
import re
from neo4j import GraphDatabase
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from transformers import T5Tokenizer, T5ForConditionalGeneration

class KnowledgeGraphPipeline:
    def __init__(self):
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.neo4j_url = "bolt://localhost:7687"
        self.neo4j_user = "neo4j"
        self.neo4j_password = "Author123"
        self.driver = GraphDatabase.driver(
            self.neo4j_url, 
            auth=(self.neo4j_user, self.neo4j_password)
        )
        self.summarization_model_name = "t5-large"
        self.tokenizer = T5Tokenizer.from_pretrained(self.summarization_model_name)
        self.summarization_model = T5ForConditionalGeneration.from_pretrained(self.summarization_model_name)
    
    def extract_triplets(self, text):
        with CoreNLPClient(
            annotators=['tokenize', 'ssplit', 'pos', 'lemma', 'ner', 'depparse', 'openie'],
            be_quiet=True,
            endpoint='http://localhost:9500',
            memory='6G',
            timeout=30000
        ) as client:
            ann = client.annotate(text)
            
            triplets = []
            for sentence in ann.sentence:
                for triple in sentence.openieTriple:
                    subj = triple.subject
                    rel = triple.relation
                    obj = triple.object
                    conf = triple.confidence
                    triplets.append((subj, rel, obj, conf))
            
            return triplets
    
    def normalize_triplets(self, raw_triplets, confidence_threshold=0.85):
        cleaned = set()
        for triplet in raw_triplets:
            if len(triplet) == 4:
                subj, rel, obj, conf = triplet
            else:
                subj, rel, obj = triplet
                conf = 1.0

            if conf < confidence_threshold:
                continue

            subj = subj.strip().title()
            rel = rel.strip().lower()
            obj = obj.strip()

            if subj.lower() in ["it", "the company"]:
                subj = "Tesla Inc."  
            if subj.lower() == "he":
                subj = "Jeff Bezos"
            if subj.lower() == "she":
                subj = "Tim Cook"

            if rel in ["was", "is", "are"] or obj.lower() in ["awarded", "born", "founded"]:
                continue

            rel = rel.replace(" ", "_")
            rel = re.sub(r"\b(in|on|to|by|as|from)\b", "", rel).strip("_")

            if "was_awarded_to" in rel:
                cleaned.add((subj, "awarded_to", obj))
            elif "was_awarded_in" in rel:
                cleaned.add((subj, "award_year", obj))
            elif "was_awarded" in rel:
                match = re.search(r"(.+?) in (\d{4})", obj)
                if match:
                    cleaned.add((subj, "awarded_to", match.group(1).strip()))
                    cleaned.add((subj, "award_year", match.group(2).strip()))
                else:
                    cleaned.add((subj, "awarded_to", obj))

            elif "born_in" in rel:
                cleaned.add((subj, "born_in", obj))
            elif "born_on" in rel:
                cleaned.add((subj, "birth_date", obj))
            elif "founded_by" in rel:
                cleaned.add((subj, "founded_by", obj))
            elif "founded_in" in rel:
                cleaned.add((subj, "founded_in", obj))

            elif "headquartered_in" in rel or "located_in" in rel or "based_in" in rel or "is_in" in rel:
                cleaned.add((subj, "headquartered_in", obj))

            year_match = re.search(r"(.+?) in (\d{4})", obj)
            if year_match:
                cleaned.add((subj, rel, year_match.group(1).strip()))
                cleaned.add((subj, f"{rel}_year", year_match.group(2).strip()))
            else:
                cleaned.add((subj, rel, obj))

            if subj == "Tim Cook" and "succeeded" in rel and "Steve Jobs" in obj:
                if "as" in obj:
                    cleaned.add((subj, "succeeded_steve_jobs_as", "CEO"))
                elif "in" in obj:
                    year = re.findall(r"\d{4}", obj)
                    if year:
                        cleaned.add((subj, "succeeded_steve_jobs_in", year[0]))
                    cleaned.add((subj, "succeeded", "Steve Jobs"))
                else:
                    cleaned.add((subj, "succeeded", "Steve Jobs"))

            if subj == "Jeff Bezos" and "stepped" in rel:
                if "down as" in rel:
                    cleaned.add((subj, "stepped_down_as", obj))
                if "down in" in rel:
                    cleaned.add((subj, "stepped_down_in", obj))
                if "as" in rel:
                    cleaned.add((subj, "stepped_as", obj))

        return list(cleaned)
    
    def store_graph_embeddings(self, relations):
        with self.driver.session() as session:
            session.run("MATCH (n) DETACH DELETE n")
            
            session.run("""UNWIND $relations AS rel
                        MERGE (source:Entity {name: rel[0]})
                        MERGE (target:Entity {name: rel[2]})
                        MERGE (source)-[r:RELATIONSHIP {type: rel[1]}]->(target)
                        """, relations=relations)
            
            nodes = session.run("MATCH (e:Entity) RETURN e.name as name").data()
            for node in nodes:
                embedding = self.embedding_model.encode(node['name']).tolist()
                session.run("""
                    MATCH (e:Entity {name: $name})
                    SET e.embedding = $embedding
                    """, name=node['name'], embedding=embedding)
            
            rels = session.run("""
                MATCH (s)-[r:RELATIONSHIP]->(t)
                RETURN s.name AS source, r.type AS type, t.name AS target
                """).data()
            
            for rel in rels:
                text = f"{rel['source']} {rel['type']} {rel['target']}"
                embedding = self.embedding_model.encode(text).tolist()
                session.run("""
                    MATCH (s {name: $source})-[r:RELATIONSHIP {type: $type}]->(t {name: $target})
                    SET r.embedding = $embedding
                    """, 
                    source=rel['source'], type=rel['type'], 
                    target=rel['target'], embedding=embedding)
    
    def create_vector_indexes(self):
        with self.driver.session() as session:
            session.run("""
                CREATE VECTOR INDEX entity_embeddings IF NOT EXISTS
                FOR (e:Entity) ON e.embedding
                OPTIONS {indexConfig: {
                  `vector.dimensions`: 384,
                  `vector.similarity_function`: 'cosine'
                }}
            """)
            
            session.run("""
                CREATE VECTOR INDEX relationship_embeddings IF NOT EXISTS
                FOR ()-[r:RELATIONSHIP]-() ON r.embedding
                OPTIONS {indexConfig: {
                  `vector.dimensions`: 384,
                  `vector.similarity_function`: 'cosine'
                }}
            """)
    
    def semantic_graph_search(self, query):
        query_embedding = self.embedding_model.encode(query).reshape(1, -1)
        
        with self.driver.session() as session:

            nodes = session.run("MATCH (e:Entity) RETURN e.name as name, e.embedding as embedding").data()
            
            if nodes:
                names = [n['name'] for n in nodes]
                embeddings = [n['embedding'] for n in nodes]
                similarities = cosine_similarity(query_embedding, embeddings)[0]
                
                print("Top matching entities:")
                for idx in np.argsort(similarities)[-5:][::-1]:
                    print(f"- {names[idx]} (score: {similarities[idx]:.3f})")
        

            rels = session.run("""
                MATCH (s)-[r:RELATIONSHIP]->(t)
                RETURN s.name as source, r.type as type, t.name as target, r.embedding as embedding
                """).data()
            
            if rels:
                triples = [f"{r['source']} {r['type']} {r['target']}" for r in rels]
                embeddings = [r['embedding'] for r in rels]
                similarities = cosine_similarity(query_embedding, embeddings)[0]
                
                print("\nTop matching relationships:")
                for idx in np.argsort(similarities)[-5:][::-1]:
                    print(f"- {triples[idx]} (score: {similarities[idx]:.3f})")
                top_indices = np.argsort(similarities)[-10:][::-1]
                return [(triples[idx], similarities[idx], rels[idx]) for idx in top_indices]
        
        return []
    
    def generate_summary(self, num_key_relations=10, query=None, relations=None):
        if query and not relations:
            relations_data = self.retrieve_query_relevant_relations(query)
            if not relations_data:
                return "No relevant information found for this query."
        elif not relations:
            with self.driver.session() as session:
                relations_data = session.run("""
                    MATCH (s)-[r:RELATIONSHIP]->(t)
                    RETURN s.name as source, r.type as type, t.name as target
                    LIMIT $limit
                    """, limit=num_key_relations).data()
        else:
            relations_data = relations
        formatted_relations = []
        
        entity_context = {}
        seen_content = set()  
        
        if isinstance(relations_data[0], tuple) and len(relations_data[0]) >= 3:
            for rel in relations_data:
                if isinstance(rel[2], dict):
                    source = rel[2]['source']
                    rel_type = rel[2]['type']
                    target = rel[2]['target']
                else:
                    rel_parts = rel[0].split()
                    if len(rel_parts) >= 3:
                        source = rel_parts[0]
                        rel_type = rel_parts[1]
                        target = ' '.join(rel_parts[2:])
                    else:
                        continue 
                
                relation_text = f"{source} {rel_type.replace('_', ' ')} {target}"
                if relation_text.lower() not in seen_content:
                    formatted_relations.append(relation_text)
                    seen_content.add(relation_text.lower())
                    

                    if source not in entity_context:
                        entity_context[source] = []
                    entity_context[source].append((rel_type, target))
        else:
            for rel in relations_data:
                relation_text = f"{rel['source']} {rel['type'].replace('_', ' ')} {rel['target']}"
                if relation_text.lower() not in seen_content:
                    formatted_relations.append(relation_text)
                    seen_content.add(relation_text.lower())
                    if rel['source'] not in entity_context:
                        entity_context[rel['source']] = []
                    entity_context[rel['source']].append((rel['type'], rel['target']))
        
        facts = []
        
        for entity, relations in entity_context.items():
            entity_facts = []
            for rel_type, target in relations:
                fact = f"{entity} {rel_type.replace('_', ' ')} {target}"
                entity_facts.append(fact)
            
            facts.extend(entity_facts)
        
        if len(facts) < 5 and query:
            with self.driver.session() as session:
                additional_rels = session.run("""
                    MATCH (s)-[r:RELATIONSHIP]->(t)
                    WHERE toLower(s.name) CONTAINS toLower($query) OR 
                          toLower(t.name) CONTAINS toLower($query) OR
                          toLower(r.type) CONTAINS toLower($query)
                    RETURN s.name as source, r.type as type, t.name as target
                    LIMIT 10
                """, query=query).data()
                
                for rel in additional_rels:
                    fact = f"{rel['source']} {rel['type'].replace('_', ' ')} {rel['target']}"
                    if fact.lower() not in seen_content:
                        facts.append(fact)
                        seen_content.add(fact.lower())
        

        if not facts:
            return f"I don't have enough information about {query} in my knowledge graph."
            
        facts_text = "\n".join([f"- {fact}" for fact in facts])
        
        if query:
            input_text = f"""
            summarize: Create a coherent paragraph about {query} based on these facts:
            
            {facts_text}
            
            Write a natural-sounding paragraph that summarizes the key information. 
            """
        else:
            input_text = f"""
            summarize: Create a coherent paragraph based on these facts:
            
            {facts_text}
            
            Write a natural-sounding paragraph that summarizes the key information. 
            """
        

        inputs = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
        

        outputs = self.summarization_model.generate(
            inputs,
            max_length=200,
            min_length=50,
            num_beams=8,
            length_penalty=2.0,
            early_stopping=True,
            no_repeat_ngram_size=3
        )
        
        summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        summary = summary.replace("Summary:", "").strip()
        
        if summary and len(summary) > 0:
            summary = summary[0].upper() + summary[1:]
        
        instruction_patterns = [
            "focus on key details", 
            "avoid repetition", 
            "present the information", 
            "coherent narrative", 
            "write a natural",
            "summarize:",
            "create a coherent"
        ]
        
        for pattern in instruction_patterns:
            if summary.lower().startswith(pattern):
                first_period = summary.find('.', summary.lower().find(pattern))
                if first_period > 0:
                    summary = summary[first_period+1:].strip()
        
        if len(summary) > 50:
            first_half = summary[:len(summary)//2]
            second_half = summary[len(summary)//2:]
            
            if first_half.strip() == second_half.strip():
                summary = first_half
        
        return summary
    
    def retrieve_query_relevant_relations(self, query, threshold=0.5, max_relations=15):
        """
        Retrieve relations from the knowledge graph that are relevant to the query.
        Returns a list of the most relevant relations based on semantic similarity.
        """
        relevant_relations = self.semantic_graph_search(query)
        
        filtered_relations = [rel for rel in relevant_relations if rel[1] > threshold]
        
        if len(filtered_relations) < 5 and relevant_relations:
            threshold = 0.4  
            filtered_relations = [rel for rel in relevant_relations if rel[1] > threshold]
        

        query_embedding = self.embedding_model.encode(query).reshape(1, -1)
        
        with self.driver.session() as session:
            entities = session.run("MATCH (e:Entity) RETURN e.name as name, e.embedding as embedding").data()
            
            if entities:
                names = [n['name'] for n in entities]
                embeddings = [n['embedding'] for n in entities]
                similarities = cosine_similarity(query_embedding, embeddings)[0]
                top_entities = [names[idx] for idx in np.argsort(similarities)[-2:][::-1]]
                
                for entity in top_entities:
                    entity_rels = session.run("""
                        MATCH (s:Entity {name: $name})-[r:RELATIONSHIP]->(t)
                        RETURN s.name as source, r.type as type, t.name as target
                        UNION
                        MATCH (s:Entity)-[r:RELATIONSHIP]->(t:Entity {name: $name})
                        RETURN s.name as source, r.type as type, t.name as target
                        LIMIT 5
                    """, name=entity).data()
                    
                    for rel in entity_rels:
                        rel_text = f"{rel['source']} {rel['type']} {rel['target']}"
                        
                        if not any(rel_text == existing[0] for existing in filtered_relations):
                            filtered_relations.append((rel_text, 0.51, rel)) 
        
        return filtered_relations[:max_relations]
    
    def run_pipeline(self, text):
    
        print("Extracting triplets...")
        raw_triplets = self.extract_triplets(text)
        
        print("Normalizing triplets...")
        normalized_triplets = self.normalize_triplets(raw_triplets)
        
        print("Storing in knowledge graph...")
        self.store_graph_embeddings(normalized_triplets)
        self.create_vector_indexes()
        
        '''print("Generating summary...")
        summary = self.generate_summary()
        
        print("\nFinal Summary:")
        print(summary)
        
        return summary'''   
    
    def fix_entity_references(self, text, main_entity=None):
        if not main_entity:
            return text
            
        sentences = text.split('.')
        fixed_sentences = []
        
        for i, sentence in enumerate(sentences):
            sentence = sentence.strip()
            if not sentence:
                continue
                

            if main_entity.lower() in sentence.lower():
                fixed_sentences.append(sentence)
                continue
                
            lower_sentence = sentence.lower()
            if i > 0 and (lower_sentence.startswith('it ') or 
                          lower_sentence.startswith('its ') or 
                          lower_sentence.startswith('this ') or
                          lower_sentence.startswith('the company')):
                

                if lower_sentence.startswith('it '):
                    sentence = main_entity + sentence[2:]

                elif lower_sentence.startswith('its '):
                    sentence = main_entity + "'s" + sentence[3:]

                elif lower_sentence.startswith('this '):
                    sentence = main_entity + sentence[4:]

                elif 'the company' in lower_sentence:
                    sentence = sentence.lower().replace('the company', main_entity)

                    sentence = sentence[0].upper() + sentence[1:]
            
            fixed_sentences.append(sentence)
        

        fixed_text = '. '.join(fixed_sentences)
        if not fixed_text.endswith('.'):
            fixed_text += '.'
            
        return fixed_text
        
    def query_knowledge_graph(self, user_query):
        print(f"Processing query: {user_query}")
        query_parts = user_query.split()
        
        with self.driver.session() as session:
            entities = session.run("MATCH (e:Entity) RETURN e.name as name").data()
            entity_names = [e['name'] for e in entities]
            
            mentioned_entities = []
            for entity in entity_names:
                if entity.lower() in user_query.lower() or any(part.lower() == word.lower() 
                                                              for part in entity.split() 
                                                              for word in query_parts):
                    mentioned_entities.append(entity)
        
        entity_relations = []
        if mentioned_entities:
            print(f"Detected entities in query: {', '.join(mentioned_entities)}")
            with self.driver.session() as session:
                for entity in mentioned_entities:
                    subj_rels = session.run("""
                        MATCH (s:Entity {name: $name})-[r:RELATIONSHIP]->(t)
                        RETURN s.name AS source, r.type AS type, t.name AS target
                    """, name=entity).data()
                    
                    obj_rels = session.run("""
                        MATCH (s:Entity)-[r:RELATIONSHIP]->(t:Entity {name: $name})
                        RETURN s.name AS source, r.type AS type, t.name AS target
                    """, name=entity).data()
                    
                    for rel in subj_rels + obj_rels:
                        entity_relations.append((
                            f"{rel['source']} {rel['type']} {rel['target']}", 
                            1.0,  
                            rel
                        ))
        semantic_relations = self.retrieve_query_relevant_relations(user_query)
        if not entity_relations and not semantic_relations:
            return "I couldn't find any relevant information about that query in the knowledge graph."
        all_relations = []
        seen_relation_texts = set()
        for rel, score, rel_dict in entity_relations:
            if rel not in seen_relation_texts:
                all_relations.append((rel, score, rel_dict))
                seen_relation_texts.add(rel)
        
        for rel, score, rel_dict in semantic_relations:
            if rel not in seen_relation_texts:
                all_relations.append((rel, score, rel_dict))
                seen_relation_texts.add(rel)
        
        all_relations.sort(key=lambda x: x[1], reverse=True)
        
        print(f"\nTop matching relationships:")
        for rel, score, _ in all_relations[:10]: 
            print(f"- {rel} (score: {score:.3f})")
        
        print(f"\nFound {len(all_relations)} relevant relationships:")
        for rel, score, _ in all_relations[:10]:  
            print(f"- {rel} (score: {score:.3f})")
        
        summary = self.generate_summary(query=user_query, relations=all_relations)
        
        main_entity = mentioned_entities[0] if mentioned_entities else None
        summary = self.fix_entity_references(summary, main_entity)
        
        print(f"\nQuery-specific summary: {summary}")
        return summary


if __name__ == "__main__":

    text = """
    In 2022, Tesla Inc., led by CEO Elon Musk, became the world's most valuable car manufacturer, surpassing Toyota in market capitalization. The company, headquartered in Palo Alto, California, is known for its electric vehicles such as the Model S, Model 3, and Model X. Tesla's Gigafactory in Nevada plays a critical role in battery production and employs over 7,000 people.

    Meanwhile, SpaceX, another company founded by Elon Musk, launched the Crew Dragon spacecraft to the International Space Station (ISS) under a partnership with NASA. The launch was conducted from the Kennedy Space Center in Florida. The ISS is jointly operated by NASA, Roscosmos, ESA, JAXA, and CSA, and orbits the Earth approximately every 90 minutes.

    In the tech industry, Google, a subsidiary of Alphabet Inc., acquired Fitbit in 2021 to expand its presence in the wearable technology market. Apple Inc., on the other hand, continues to dominate the smartphone industry with the iPhone, which it manufactures in collaboration with Foxconn. Tim Cook succeeded Steve Jobs as CEO of Apple in 2011.

    Across the globe, Amazon operates hundreds of fulfillment centers to ensure quick delivery of products. It uses machine learning algorithms to optimize its supply chain and predict consumer demand. Jeff Bezos founded Amazon in 1994 and stepped down as CEO in 2021, appointing Andy Jassy as his successor.

    In the field of AI research, OpenAI developed GPT-4, a state-of-the-art language model capable of understanding and generating human-like text. It has been integrated into several applications, including Microsoft's Copilot and ChatGPT. Microsoft invested heavily in OpenAI, incorporating its technology into the Azure cloud platform.

    In climate science, the Intergovernmental Panel on Climate Change (IPCC) published a report highlighting the urgent need to reduce greenhouse gas emissions. The Paris Agreement, signed by 196 countries in 2015, set out goals to limit global warming to below 2 degrees Celsius.
    """
    

    pipeline = KnowledgeGraphPipeline()
    pipeline.run_pipeline(text)
    

    while True:
        user_query = input("\nEnter your query (or 'exit' to quit): ")
        if user_query.lower() == 'exit':
            break
        
        summary = pipeline.query_knowledge_graph(user_query)
        print(f"\nSummary: {summary}")
    
    pipeline.driver.close()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
2025-05-10 08:50:42 INFO: Writing properties to tmp file: corenlp_server-80300291fff34cf4.props
2025-05-10 08:50:42 INFO: Starting server with command: java -Xmx6G -cp /Users/abhinavkrishna/stanza_corenlp/* edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9500 -timeout 30000 -threads 5 -maxCharLength 100000 -quiet True -serverProperties corenlp_server-80300291fff34cf4.props -annotators tokenize,ssplit,pos,lemma,ner,depparse,openie -preload -outputFormat serialized


Extracting triplets...
Normalizing triplets...
Storing in knowledge graph...
Processing query: Tesla
Detected entities in query: Model X. Tesla 'S Gigafactory, Model X. Tesla, Tesla Inc.


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Top matching entities:
- Tesla Inc. (score: 0.802)
- Model X. Tesla (score: 0.729)
- Model X. Tesla 'S Gigafactory (score: 0.574)
- its electric vehicles (score: 0.561)
- world 's car manufacturer (score: 0.534)

Top matching relationships:
- Tesla Inc. became world 's car manufacturer (score: 0.684)
- Tesla Inc. became_in 2022 (score: 0.678)
- Tesla Inc. became world 's valuable car manufacturer (score: 0.678)
- Tesla Inc. became world 's most valuable car manufacturer (score: 0.670)
- Tesla Inc. has has integrated into several applications (score: 0.621)


  ret = a @ b
  ret = a @ b
  ret = a @ b



Top matching relationships:
- Model X. Tesla 'S Gigafactory headquartered_in Nevada (score: 1.000)
- Model X. Tesla 'S Gigafactory is_in Nevada (score: 1.000)
- Model X. Tesla has Gigafactory in Nevada (score: 1.000)
- Tesla Inc. became world 's most valuable car manufacturer (score: 1.000)
- Tesla Inc. has has integrated into several applications including Microsoft 's Copilot (score: 1.000)
- Tesla Inc. manufactures_in collaboration with Foxconn (score: 1.000)
- Tesla Inc. has has integrated into applications including Microsoft 's Copilot (score: 1.000)
- Tesla Inc. has has integrated into applications (score: 1.000)
- Tesla Inc. became world 's valuable car manufacturer (score: 1.000)
- Tesla Inc. has has integrated (score: 1.000)

Found 15 relevant relationships:
- Model X. Tesla 'S Gigafactory headquartered_in Nevada (score: 1.000)
- Model X. Tesla 'S Gigafactory is_in Nevada (score: 1.000)
- Model X. Tesla has Gigafactory in Nevada (score: 1.000)
- Tesla Inc. became world 's mo