In [1]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from openai import OpenAI
from typing import List
from sklearn.metrics import silhouette_score
from tqdm import tqdm
import argparse
import json
import os
import random
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import concurrent.futures
import pickle

In [11]:
def construct_args():
    parser = argparse.ArgumentParser(description='Cluster entities using hierarchical clustering and refine the clusters using LLM.')
    parser.add_argument('--output_dir', type=str, default="/data/pj20/lamake_data")
    parser.add_argument('--data_dir', type=str, default="/home/pj20/server-03/lamake/data")
    parser.add_argument('--dataset', type=str, default="FB15K-237", help='Path to the dataset file containing the list of entities to cluster.')
    parser.add_argument('--dimensions', type=int, default=1024, help='Dimensionality of the embeddings. Default: 1024.')
    parser.add_argument('--num_threads', type=int, default=10, help='Number of threads to use for multi-threaded processes. Default: 10.')
    parser.add_argument('--max_entities', type=int, default=100, help='Maximum number of entities to include in an LLM request. Default: 100.')
    
    args = parser.parse_args(args=[])
    args.log_dir = f"{args.output_dir}/{args.dataset}/logs"
    
    return args

In [12]:
args = construct_args()

In [23]:
from cluster import  read_entities, create_entity_info_emb_dict, generate_embeddings, build_hierarchy

In [2]:
import pickle

clustering_file = "/data/pj20/lamake_data/FB15K-237/clustering/clustering_0.52.pkl"
with open(clustering_file, "rb") as f:
    clustering = pickle.load(f)

In [15]:
entities = read_entities('/home/pj20/server-03/lamake/data/FB15K-237/entities.dict')
entity_info, entity_embeddings = create_entity_info_emb_dict(args, entities)
entities_text, original_descriptions = [], []
for entity in entities:
    entities_text.append(entity_info[entity]["text_label"])
    original_descriptions.append(entity_info[entity]["original_description"])
    
print("Start Generating Embeddings...")
embeddings, entity_info, entity_embeddings = generate_embeddings(args, entity_info=entity_info, entity_embeddings=entity_embeddings, dim=args.dimensions)

Start Generating Embeddings...
Loading existing entity embeddings from /data/pj20/lamake_data/FB15K-237/entity_embeddings.json...
Done.
All entities have valid embeddings. Skipping embedding generation.
Loading existing entity info from /data/pj20/lamake_data/FB15K-237/entity_info.json...
Done.


In [150]:
clusters = {}
for i in range(clustering.n_clusters_):
    cluster_indices = np.where(clustering.labels_ == i)[0]
    cluster_entities = [entities_text[idx] for idx in cluster_indices]
    clusters[f"Cluster_{i+1}"] = cluster_entities

In [54]:
clusters_ = {int(i): entities for i, entities in enumerate(clusters.values())}

In [21]:
len(entities)

14541

In [18]:
clustering.children_

array([[ 1422, 12516],
       [ 1275,  4592],
       [12324, 13070],
       ...,
       [29074, 29076],
       [ 5241, 29077],
       [29078, 29079]])

In [151]:
initial_hier = build_hierarchy(clustering.children_, len(entities), entity_labels=entities_text, clustering=clustering)

In [152]:
clusters_ = {int(i): entities for i, entities in enumerate(clusters.values())}
entity2clusterid = {}

for i, cluster in enumerate(clusters_.values()):
    for entity in cluster:
        entity2clusterid[entity] = i
        
clusterid2count = defaultdict(int)

In [153]:
entity2clusterid

