In [None]:
import os
from neo4j import GraphDatabase

class Neo4jApp:
    #not_in_my_scope_string = "What your referring to is outside my scope. Ask for the taxonomy."
    def __init__(self):
        uri = os.getenv("uri", "neo4j://0.0.0.0:7687")
        user = os.getenv("user", "neo4j")
        password = os.getenv("password", "neo4j-connect")
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def get_clusters_by_topic(self, topic_name):
        with self.driver.session() as session:
            clusters = session.execute_read(self._retrieve_clusters_of_topic, topic_name)
            return clusters
        
    def get_subclusters_by_cluster_id(self, cluster_id):
        with self.driver.session() as session:
            subclusters = session.execute_read(self._retrieve_subclusters, cluster_id)
            return subclusters
        
    def get_publications_by_cluster_id(self, cluster_id):
        with self.driver.session() as session:
            publications = session.execute_read(self._retrieve_publications, cluster_id)
            return publications    
        

    def get_publications_by_cluster_and_subclusters(self, cluster_id):
        with self.driver.session() as session:
            publications = session.execute_read(self._retrieve_publications_and_subclusters, cluster_id)
            return publications
        
    def get_all_cluster_and_subcluster_ids_and_tags_by_topic(self, topic_name):
        with self.driver.session() as session:
            all_cluster_info = session.execute_read(self._retrieve_all_cluster_and_subcluster_ids_and_tags_by_topic, topic_name)
            return all_cluster_info    
        
    def update_zephyr_cluster_tag(self, cluster_id, new_tag_value):
        with self.driver.session() as session:
            new_zephyr_cluster_tag = session.execute_write(self.add_zephyr_cluster_tag, cluster_id, new_tag_value)
            return new_zephyr_cluster_tag    
    

    @staticmethod
    def _retrieve_clusters_of_topic(tx, topic_name):
        query = """
        MATCH (topic:FieldOfStudy)-[:HAS_CLUSTER]->(cluster)
        WHERE toUpper(topic.label) = toUpper($topic_name)
        OPTIONAL MATCH (cluster)-[:HAS_PUBLICATION]->(pub:Publication)
        WITH cluster, COUNT(DISTINCT pub) AS publicationsCount
        RETURN cluster.id AS ClusterID, cluster.tag AS Tag, cluster.level AS Level, publicationsCount
        ORDER BY publicationsCount DESC
        """
        result = tx.run(query, topic_name=topic_name)
        clusters = [{"ClusterID": record["ClusterID"], "Tag": record["Tag"], "Level": record["Level"]} for record in result]
        return clusters

    @staticmethod
    def _retrieve_subclusters(tx, cluster_id):
        query = """
        MATCH (parent:Cluster {id: $cluster_id})-[:HAS_SUBCLUSTER]->(subcluster)
        OPTIONAL MATCH (subcluster)-[:HAS_PUBLICATION]->(pub:Publication)
        WITH subcluster, COUNT(DISTINCT pub) AS publicationsCount
        RETURN subcluster.id AS SubclusterID, subcluster.tag AS Tag, subcluster.level AS Level, publicationsCount
        ORDER BY publicationsCount DESC
        """
        result = tx.run(query, cluster_id=cluster_id)
        subclusters = [{"ClusterID": record["SubclusterID"], "Tag": record["Tag"], "Level": record["Level"]} for record in result]
        return subclusters
    
    @staticmethod
    def _retrieve_publications(tx, cluster_id):
        query = """
        MATCH (cluster:Cluster {id: $cluster_id})-[:HAS_PUBLICATION]->(pub:Publication)
        RETURN id(pub) AS PublicationID
        """
        result = tx.run(query, cluster_id=cluster_id)
        publications = [{"Publication ID": record["PublicationID"]} for record in result]
        return publications    
    
    @staticmethod
    def _retrieve_publications_and_subclusters(tx, cluster_id):
        query = """
        MATCH (cluster:Cluster {id: $cluster_id})
        CALL {
            WITH cluster
            MATCH (cluster)-[:HAS_SUBCLUSTER*0..]->(subcluster)-[:HAS_PUBLICATION]->(pub:Publication)
            RETURN pub
        }
        RETURN id(pub) AS PublicationID, pub.publicationTitle AS Title, pub.publicationAbstract AS Abstract
        """
        result = tx.run(query, cluster_id=cluster_id)
        publications = [{"PublicationID": record["PublicationID"], "Title": record["Title"], "Abstract": record["Abstract"]} for record in result]
        return publications
    
    @staticmethod
    def _retrieve_all_cluster_and_subcluster_ids_and_tags_by_topic(tx, topic_name):
        query = """
        MATCH (topic:FieldOfStudy {label: $topic_name})-[:HAS_CLUSTER]->(cluster)
        WITH DISTINCT cluster
        OPTIONAL MATCH (cluster)-[:HAS_SUBCLUSTER*0..]->(subcluster)
        WITH cluster, COLLECT(DISTINCT subcluster) AS subclusters
        UNWIND ([cluster] + subclusters) AS allClusters
        WITH DISTINCT allClusters.id AS ClusterID, allClusters.tag AS ClusterTag
        ORDER BY ClusterID
        RETURN ClusterID, ClusterTag
        """
        result = tx.run(query, topic_name=topic_name)  # Ensuring case-insensitivity
        all_cluster_info = [{"ClusterID": record["ClusterID"], "ClusterTag": record["ClusterTag"]} for record in result if record["ClusterID"] is not None]
        return all_cluster_info
    
    @staticmethod
    def add_zephyr_cluster_tag(tx, cluster_id, new_tag_value):
        query = """
        MATCH (cluster:Cluster {id: $cluster_id})
        SET cluster.zephyr_cluster_tag = $new_tag_value
        RETURN cluster.id AS ClusterID, cluster.zephyr_cluster_tag AS NewZephyrClusterTag
        """
        result = tx.run(query, cluster_id=cluster_id, new_tag_value=new_tag_value)
        try:
            return result.single()[1]  # Returns the new zephyr_cluster_tag for the updated node.
        except Exception as e:
            print(f"An error occurred: {e}")
            return None


