# Imports

In [None]:
import torch 
import random
import json
from collections import defaultdict
import numpy as np
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import normalized_mutual_info_score
from sentence_transformers import SentenceTransformer, util
from mistralai import Mistral 
from dotenv import load_dotenv
import os
load_dotenv()

### LLMs

In [None]:
api_key = os.getenv("MISTRAL_API_KEY")
mistral = "mistral-large-latest"
client = Mistral(api_key=api_key)

### Embedding Model

In [None]:
inst_emb_model = SentenceTransformer("hkunlp/instructor-xl") 
minilm_emb_model = SentenceTransformer("all-MiniLM-L6-v2") 

### Reading Reverb45K Data

In [78]:
def load_dataset(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = [json.loads(line) for line in file]
    return data

In [67]:
dataset_path = "C:/projetsAlternance/projetMaster/reverb45k/reverb45k_test.txt"
dataset = load_dataset(dataset_path)

In [5]:
def extract_data(data):
    clusters = defaultdict(set)
    for item in data: 
        subject = item["triple_norm"][0]
        object = item["triple_norm"][2]
        subj_id = item["true_link"]["subject"]
        obj_id = item["true_link"]["object"]
        clusters[subj_id].add(subject)
        clusters[obj_id].add(object)
    return clusters

In [68]:
clusters = extract_data(dataset)

In [70]:
total_mentions = sum(len(mentions) for mentions in clusters.values())
print(f"Nombre total d'entités (mentions) dans les clusters : {total_mentions}")

Nombre total d'entités (mentions) dans les clusters : 26968


In [47]:
# Filtrer les clusters avec plus d'une mention
filtered_clusters = {eid: mentions for eid, mentions in clusters.items() if len(mentions) > 1}

# Affichage du nombre de clusters filtrés
print(f"Nombre de clusters avec plus d'une mention : {len(filtered_clusters)}")


Nombre de clusters avec plus d'une mention : 6785


In [48]:
total_mentions = sum(len(mentions) for mentions in filtered_clusters.values())
print(f"Nombre total d'entités (mentions) dans les clusters : {total_mentions}")

Nombre total d'entités (mentions) dans les clusters : 16616


# Evaluating Functions

In [None]:
def compute_nmi(true_labels, predicted_labels):
    nmi_score = normalized_mutual_info_score(true_labels, predicted_labels)
    return nmi_score

In [None]:
def compute_ari(true_labels, predicted_labels):
    ari_score = adjusted_rand_score(true_labels, predicted_labels)
    return ari_score

In [None]:
def get_cluster_mapping(clusters):
    """Crée un dictionnaire {élément : ID_cluster} pour un accès rapide."""
    element_to_cluster = {}
    for cluster_id, cluster in enumerate(clusters):
        for element in cluster:
            element_to_cluster[element] = cluster_id
    return element_to_cluster

In [5]:
def macro_precision_recall(gold_clusters, pred_clusters):
    """Calcule Macro Precision et Recall."""
    pred_to_gold = defaultdict(set)
    gold_to_pred = defaultdict(set)
    
    gold_map = get_cluster_mapping(gold_clusters)
    pred_map = get_cluster_mapping(pred_clusters)
    
    # Construction des mappings entre clusters pred et gold
    for element in gold_map:
        if element in pred_map:
            pred_to_gold[pred_map[element]].add(gold_map[element])
            gold_to_pred[gold_map[element]].add(pred_map[element])
    
    # Macro Precision : % de clusters pred où tous les éléments sont dans le même cluster gold
    macro_prec = sum(
        1 for pred_cluster in pred_to_gold if len(pred_to_gold[pred_cluster]) == 1
    ) / len(pred_to_gold) if pred_to_gold else 0
    
    # Macro Recall : % de clusters gold où tous les éléments sont dans le même cluster pred
    macro_rec = sum(
        1 for gold_cluster in gold_to_pred if len(gold_to_pred[gold_cluster]) == 1
    ) / len(gold_to_pred) if gold_to_pred else 0
    
    return macro_prec, macro_rec

In [6]:
def micro_precision_recall(gold_clusters, pred_clusters):
    """Calcule Micro Precision et Recall."""
    gold_map = get_cluster_mapping(gold_clusters)
    pred_map = get_cluster_mapping(pred_clusters)
    
    # Micro Precision
    pred_cluster_majority = {}
    for cluster in pred_clusters:
        gold_counts = defaultdict(int)
        for element in cluster:
            if element in gold_map:
                gold_counts[gold_map[element]] += 1
        majority_gold = max(gold_counts.items(), key=lambda x: x[1])[0] if gold_counts else None
        pred_cluster_majority[tuple(cluster)] = majority_gold  # Convertir en tuple
    
    micro_prec = sum(
        1 for cluster in pred_clusters
        for element in cluster
        if element in gold_map and gold_map[element] == pred_cluster_majority[tuple(cluster)]
    ) / sum(len(cluster) for cluster in pred_clusters) if pred_clusters else 0
    
    # Micro Recall
    gold_cluster_majority = {}
    for cluster in gold_clusters:
        pred_counts = defaultdict(int)
        for element in cluster:
            if element in pred_map:
                pred_counts[pred_map[element]] += 1
        majority_pred = max(pred_counts.items(), key=lambda x: x[1])[0] if pred_counts else None
        gold_cluster_majority[tuple(cluster)] = majority_pred  # Convertir en tuple
    
    micro_rec = sum(
        1 for cluster in gold_clusters
        for element in cluster
        if element in pred_map and pred_map[element] == gold_cluster_majority[tuple(cluster)]
    ) / sum(len(cluster) for cluster in gold_clusters) if gold_clusters else 0
    
    return micro_prec, micro_rec

In [7]:
def pairwise_precision_recall(gold_clusters, pred_clusters):
    """Version corrigée pour gérer les sets et les lists."""
    gold_pairs = set()
    pred_pairs = set()
    
    # Convertir les clusters gold (sets) en lists pour l'indexation
    gold_clusters = [list(cluster) if isinstance(cluster, set) else cluster for cluster in gold_clusters]
    
    # Génération des paires gold
    for cluster in gold_clusters:
        for i in range(len(cluster)):
            for j in range(i + 1, len(cluster)):
                gold_pairs.add((cluster[i], cluster[j]))
    
    # Génération des paires pred
    for cluster in pred_clusters:
        for i in range(len(cluster)):
            for j in range(i + 1, len(cluster)):
                pred_pairs.add((cluster[i], cluster[j]))
    
    # Calcul des métriques
    tp = len(pred_pairs & gold_pairs)  # True positives
    pairwise_prec = tp / len(pred_pairs) if pred_pairs else 0
    pairwise_rec = tp / len(gold_pairs) if gold_pairs else 0
    
    return pairwise_prec, pairwise_rec

In [8]:
def compute_f1(precision, recall):
    """Calcule le F1-score harmonique."""
    return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

In [None]:
def evaluate_canonicalization(gold_clusters, pred_clusters):
    """Calcule toutes les métriques demandées."""
    # Macro
    macro_prec, macro_rec = macro_precision_recall(gold_clusters, pred_clusters)
    macro_f1 = compute_f1(macro_prec, macro_rec)
    # Micro
    micro_prec, micro_rec = micro_precision_recall(gold_clusters, pred_clusters)
    micro_f1 = compute_f1(micro_prec, micro_rec)
    
    # Pairwise
    pairwise_prec, pairwise_rec = pairwise_precision_recall(gold_clusters, pred_clusters)
    pairwise_f1 = compute_f1(pairwise_prec, pairwise_rec)
    
    return {
        "Macro": {"Precision": round(macro_prec, 4), "Recall": round(macro_rec, 4), "F1": round(macro_f1, 4)},
        "Micro": {"Precision": round(micro_prec, 4), "Recall": round(micro_rec, 4), "F1": round(micro_f1, 4)},
        "Pairwise": {"Precision": round(pairwise_prec, 4), "Recall": round(pairwise_rec, 4), "F1": round(pairwise_f1, 4)},
    }

# KMeans with Raw Data 

In [None]:
#mentions_flat = [mention for mentions in filtered_clusters.values() for mention in mentions]
mentions_flat = [mention for mentions in clusters.values() for mention in mentions]
print(mentions_flat)

In [72]:
embeddings = minilm_emb_model.encode(mentions_flat, convert_to_tensor=True)
print(f"Nombre total de mentions encodées : {len(mentions_flat)}")

Nombre total de mentions encodées : 26968


In [73]:
embeddings = embeddings.cpu().numpy()

In [None]:
from sklearn.cluster import KMeans

num_clusters = 17137

kmeans = KMeans(n_clusters=num_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(embeddings)

In [74]:
# Formatage des clusters prédits
pred_clusters = defaultdict(list)
for mention, label in zip(mentions_flat, cluster_labels):
    pred_clusters[label].append(mention)
pred_clusters = list(pred_clusters.values())

In [75]:
gold_clusters = list(clusters.values()) 

In [77]:
results = evaluate_canonicalization(gold_clusters, pred_clusters)
print(results)

{'Macro': {'Precision': 0.756702681072429, 'Recall': 0.7087011349306431, 'F1': 0.7319157259243368}, 'Micro': {'Precision': 0.7243045579208567, 'Recall': 0.23264609908039158, 'F1': 0.3521741245770732}, 'Pairwise': {'Precision': 0.13391799004503438, 'Recall': 0.03728880675818374, 'F1': 0.058334623922358164}}


# Implementing Context View

## TransE Training with my Data (reverb45k_test)

In [None]:
import json
import csv
import random
from pathlib import Path

data = "C:/projetsAlternance/projetMasterGit/Clustering_and_LLMs/entity_canonicalization/data/reverb45k/reverb45k_test.txt"      
output_dir = Path("C:/projetsAlternance/projetMasterGit/Clustering_and_LLMs/entity_canonicalization/scripts/training_transe_data")
split_ratio = (0.9, 0.1)

with open(data, 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

triplets = []
for item in data:
    if "triple_norm" in item and len(item["triple_norm"]) == 3:
        h, r, t = item["triple_norm"]
        triplets.append((h.strip(), r.strip(), t.strip()))


random.shuffle(triplets)
n = len(triplets)
train_triples = triplets

n_test = int(0.15 * n)
n_valid = int(0.1 * n)
test_triples = random.sample(train_triples, n_test) 
valid_triples = random.sample(train_triples, n_valid)
# 🔹 Fonction pour écrire les TSV
def write_triples(filename, triples):
    with open(filename, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerows(triples)

write_triples(output_dir / 'train.tsv', train_triples)
write_triples(output_dir / 'valid.tsv', valid_triples)
write_triples(output_dir / 'test.tsv', test_triples)

print("✅ Triplets préparés et fichiers générés.")

✅ Triplets préparés et fichiers générés.


In [None]:
from pykeen.pipeline import pipeline

result = pipeline(
    model='TransE',
    training='training_transe_data/train.tsv',  
    validation='training_transe_data/valid.tsv',  
    testing='training_transe_data/test.tsv',  
    training_loop='slcwa', 
    epochs=50,  
    random_seed=42,
)
model = result.model

INFO:pykeen.pipeline.api:Using device: None
Training epochs on cpu: 100%|██████████| 50/50 [02:26<00:00,  2.94s/epoch, loss=0.0115, prev_loss=0.0128]
Evaluating on cpu: 100%|██████████| 5.52k/5.52k [00:33<00:00, 166triple/s]
INFO:pykeen.evaluation.evaluator:Evaluation took 33.42s seconds


In [132]:
from pykeen.triples import TriplesFactory
import pandas as pd


df = pd.read_csv("training_transe_data/train.tsv", sep='\t', header=None, names=['subject', 'predicate', 'object'])

relation_counts = df['predicate'].value_counts()
frequent_relations = relation_counts[relation_counts > 10].index
df_filtered = df[df['predicate'].isin(frequent_relations)]

print(len(df_filtered))
# Assuming your dataframe has columns 'subject', 'predicate', 'object'
triples_factory = TriplesFactory.from_labeled_triples(
    triples=df_filtered[['subject', 'predicate', 'object']].values,
)

# 🔹 Split : 80% train, 10% valid, 10% test (tu peux ajuster si besoin)
training, testing, validation = triples_factory.split([0.75, 0.1, 0.15])

d=training
id_to_entity={v: k for k, v in d.entity_to_id.items()}
id_to_relation={v: k for k, v in d.relation_to_id.items()}

# Display the first few triples
triples_factory.triples

INFO:pykeen.triples.splitting:done splitting triples to groups of sizes [855, 1241, 1862]


12411


array([['10base2', 'be also know as', 'thin ethernet'],
       ['13-cis-retinoic acid', 'be another name for', 'isotretinoin'],
       ['20th century fox', 'be a subsidiary of', 'news corporation'],
       ...,
       ['zyprexa', 'be manufacture by', 'eli lilly and company'],
       ['zyrtec', 'be a registered trademark of', 'pfizer'],
       ['zyrtec', 'be also call', 'warfarin']],
      shape=(12411, 3), dtype='<U49')

In [None]:
from pykeen.pipeline import pipeline

result = pipeline(
    model='ComplEx',
    loss="softplus",
    training=training,
    testing=testing,
    validation=validation,
    model_kwargs=dict(embedding_dim=3),  # Increase the embedding dimension
    optimizer_kwargs=dict(lr=0.1),  # Adjust the learning rate
    training_kwargs=dict(num_epochs=100, use_tqdm_batch=False),  # Increase the number of epochs
)

model = result.model

INFO:pykeen.pipeline.api:Using device: None
Training epochs on cpu: 100%|██████████| 100/100 [00:47<00:00,  2.10epoch/s, loss=0.767, prev_loss=0.777]
Evaluating on cpu: 100%|██████████| 1.24k/1.24k [00:00<00:00, 2.87ktriple/s]
INFO:pykeen.evaluation.evaluator:Evaluation took 0.47s seconds


In [None]:
from pykeen.evaluation import RankBasedEvaluator

# Create an evaluator
evaluator = RankBasedEvaluator()

# Evaluate the model
metrics = evaluator.evaluate(result.model, testing.mapped_triples, additional_filter_triples=[training.mapped_triples, validation.mapped_triples])

# Print the metrics
print(f"Hits@1: {metrics.get_metric('hits@1')}")
print(f"Hits@3: {metrics.get_metric('hits@3')}")
print(f"Hits@5: {metrics.get_metric('hits@5')}")
print(f"Hits@10: {metrics.get_metric('hits@10')}")
print(f"Mean Reciprocal Rank: {metrics.get_metric('mean_reciprocal_rank')}")

## TransE training with 'FB15k_237' DATA

In [195]:
from pykeen.pipeline import pipeline

result = pipeline(
    model='TransE',
    dataset='FB15k_237',
    training_kwargs=dict(num_epochs=100, use_tqdm_batch=False),
)

INFO:pykeen.datasets.utils:Caching preprocessed dataset to file:///C:/Users/melissa.merabet/.data/pykeen/datasets/fb15k237/cache/47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
INFO:pykeen.datasets.base:downloading data from https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip to C:\Users\melissa.merabet\.data\pykeen\datasets\fb15k237\FB15K-237.2.zip
INFO:pykeen.triples.triples_factory:Stored TriplesFactory(num_entities=14505, num_relations=237, create_inverse_triples=False, num_triples=272115, path=Release\train.txt) to file:///C:/Users/melissa.merabet/.data/pykeen/datasets/fb15k237/cache/47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM/training
INFO:pykeen.datasets.base:Stored training factory to file:///C:/Users/melissa.merabet/.data/pykeen/datasets/fb15k237/cache/47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM/training
INFO:pykeen.triples.triples_factory:Stored TriplesFactory(num_entities=14505, num_relations=237, create_inverse_triples=False, num_triples=20438, path=Release\test.t

In [None]:
transe_model = result.model
# Accéder aux représentations des entités et des relations
entity_representation_modules = model.entity_representations
relation_representation_modules = model.relation_representations

# Pour TransE, il n'y a qu'une seule représentation pour les entités et une pour les relations
entity_embeddings = entity_representation_modules[0]
relation_embeddings = relation_representation_modules[0]

In [None]:
result.metric_results.to_dict()

## TransR Training with 'FB15k_237' DATA

In [None]:
result = pipeline(
    model='TransR',
    dataset='FB15k_237',
    training_kwargs=dict(num_epochs=100, use_tqdm_batch=False),
)

In [None]:
result.metric_results.to_dict()

## ComplEx training with FB15K_237

In [None]:
result = pipeline(
    model='ComplEx',
    dataset='FB15k_237',
    training_kwargs=dict(num_epochs=100, use_tqdm_batch=False),
)

In [None]:
result.metric_results.to_dict()

# Multi-View Encoding 

In [79]:
dataset_path = "C:/projetsAlternance/projetMaster/reverb45k/reverb45k_valid.txt"
dataset = load_dataset(dataset_path)

In [88]:
from collections import defaultdict

def extract_data(data):
    clusters = defaultdict(set)
    for item in data:
        subject_mention = item["triple_norm"][0]
        relation = item["triple_norm"][1]
        object_mention = item["triple_norm"][2]

        subject_id = item["true_link"]["subject"]
        object_id = item["true_link"]["object"]

        context = " ".join(item["src_sentences"])

        # Stockage avec format: (mention, contexte) 
        clusters[subject_id].add((subject_mention, (subject_mention, relation, object_mention), context))
        clusters[object_id].add((object_mention, (subject_mention, relation, object_mention), context))
    return clusters

In [89]:
entity_clusters = extract_data(dataset)

In [90]:
# Exemple d'accès
for entity_id, mentions in entity_clusters.items():
    print(f"Entité {entity_id}:")
    for mention, fact_text, context in mentions:
        print(f"  - {mention} (fact_view: {fact_text})  (contexte: {context[:50]}...)")

Entité /m/09w_9:
  - frederick (fact_view: ('frederick', 'have reach', 'alessandria'))  (contexte: Frederick had reached Alessandria By late October,...)
  - barbarossa (fact_view: ('barbarossa', 'lose the battle of', 'legnano'))  (contexte: Barbarossa lost the battle of Legnano Barbarossa l...)
Entité /m/02bb_4:
  - alessandria (fact_view: ('frederick', 'have reach', 'alessandria'))  (contexte: Frederick had reached Alessandria By late October,...)
Entité /m/03q3p4:
  - legnano (fact_view: ('barbarossa', 'lose the battle of', 'legnano'))  (contexte: Barbarossa lost the battle of Legnano Barbarossa l...)
Entité /m/0f2y0:
  - darth vader (fact_view: ('darth vader', 'meet', 'luke skywalker'))  (contexte: Darth Vader met Luke Skywalker The Ultimate Ninja ...)
  - anakin skywalker (fact_view: ('darth vader', 'be once know as', 'anakin skywalker'))  (contexte: Darth Vader was once known as Anakin Skywalker Dar...)
  - darth vader (fact_view: ('anikin', 'eventually become', 'darth vader'))  

# Integrate LLM in canonicaling entities 

## Before Clustering 

### Entity Canonical Form

In [None]:
def canonicalize_entity(entity, sentences):
    prompt = """
    You are an expert in entity canonicalization. Your task is to analyze the provided entity mention along with its contextual appearances in different sentences. Based on these contexts,
    infer and return the most accurate and complete canonical name of the entity.

    Instructions:
    - Use the context of the sentences to disambiguate abbreviated or informal entity mentions.
    - Return only the canonical name of the entity, not a description or explanation.
    - If the entity is already in canonical form, return it as is.
    - Avoid hallucinations. If the context is insufficient to determine the canonical name, return "null".

    Here are 2 examples :

    Example 1:
    Entity : B. Obama
    Sentences:
    1. B. Obama was the 44th president of USA.
    2. The president elected in 2008 in the USA after George Bush is B. Obama. 
    Entity Canonicalisation: Barack Obama

    Example 2:
    Entity : NYC
    Sentences:
    1. I took a flight to NYC last summer and visited Times Square.
    2. The Statue of Liberty is one of the most iconic landmarks in NYC.
    Entity Canonicalization: New York City
    """
    prompt += f"Now, please process the following entity: : {entity}\n"
    prompt += f"Here are the sentences :\n"
    for i, sentence in enumerate(sentences, 1):
        prompt += f"{i}. {sentence}\n"

    prompt += """
    Entity Canonicalization:
    """

    print(f"Entity: {entity}")
    chat_response = client.chat.complete(
        model= mistral,
        messages = [
            {
                "role": "system", "content": prompt,
            },
        ]
    )
    return chat_response.choices[0].message.content

In [184]:
import time
INPUT_FILE = "reverb45k/reverb45k_test.txt"  
OUTPUT_FILE = "reverb45k_canonicalized_entities.json"

def process_data():
    with open(INPUT_FILE, "r") as infile, open(OUTPUT_FILE, "w") as outfile:
        for line in infile:
            data = json.loads(line)
            
            # Traitement pour subject/object
            for role in ["subject", "object"]:
                entity = data["triple_norm"][0 if role == "subject" else 2]
                context = data["src_sentences"]
                
                # Canonicalisation seulement si entity_linking est null
                if not data["entity_linking"].get(role):
                    canonical = canonicalize_entity(entity, context)
                    print(f"canonical : {canonical}")
                    print("-"*100)
                    time.sleep(5)
                    if canonical and canonical.lower() != "null":
                        data["entity_linking"][role] = canonical
            
            outfile.write(json.dumps(data) + "\n")

In [None]:
process_data()

### Multi View Embeddings

In [241]:
import json

all_inputs = []
all_labels = []

with open("reverb45k_canonicalized_entities.json", "r", encoding="utf-8") as f:
    for line in f:
        data = json.loads(line)
        subj_text = data['triple_norm'][0]
        obj_text = data['triple_norm'][2]
        sentences = " ".join(data.get("src_sentences", []))
        entity_linking = data.get("entity_linking") or {}
        subj_link = entity_linking.get("subject", "") or ""
        obj_link = entity_linking.get("object", "") or ""
        true_subj = data.get("true_link", {}).get("subject")
        true_obj = data.get("true_link", {}).get("object")
        data_id = data.get("_id")
        
        all_inputs.append((subj_text, sentences, subj_link))
        all_labels.append(true_subj)
        
        all_inputs.append((obj_text, sentences, obj_link))
        all_labels.append(true_obj)

In [None]:
from collections import Counter

# Compter les occurrences de chaque type de label
label_counter = Counter(all_labels)

# Filtrer les labels qui ont plus d'une occurrence
labels_to_keep = {label for label, count in label_counter.items() if count > 1}

# Filtrer les entrées de données
filtered_labels = []
filtered_inputs = []

for i, label in enumerate(all_labels):
    if label in labels_to_keep:
        filtered_labels.append(label)
        filtered_inputs.append(all_inputs[i])

In [247]:
texts = [x[0] for x in filtered_inputs]
sentences = [x[1] for x in filtered_inputs]
links = [x[2] for x in filtered_inputs]

In [248]:
print(len(texts))

5559


In [249]:
prompt_entity = "Represent this entity for entity canonicalization clustering : "
texts_to_emb = [prompt_entity + text for text in texts]
vectorized_texts = inst_emb_model.encode(texts_to_emb)

In [250]:
prompt_sentences = "Represent these sentences for entity canonicalization clustering : "
sentences_to_emb = [prompt_sentences + sentence for sentence in sentences]
vectorized_sentences = inst_emb_model.encode(sentences_to_emb)

In [251]:
prompt_entity_linking = "Represent this entity for entity canonicalization clustering : "
links_to_emb = [prompt_entity_linking + entity_linking for entity_linking in links]
vectorized_links = inst_emb_model.encode(links_to_emb)

In [None]:
np.savez_compressed("links_text_sentences_instructor.npz",
                    links=vectorized_links,
                    sentences=vectorized_sentences,
                    texts=vectorized_texts,
                    labels=np.array(filtered_labels))

In [252]:
print(type(vectorized_texts), vectorized_texts.shape)
print(type(vectorized_sentences), vectorized_sentences.shape)
print(type(vectorized_links), vectorized_links.shape)


<class 'numpy.ndarray'> (5559, 768)
<class 'numpy.ndarray'> (5559, 768)
<class 'numpy.ndarray'> (5559, 768)


### Concatenation

In [None]:
X_combined = np.hstack([vectorized_texts, vectorized_links])


In [254]:
print("Vecteurs shape :", X_combined.shape)
print("Exemple label :", all_labels[0])

Vecteurs shape : (5559, 2304)
Exemple label : /m/0cnn5


### CCA de sklearn

In [16]:
data = np.load("C:/projetsAlternance/projetMaster/links_text_sentences_instructor.npz")
vectorized_links = data["links"]
vectorized_sentences = data["sentences"]
vectorized_texts = data["texts"]
filtered_labels = data["labels"]

In [None]:
from sklearn.cross_decomposition import CCA
import numpy as np
from sklearn.preprocessing import StandardScaler


sentences = np.array(vectorized_sentences)
texts = np.array(vectorized_texts)
links = np.array(vectorized_links)


scaler = StandardScaler()
sentences = scaler.fit_transform(sentences)
texts = scaler.fit_transform(texts)
links = scaler.fit_transform(links)

cca = CCA(n_components=1) 

# Apprendre et transformer les vues
cca.fit(texts)
links_c, texts_c = cca.transform(texts)


combined_features = np.concatenate([texts_c, links_c], axis=1)


# Siamese Network

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Supposons que les données "sentences", "texts", "links" sont déjà vectorisées
sentences = np.array(vectorized_sentences)
texts = np.array(vectorized_texts)
links = np.array(vectorized_links)

# Définir l'architecture de base pour les trois vues
def base_model(input_shape):
    input_layer = layers.Input(shape=input_shape)
    x = layers.Dense(64, activation='relu')(input_layer)
    x = layers.Dense(32, activation='relu')(x)
    x = layers.Dense(16, activation='relu')(x)
    x = layers.Dense(8, activation='relu')(x)
    model = models.Model(inputs=input_layer, outputs=x)
    return model

input_shape_texts = (texts.shape[1],)  # Forme des données d'entrée pour texts
input_shape_links = (links.shape[1],)  # Forme des données d'entrée pour links

view2 = base_model(input_shape_texts)
view3 = base_model(input_shape_links)

# Fusionner les représentations des trois vues
def siamese_network( view2, view3):
    output2 = view2.output
    output3 = view3.output
    
    # Fusionner les sorties des trois vues (par exemple, concaténation)
    merged = layers.Concatenate()([output2, output3])
    
    # Ajouter une couche de sortie
    output = layers.Dense(1, activation='sigmoid')(merged)
    
    model = models.Model(inputs=[view2.input, view3.input], outputs=output)
    return model

# Créer le modèle Siamese
model = siamese_network(view2, view3)

# Compiler le modèle
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Afficher le résumé du modèle
model.summary()

# Entraîner le modèle (on peut ajuster les labels de clustering ici)
# En l'absence de labels spécifiques, utilisons des labels fictifs (par exemple, 0 ou 1 pour un exemple binaire)
labels = np.random.randint(0, 2, size=(sentences.shape[0], 1))

model.fit([texts, links], labels, epochs=10, batch_size=32)


texts_rep = view2.predict(texts)
links_rep = view3.predict(links)

# Fusionner les représentations des trois vues pour clustering
combined_rep = np.concatenate([texts_rep, links_rep], axis=1)

Epoch 1/10
[1m174/174[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - accuracy: 0.4688 - loss: 0.6935
Epoch 2/10
[1m174/174[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.5053 - loss: 0.6932
Epoch 3/10
[1m174/174[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.4983 - loss: 0.6934
Epoch 4/10
[1m174/174[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.4970 - loss: 0.6934
Epoch 5/10
[1m174/174[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.5088 - loss: 0.6931
Epoch 6/10
[1m174/174[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.5005 - loss: 0.6931
Epoch 7/10
[1m174/174[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.5005 - loss: 0.6933
Epoch 8/10
[1m174/174[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.5156 - loss: 0.6928
Epoch 9/10
[1m174/174[0m [32m━━━━━━━━

### KMeans

In [50]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=980, random_state=42)
pred_labels = kmeans.fit_predict(combined_rep)

In [None]:
nmi = compute_nmi(filtered_labels, pred_labels)
print(f"NMI : {nmi:.4f}")

In [45]:
from collections import defaultdict

def labels_to_clusters(labels):
    cluster_map = defaultdict(list)
    for idx, label in enumerate(labels):
        cluster_map[label].append(idx)
    return list(cluster_map.values())


In [None]:
gold_clusters = labels_to_clusters(filtered_labels)
pred_clusters = labels_to_clusters(pred_labels)

evaluate_canonicalization(gold_clusters, pred_clusters)

### Reformulate and process sentences 

In [None]:
def focus_on_entity(entity, sentence):
    prompt = f"""
You are an expert in entity-focused summarization. Your task is to rewrite a given sentence, emphasizing the specified entity. 

Instructions:
- Focus only on the parts of the sentence that are directly related to the entity.
- Remove or ignore any information that diverges from the entity or does not clarify its role or characteristics.
- Keep the result concise but informative.
- Do not hallucinate or invent information not present in the original sentence.
- If the sentence contains no useful information about the entity, return "null".

Examples:

Entity: The Guardian
Sentence: The Guardian at Kefka's Tower can be a difficult boss. It runs several "Battle Programs" which can deal massive damage.
→ Focused Rewrite: The Guardian is a difficult boss in Kefka's Tower, known for its damaging Battle Programs.

Entity: Franz Kafka
Sentence: Franz Kafka was born in Prague, Bohemia, July 3, 1883 and died June 3, 1924 of tuberculosis at the age of 40.
→ Focused Rewrite: Franz Kafka was born in Prague in 1883 and died of tuberculosis in 1924.

Entity: Doe
Sentence: Furthermore, Doe had a special affinity for Babangida, naming a branch of the University of Liberia which he established in honour of the Nigerian president.
→ Focused Rewrite: Doe honored Babangida by naming a branch of the University of Liberia after him.

Now process the following:

Entity: {entity}
Sentence: {sentence}

Focused Rewrite:
    """

    print(f"Entity: {entity}")
    chat_response = client.chat.complete(
        model= mistral,
        messages=[
            {
                "role": "system",
                "content": prompt
            }
        ]
    )
    return chat_response.choices[0].message.content.strip()


In [None]:
import json

# Charger le fichier JSON
input_file = 'C:/ProjetMaster/scripts/reformulated_sentences.json' 
output_file = 'C:/ProjetMaster/scripts/reformulated_treated_sentences.json'

with open(input_file, 'r', encoding="utf-8") as file:
    data = json.load(file)


def process_sentences(sentences, max_length=200):
    filtered_sentences = []
    for sentence in sentences:
        print(f"sentence: {sentence}")
        if len(sentence) <= max_length:
            filtered_sentences.append(sentence)
    unique_sentences = list(set(filtered_sentences))
    return unique_sentences

for item in data:
    item['sentences'] = process_sentences(item['reformulated_sentences'])
    print(f"sentences : {item['src_sentences']}")
    print(f"-"*100)

# Sauvegarder le nouveau fichier JSON
with open(output_file, 'w') as file:
    json.dump(data, file, indent=4)

print(f"Le fichier a été traité et sauvegardé sous {output_file}")

In [None]:
# Charger le fichier JSON
with open('C:/ProjetMaster/scripts/reformulated_sentences.json', 'r', encoding="utf-8") as f:
    data = json.load(f)

all_inputs = []
all_labels = []

for item in data:
    subj_text = item['triple_norm'][0]
    obj_text = item['triple_norm'][2]
    sentences = " ".join(item.get("sentences", []))
    entity_linking = item.get("entity_linking") or {}
    subj_link = entity_linking.get("subject", "") or ""
    obj_link = entity_linking.get("object", "") or ""
    true_subj = item.get("true_link", {}).get("subject")
    true_obj = item.get("true_link", {}).get("object")
    
    all_inputs.append((subj_text, sentences, subj_link))
    all_labels.append(true_subj)
    
    all_inputs.append((obj_text, sentences, obj_link))
    all_labels.append(true_obj)

In [None]:
from collections import Counter
import numpy as np

# Compter les occurrences de chaque type de label
label_counter = Counter(all_labels)

# Filtrer les labels qui ont plus d'une occurrence
labels_to_keep = {label for label, count in label_counter.items() if count > 1}

# Filtrer les entrées de données
filtered_labels = []
filtered_inputs = []

for i, label in enumerate(all_labels):
    if label in labels_to_keep:
        filtered_labels.append(label)
        filtered_inputs.append(all_inputs[i])

In [None]:
texts = [x[0] for x in filtered_inputs]
sentences = [x[1] for x in filtered_inputs]
links = [x[2] for x in filtered_inputs]

In [None]:
# Fusionner texte + lien en une seule string
inputs_for_embedding = [f"{text} {link}" for text, link in zip(texts, links)]

In [None]:
vectorized_sentences = minilm_emb_model.encode(sentences)
vectorized_texts = minilm_emb_model.encode(texts)
vectorized_links = minilm_emb_model.encode(links)
X_combined = np.hstack([vectorized_texts, vectorized_sentences])

In [None]:
kmeans = KMeans(n_clusters=546, random_state=42)
pred_labels = kmeans.fit_predict(X_combined)

In [None]:
from collections import defaultdict

def labels_to_clusters(labels):
    cluster_map = defaultdict(list)
    for idx, label in enumerate(labels):
        cluster_map[label].append(idx)
    return list(cluster_map.values())

In [None]:
gold_clusters = labels_to_clusters(filtered_labels)
pred_clusters = labels_to_clusters(pred_labels)

evaluate_canonicalization(gold_clusters, pred_clusters)

## During Clustering 

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def mmr(doc_embeddings, centroid, k=3, lambda_param=0.5):
    doc_embeddings = np.array(doc_embeddings)
    centroid = np.array(centroid).reshape(1, -1)

    similarities = cosine_similarity(doc_embeddings, centroid).flatten()
    selected = []
    candidates = list(range(len(doc_embeddings)))

    k = min(k, len(doc_embeddings)) 
    
    for _ in range(k):
        if not selected:
            selected_idx = np.argmax(similarities)
        else:
            selected_similarities = cosine_similarity(doc_embeddings[candidates], doc_embeddings[selected])
            max_sim = np.max(selected_similarities, axis=1)
            mmr_score = lambda_param * similarities[candidates] - (1 - lambda_param) * max_sim
            selected_idx = candidates[int(np.argmax(mmr_score))]
        selected.append(selected_idx)
        candidates.remove(selected_idx)

    return selected

In [None]:
def generate_summary_with_llm(documents):

    prompt = """You are an expert in entity canonicalization through clustering. Given multiple mentions of potentially similar entities across different contexts, 
    your goal is to summarize each cluster by identifying the canonical entity name and capturing the central theme or identity behind these mentions.
    Do not give any explanation, just return the answer without any other word.

    Below are examples of how to effectively summarize such clusters:
    Example 1
    Mention 1: "franz kafka"
    Mention 2: "kafka, franz"
    Mention 3: "Franz Kafka, the Prague-born writer"
    Summary: Franz Kafka

    Example 2
    Mention 1: "nyc"
    Mention 2: "new york city"
    Mention 3: "the Big Apple"
    Summary: New York City

    Now, based on the following mentions grouped in a cluster, please identify the canonical entity name :"""
        
    prompt += "\n".join(documents)

    # Tokenizer le prompt et le passer au modèle
    chat_response = client.chat.complete(
        model= mistral,
        messages = [
            {
                "role": "system", "content": prompt,
            },
        ],
        temperature=0.5
    )
    return chat_response.choices[0].message.content

In [None]:
def assign_clusters(embeddings, centroids):
    labels = []
    # Calcul de la similarité cosinus entre chaque embedding et les centroids
    for emb in embeddings:
        similarities = cosine_similarity([emb], centroids)  # cosine_similarity attend des entrées sous forme de matrices
        labels.append(np.argmax(similarities))  # On sélectionne le centroïde avec la plus grande similarité
    return labels

In [None]:
import time
def k_llmmeans(documents, n_clusters, update_iter, total_iter):
    embeddings = minilm_emb_model.encode(documents)
    
    kmeans = KMeans(n_clusters=n_clusters, init='k-means++')
    kmeans.fit(embeddings)
    centroids = kmeans.cluster_centers_
    labels = kmeans.labels_

    for iter in range(total_iter):
        print(f"iter = {iter}")
        if iter % update_iter == 0:
            print(">> Updating centroids via LLM summarization")
            # Mettre à jour les centroïdes avec LLM
            for j in range(n_clusters):
                # Extraire les indices des documents dans le cluster j
                cluster_indices = [i for i in range(len(documents)) if kmeans.labels_[i] == j]
                if not cluster_indices:
                    continue  

                cluster_embeds = [embeddings[i] for i in cluster_indices]
                cluster_texts = [documents[i] for i in cluster_indices]
                print(f"cluster_texts : {cluster_texts}")
                top_k = min(3, len(cluster_texts))  
                selected_idxs = mmr(cluster_embeds, centroids[j], k=top_k, lambda_param=0.5)
                print(f"selected_idxs : {selected_idxs}")
                cluster_docs = [cluster_texts[i] for i in selected_idxs]
                print(cluster_docs)
                time.sleep(5)
                # Générer un résumé avec le LLM
                summary = generate_summary_with_llm(cluster_docs)
                print(f"summary : {summary}")
                print("-"*200)
                # Re-calculer l'embedding du résumé (nouveau centroïde)
                summary_embedding = minilm_emb_model.encode([summary])[0]
                centroids[j] = summary_embedding
                
        else:
            print(">> Updating centroids via mean of cluster embeddings")
            # Mise à jour des centroïdes avec l'algorithme classique (moyenne des embeddings)
            for j in range(n_clusters):
                cluster_embeddings = [embeddings[i] for i in range(len(documents)) if kmeans.labels_[i] == j]
                if cluster_embeddings:
                    centroids[j] = np.mean(cluster_embeddings, axis=0)

        # Réassignation manuelle des documents aux clusters
        labels = assign_clusters(embeddings, centroids)
        
    return labels, centroids

In [None]:
n_clusters = 546
update_iter = 5
total_iter = 20
labels, centroids = k_llmmeans(inputs_for_embedding, n_clusters, update_iter, total_iter)

In [None]:
from collections import defaultdict

def labels_to_clusters(labels):
    cluster_map = defaultdict(list)
    for idx, label in enumerate(labels):
        cluster_map[label].append(idx)
    return list(cluster_map.values())

In [None]:
gold_clusters = labels_to_clusters(filtered_labels)
pred_clusters = labels_to_clusters(labels)

evaluate_canonicalization(gold_clusters, pred_clusters)

In [None]:
import json
import numpy as np

clustering_results = []
for i in range(len(inputs_for_embedding)):
    clustering_results.append({
        "document": inputs_for_embedding[i],
        "label": labels[i],
        "true_label": filtered_labels[i],
        "centroid": centroids[labels[i]]  # Le centroid correspondant au label
    })

# Sauvegarder dans un fichier JSON
with open('C:/ProjetMaster/scripts/kllmmeans_data_result.json', 'w') as f:
    json.dump(clustering_results, f, indent=4)

print("Clustering results saved")

## Outliers Assignment

In [None]:
import re

def get_outliers(vectors, centroids, labels, threshold=0.5):
    label_to_index = {label: idx for idx, label in enumerate(set(labels))}
    label_indices = [label_to_index[label] for label in labels]
    distances = np.linalg.norm(vectors - np.array([centroids[label_indices[i]] for i in range(len(labels))]), axis=1)
    outlier_indices = np.where(distances > threshold)[0]
    outlier_clusters = {idx: labels[idx] for idx in outlier_indices}
    return outlier_clusters
    

def get_top_k_nearest_clusters(doc_vector, centroids, k=5):
    dists = np.linalg.norm(centroids - doc_vector, axis=1)
    return np.argsort(dists)[:k]


def get_top_n_representative_docs(cluster_id, vectors, labels, centroids, n=3):
    cluster_indices = np.where(labels == cluster_id)[0]
    cluster_vectors = vectors[cluster_indices]
    dists = np.linalg.norm(cluster_vectors - centroids[cluster_id], axis=1)
    top_n_indices = cluster_indices[np.argsort(dists)[:n]]
    return top_n_indices


def classify_with_llm(outlier_doc, candidate_clusters, representatives, documents, llm_decision_function):
    prompt = """You are an expert in entity canonicalization and reassigning outliers.
    You will be given:
    - An entity to reassign
    - 5 clusters, each represented by 3 typical entities

    Your task is to:
    - Analyze the given entity
    - Compare it with the examples from each cluster
    - Determine the most appropriate cluster the entity should belong to

    Output format:
    - Return only the selected cluster number
    - Do not return any explanation or extra words

    Here are 2 examples:

    Example 1 : 
    Entity to reassign: Azure
    Cluster 1: 
        - Google Cloud
        - Google Cloud Platform
        - GCP
    Cluster 2: 
        - AWS
        - Amazon Web Services
        - Amazon Cloud
    Cluster 3: 
        - Microsoft Azure
        - Azure
        - MS Azure
    Cluster 4: 
        - IBM Cloud
        - IBM Public Cloud
        - IBM Cloud Services
    Cluster 5: 
        - Oracle Cloud
        - OCI
        - Oracle Cloud Infrastructure
    Answer: Cluster 3

    Example 2 : 
    Entity to reassign: The Coca-Cola Company
    Cluster 1: 
        - Coca-Cola
        - Coke
        - Coca Cola
    Cluster 2: 
        - Pepsi
        - PepsiCo
        - Pepsi-Cola
    Cluster 3: 
        - Sprite
        - Sprite Soda
        - Sprite Drink
    Cluster 4: 
        - Fanta
        - Fanta Orange
        - Fanta Soda
    Cluster 5: 
        - Dr Pepper
        - Dr. Pepper
        - DrPepper
    Answer: Cluster 1

    Now perform the same task on the next input.
    """

    prompt += f"Entity to reassign:\n{outlier_doc}...\n\n" 
    for i, cluster_id in enumerate(candidate_clusters):
        prompt += f"\nCluster {cluster_id}:\n"
        for doc_id in representatives[cluster_id]:
            prompt += f"- {documents[doc_id]}\n"
    print(f"prompt: {prompt}")
    return llm_decision_function(prompt)


def reassign_outliers(documents, vectors, centroids, labels, llm_decision_function, threshold=0.5):
    outliers = get_outliers(vectors, centroids, labels, threshold)
    new_assignments = {}
    for idx in outliers:
        current_label = labels[idx]
        print(f"idx: {idx}")
        print(f"Current label: {labels[idx]}")
        doc_vector = vectors[idx]
        candidate_clusters = get_top_k_nearest_clusters(doc_vector, centroids)  
        print(f"candidate_clusters : {candidate_clusters}")
        representatives = {
            cluster_id: get_top_n_representative_docs(cluster_id, vectors, labels, centroids)
            for cluster_id in candidate_clusters
        }
        time.sleep(5)
        response = classify_with_llm(
            documents[idx],
            candidate_clusters,
            representatives,
            documents,
            llm_decision_function
        )
        print(f"response = {response}")
        match = re.search(r'\d+', response)
        if match:
            new_label= int(match.group())
        else:
            print(f"Aucun chiffre trouvé dans la réponse : {response}")
            new_label = labels[idx]
        print(f"new_label = {new_label}")
        if new_label != current_label:
            new_assignments[idx] = new_label
        else: 
            new_assignments[idx] = current_label
        print(f"-"*100)
    return new_assignments


def ask_llm(prompt):
    chat_response = client.chat.complete(
        model= mistral,
        messages = [
            {
                "role": "system", "content": prompt,
            },
        ],
        temperature=0.3
    )
    return chat_response.choices[0].message.content

In [None]:
embeddings = minilm_emb_model.encode(inputs_for_embedding)

In [None]:
new_labels = reassign_outliers(inputs_for_embedding, embeddings, centroids, labels, ask_llm)

In [None]:
for idx, new_label in new_labels.items():
    labels[idx] = new_label

In [None]:
gold_clusters = labels_to_clusters(filtered_labels)
pred_clusters = labels_to_clusters(labels)

evaluate_canonicalization(gold_clusters, pred_clusters)

# Centroids generated by LLM

In [None]:
with open("C:/projetsAlternance/projetMaster/reformulated_sentences.json", "r", encoding="utf_8") as f:
    all_data = json.load(f)

In [None]:
import pandas as pd
formatted_data = []

for data in all_data:
    subject_raw = data["triple_norm"][0]
    object_raw = data["triple_norm"][2]
    
    subject_el = data["entity_linking"].get("subject", {})
    object_el = data["entity_linking"].get("object", {})
    
    subject_true = data["true_link"].get("subject", {})
    object_true = data["true_link"].get("object", {})
    
    # Ajouter les sujets et objets séparés dans le format demandé
    formatted_data.append({
        "entity_input": f"{subject_raw} ({subject_el})", 
        "true_linking": subject_true,
    })
    formatted_data.append({
        "entity_input": f"{object_raw} ({object_el})", 
        "true_linking": object_true,
    })

# Création du DataFrame
df = pd.DataFrame(formatted_data)

# Affichage du DataFrame
print(df)

In [None]:
filtered_df = df.groupby('true_linking').filter(lambda x: len(x) > 1)

# Afficher le résultat
print(len(filtered_df))

In [None]:
print(filtered_df["true_linking"].nunique())

In [None]:
def generate_centroid(entities):
    prompt = """You are an expert generating a representative canonicalized entity for a cluster of entity canonicalization.
    You will be given:
    - 4 entities which belong to one entity cluster.
    
    Your task is to analyze these entities and provide an entity will that represent all 4 entities given. 
    Do not give any explanation.
    Return only the representative canonicalized entity without any other word.
    
    Here is 3 examples : 
    Example 1 : 
        - entity 1 : Barack Obama
        - entity 2 : President Obama
        - entity 3 : Barack H. Obama
        - entity 4 : The 44th president of USA
    the representative canonicalized entity is : Barack Obama

    Example 2 : 
        - entity 1 : International Business Machines
        - entity 2 : IBM
        - entity 3 : I.B.M.
        - entity 4 : IBM Corp
    the representative canonicalized entity is : International Business Machines

    Example 3 :
        - entity 1 : Google Inc.
        - entity 2 : Google
        - entity 3 : Google LLC
        - entity 4 : google.com
    the representative canonicalized entity is : Google

    Do the same with : 
    """
    for entity in entities:
        prompt += f"entity : {entity}\n"
    chat_response = client.chat.complete(
        model= mistral,
        messages = [
            {
                "role": "system", "content": prompt,
            },
        ],
        temperature=0.5
    )
    return chat_response.choices[0].message.content

In [None]:
import time

def generate_all_centroids(df):
    centroids = {}
    for label, group in df.groupby('true_linking'):
        print(f"label = {label}")
        print(f"group = {group}")
        sample_size = min(4, len(group))
        entities_for_centroid = group['entity_input'].sample(sample_size, random_state=42).tolist()
        time.sleep(5)
        centroid = generate_centroid(entities_for_centroid)
        if centroid:
            centroids[label] = centroid
            print(f"centroid : {centroid}")
        else:
            print(f"Failed to generate centroid for intent: {label}")
        print(f"-"*300)
    return centroids

In [None]:
centroids = generate_all_centroids(filtered_df)

In [None]:
centroid_sentences = list(centroids.values())

In [None]:
filtered_df = filtered_df.reset_index(drop=True)

In [None]:
centroids_emb = minilm_emb_model.encode(centroid_sentences)
entities_emb = minilm_emb_model.encode(filtered_df["entity_input"])

In [None]:
# 6. KMeans avec init LLM
kmeans = KMeans(n_clusters=546, init=centroids_emb)
labels = kmeans.fit_predict(entities_emb)

In [None]:
gold_clusters = labels_to_clusters(filtered_df["true_linking"])
pred_clusters = labels_to_clusters(labels)

evaluate_canonicalization(gold_clusters, pred_clusters)