{'comedy-drama': 0,
 'romance film': 0,
 'sketch comedy': 0,
 'television comedy': 0,
 'comedian': 0,
 'romantic comedy': 0,
 'black comedy': 0,
 'comedy film': 0,
 'Comedy': 0,
 'comedy of manners': 0,
 'stand-up comedy': 0,
 'sex comedy': 0,
 'screwball comedy film': 0,
 'situation comedy': 0,
 'Paddington': 1,
 'Hammersmith': 1,
 'Kensington': 1,
 'Wandsworth': 1,
 'Chiswick': 1,
 'Hampstead': 1,
 'Marylebone': 1,
 'Ealing': 1,
 'Kannada': 2,
 'Mangalore': 2,
 'Mysore': 2,
 'Karnataka': 2,
 'Metropolis Records': 3,
 'Tooth & Nail Records': 3,
 'Southern Lord Records': 3,
 'Revelation Records': 3,
 'Roadrunner Records': 3,
 'Century Media Records': 3,
 'Nuclear Blast': 3,
 'Equal Vision Records': 3,
 'The End Records': 3,
 'Relapse Records': 3,
 'Cleopatra Records': 3,
 'Chrysalis Records': 3,
 'Victory Records': 3,
 'Frontiers Records': 3,
 'Candlelight Records': 3,
 'Napalm Records': 3,
 'Metal Blade Records': 3,
 'Season of Mist': 3,
 'Earache Records': 3,
 'Spinefarm Records': 3,

In [28]:
with open('./initial_hier.json', 'w') as f:
    json.dump(initial_hier, f, indent=4)

In [113]:
entity2clusterid = {}

for i, cluster in enumerate(clusters_.values()):
    for entity in cluster:
        entity2clusterid[entity] = i

In [154]:
from collections import defaultdict

clusterid2count = defaultdict(int)


def label_(d, leaf_keys=None, leaf_values=None):
    if leaf_keys is None:
        leaf_keys = []
    if leaf_values is None:
        leaf_values = []
    for key, value in d.items():
        if isinstance(value, dict):  # If the value is another dictionary, recurse into it
            label_(value, leaf_keys, leaf_values)
        else:  # If the value is not a dictionary, then it's a leaf node
            cluster_id = entity2clusterid[value]
            d[key] = [cluster_id, clusterid2count[cluster_id]]
            clusterid2count[entity2clusterid[value]] += 1
    return d


In [155]:
hierarchy = label_(initial_hier)

In [132]:
with open('./initial_hier_numeric.json', 'w') as f:
    json.dump(hierarchy, f, indent=4)

In [135]:
def refine_1(d, clusters_, leaf_keys=None, leaf_values=None):
    if leaf_keys is None:
        leaf_keys = []
    if leaf_values is None:
        leaf_values = []
    
    keys_to_delete = []  # List to hold keys of items to be deleted
    items_to_update = {}  # Dictionary to hold items to be updated

    for key, value in list(d.items()):  # Convert dict_items to a list to safely iterate
        if isinstance(value, dict):  # If the value is another dictionary, recurse into it
            refine_1(value, clusters_, leaf_keys, leaf_values)
        else:
            if value[1] > 0:
                keys_to_delete.append(key)
            else:
                items_to_update[key] = clusters_[value[0]]

    # Now, delete keys marked for deletion
    for key in keys_to_delete:
        del d[key]

    # Update the dictionary with new values
    for key, new_value in items_to_update.items():
        d[key] = new_value

    return d


In [143]:
hierarchy = refine_1(hierarchy, clusters_)
with open('./refined_hier.json', 'w') as f:
    json.dump(hierarchy, f, indent=4)

In [145]:
def refine_2(d):
    # Recursive function to process and refine each dictionary
    def process_dict(sub_dict):
        for key in list(sub_dict.keys()):  # Iterate over a copy of the keys
            value = sub_dict[key]
            if isinstance(value, dict):
                if value:  # Check if the dictionary is not empty
                    result = process_dict(value)
                    # If the result is a single entry with a list, replace the current dict
                    if len(result) == 1 and isinstance(list(result.values())[0], list):
                        sub_dict[key] = list(result.values())[0]
                    else:
                        sub_dict[key] = result
                else:
                    del sub_dict[key]  # Remove empty dictionaries
        return sub_dict

    # Copy the original dictionary to avoid modification issues
    refined_dict = process_dict(d.copy())
    return refined_dict

In [146]:
hierarchy = refine_2(hierarchy)
with open('./refined_hier.json', 'w') as f:
    json.dump(hierarchy, f, indent=4)

In [148]:
def refine_3(d):
    # Recursive function to process and refine each dictionary
    def process_dict(sub_dict):
        new_dict = {}  # To accumulate refined results
        for key, value in list(sub_dict.items()):
            if isinstance(value, dict):
                processed = process_dict(value)  # Recursively process
                if processed:  # Only add non-empty results
                    new_dict[key] = processed
            else:  # Keep non-dict items as they are
                new_dict[key] = value
        return new_dict

    # Start the processing with the original dictionary
    refined_dict = process_dict(d)
    return refined_dict

In [149]:
hierarchy = refine_3(hierarchy)
with open('./refined_hier.json', 'w') as f:
    json.dump(hierarchy, f, indent=4)

In [1]:
from utils import refine_4
import json

with open('/data/pj20/lamake_data/FB15K-237/seed_clusters.json', 'r') as f:
    seed_clusters = json.load(f)
    
hierarchy = refine_4(seed_clusters)

with open('./refined_hier.json', 'w') as f:
    json.dump(hierarchy, f, indent=4)

In [5]:
from openai import OpenAI

with open('./openai_api.key', 'r') as f:
    api_key = f.read().strip()
client = OpenAI(api_key=api_key)

def gpt_chat_return_response(model, prompt, seed=44):
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "user", "content": prompt}
        ],
        max_tokens=200,
        temperature=0,
        seed=seed,
        logprobs=True
    )
    return response