In [None]:
import random

import requests

# Endpoint URL
url = 'http://127.0.0.1:6000/generate_cluster_name'

n4j = Neo4jApp()

In [None]:
topic_names = ["Text Style Transfer", "Paraphrasing", "Data-to-Text Generation", "Summarization"]

for topic in topic_names:
    all_clusters_inside_topic = n4j.get_all_cluster_and_subcluster_ids_and_tags_by_topic(topic)

    print("Working on topic:", topic, "with", len(all_clusters_inside_topic), "clusters.")

    for i in range(len(all_clusters_inside_topic)):
        cluster = all_clusters_inside_topic[i]

        cluster_id = cluster['ClusterID']
        cluster_tag = cluster['ClusterTag']

        publications = n4j.get_publications_by_cluster_and_subclusters(cluster_id)

        num_samples = min(len(publications), 5)
        randomly_selected_titles = [publication['Title'] for publication in random.sample(publications, num_samples)]

        # Your request data
        data = {
            "tfidf_cluster_name": cluster_tag,
            "paper_titles": randomly_selected_titles  # Your list of paper titles
        }

        # Send a POST request
        response = requests.post(url, json=data)

        # Check if the request was successful
        if response.status_code == 200:
            # Get the JSON response body
            response_data = response.json()
            print("Cluster Name:", response_data.get('cluster_name'))

            new_cluster_tag = response_data.get('cluster_name')

            n4j.update_zephyr_cluster_tag(cluster_id, new_cluster_tag)

            print("Updated cluster tag for cluster ID", cluster_id, "to", new_cluster_tag)
            print(f"""{i} from {len(all_clusters_inside_topic) - 1} completed.""")

        else:
            print("Error:", response.text)