In [1]:
import random

import torch

import pandas as pd

from ast import literal_eval

from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.example_selectors.base import BaseExampleSelector

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import pipeline

from src.dataset import Dataset

from src.link_prediction.evaluation import Evaluator

from src.utils import load_model

from src.utils import format_paths
from src.utils import read_json

from src.utils import set_seeds

from src import DATA_PATH, LP_CONFIGS_PATH

from src import DB50K, DB100K, YAGO4_20

In [2]:
set_seeds(42)

In [3]:
def init_pipeline():
    model_id = "Meta-Llama-3.1-8B-Instruct"
    quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
    model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
    # model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")

    tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    pipe = pipeline(
        "text-generation", 
        model=model, 
        tokenizer=tokenizer,
        eos_token_id=terminators,
        pad_token_id=tokenizer.eos_token_id,
        batch_size=128,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
        max_length=8192,
        truncation=True
    )
    pipe.tokenizer.pad_token_id = model.config.eos_token_id[0]
    pipe = HuggingFacePipeline(pipeline=pipe)

    return pipe

In [None]:
pipe = init_pipeline()

In [82]:
prefix = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful, respectful and honest assistant.
Your response should be crisp, short and not repetitive.
Discard any preamble, explanation, greeting, or final consideration.
An RDF triple is a statement (subject, predicate, object).
The subject and the object are entities, and the predicate is a relation between the subject and the object.
Perform a Link Prediction (LP) task, specifically, given an incomplete RDF triple (subject, predicate, ?), predict the missing object that completes the triple and makes it a true statement.
Strict requirement: output solely the name of a single object entity, discard any explanations or other text. 
Correct format: Elizabeth_of_Bohemia
Incorrect format: The object entity is Elizabeth_of_Bohemia.
{ranking}
"""

suffix = """
<|eot_id|><|start_header_id|>user<|end_header_id|>
({subject}, {predicate}, ?)
{explanation}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

explanation_hook = """
In addition to the incomplete triple, an explanation is provided.
An explanation is a set of RDF triples that provide context for the prediction.
Given an incomplete triple (s, p, ?), all the RDF triples in the explanation feature s as subject or as object and connect it with another entity.
Explanation:
"""

qhook = """
This explanation is also provided in the form of quotient triples to provide a more abstract view of the explanation.
A quotient triple is an RDF triple featuring the subject entity of the incomplete triple as subject or as object and connects it with a concept name rather than an entity.
Considering an example explanation with the following triples:
(s, p, e)
(s, p, f)
(s, p, g)

The corresponding quotient triples are:
(s, p, C1)
(s, p, C2)

where C1, C2 are concept names, e and f are instances of C1, and g is an instance of C2.

Your answer still must be a single entity name even if the quotient triples contain concept names.

Finally, the quotient triples are:
"""

ranking_hook = """
The entity name that you provide must be in the following list of entity names:
"""

example_prompt_str_template = """
<|eot_id|><|start_header_id|>user<|end_header_id|>
({subject}, {predicate}, ?)
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{object}
"""

In [83]:
class CustomExampleSelector(BaseExampleSelector):
    def __init__(self, dataset, df_output): 
        self.dataset = dataset

        self.df_output = df_output

    def add_example(example):
        return

    def select_examples(self, input_variables):
        predicate = input_variables["predicate"]

        examples = self.df_output[self.df_output["p"] == predicate].copy()
        examples = examples[examples["o_rank"] == 1]
        examples = examples.drop(columns=["s_rank", "o_rank"])
        examples = examples.rename(columns={"s": "subject", "p": "predicate", "o": "object"})

        examples = examples.to_dict("records")

        return random.sample(examples, 10) if len(examples) > 10 else examples

In [84]:
def build_explanation(explained_pred, summarization, e_sem):
    if summarization == "no":
        explanation = explained_pred["explanation"]
        explanation = [f"({s}, {p}, {o})\n" for (s, p, o) in explanation]
        explanation = "\n".join(explanation)
        explanation = f"{explanation_hook}\n{explanation}"

        return explanation
    elif summarization == "simulation" or summarization == "bisimulation":
        triples = explained_pred["explanation"]
        triples_string = [f"({s}, {p}, {o})" for (s, p, o) in triples]
        triples_string = "\n".join(triples_string)

        pred_s = explained_pred["pred"][0]

        qtriples = explained_pred["quotient_explanation"]
        qtriples_string = []
        for qtriple in qtriples:
            s, p, o = qtriple
            if len(s) == 1 and s[0] == pred_s:
                s_ = s[0]
            else:
                if len(e_sem[s[0]]) == 0:
                    e_sem[s[0]] = "Thing"
                s_ = f"[{e_sem[s[0]]}]"
            if len(o) == 1 and o[0] == pred_s:
                o_ = o[0]
            else:
                if len(e_sem[o[0]]) == 0:
                    e_sem[o[0]] = "Thing"
                o_ = f"[{e_sem[o[0]]}]"

            if s_ == o_:
                s_ = pred_s
            qtriples_string.append(f"({s_}, {p}, {o_})")
        qtriples_string = "\n".join(qtriples_string)

        types = [f"{s} is an instance of {e_sem[s]}" for s, _, _ in triples if s != pred_s]
        types += [f"{o} is an instance of {e_sem[o]}" for _, _, o in triples if o != pred_s]

        types = "\n".join(types)

        explanation = f"{explanation_hook}\n{triples_string}\n{qhook}\n{qtriples_string}\n{types}"

        return explanation

def build_ranking(ranking):
    ranking = [o for o in ranking]
    ranking = "\n".join(ranking)
    ranking = f"{ranking_hook}{ranking}"

    return ranking

def parse(response):
    substring = "<|start_header_id|>assistant<|end_header_id|>\n"
    pos = response.rfind(substring) + len(substring)
    response = response[pos:]
    response = response.replace("\n", "")
    response = response.replace(" ", "_")

    return response

In [106]:
method = "kelpie++"
mode = "necessary"
model = "TransE"
dataset = "DB50K"
summarization = "simulation"
include_ranking = True
examples = False

paths = format_paths(method, mode, model, dataset, summarization)

lp_config_path = LP_CONFIGS_PATH / f"{model}_{dataset}.json"
lp_config = read_json(lp_config_path)

In [107]:
print(f"Loading dataset {dataset}...")
dataset = Dataset(dataset)
print(f"Loading model {model}...")
model = load_model(lp_config, dataset)
model.eval()
evaluator = Evaluator(model=model)
ranks = evaluator.evaluate(triples=dataset.training_triples)
df_output = evaluator.get_df_output(triples=dataset.training_triples, ranks=ranks)
example_selector = CustomExampleSelector(dataset, df_output)
example_prompt_template = PromptTemplate(
    template=example_prompt_str_template, input_variables=["subject", "predicate", "object"]
)

Loading dataset DB50K...


You're trying to map triples with 11247 entities and 22 relations that are not in the training set. These triples will be excluded from the mapping.
In total 8871 from 10969 triples were filtered out
You're trying to map triples with 332 entities and 1 relations that are not in the training set. These triples will be excluded from the mapping.
In total 276 from 399 triples were filtered out


Loading model TransE...


                                                                                                                                                        

In [108]:
few_shot_template = FewShotPromptTemplate(
    example_selector=example_selector if examples else None,
    examples=None if examples else [],
    example_prompt=example_prompt_template,
    input_variables=["subject", "predicate", "explanation", "ranking"],
    prefix=prefix,
    suffix=suffix,
    example_separator=""
)

In [109]:
chain = few_shot_template | pipe | parse

In [110]:
rankings = read_json(paths["rankings"])
rankings = {tuple(ranking["pred"]): ranking["ranking"][:100] for ranking in rankings}
explained_preds = read_json(paths["exps"])

explained_preds = [
    {
        "pred": explained_pred["pred"],
        "explanation": explained_pred["explanation"],
        "quotient_explanation": explained_pred.get("quotient_explanation", None),
        "label": explained_pred.get("label", None),
        "ranking": rankings[tuple(explained_pred["pred"])]
    }
    for explained_pred in explained_preds
]

# explained_preds = [ep for ep in explained_preds if ep["label"] == 1]

# random.shuffle(explained_preds)
explained_preds = explained_preds[:128]

In [111]:
queries = [
    {
        "subject": explained_pred["pred"][0],
        "predicate": explained_pred["pred"][1],
        "explanation": "",
        "ranking": build_ranking(explained_pred["ranking"]) if include_ranking else "",
    }
    for explained_pred in explained_preds
]

In [112]:
gts = [ep["pred"][2] for ep in explained_preds]

In [113]:
simulations = chain.batch(queries)

In [114]:
e_sem = None

if dataset.name in [DB50K, DB100K, YAGO4_20]:
    e_sem = pd.read_csv(
        DATA_PATH / dataset.name / "reasoned" / "entities.csv",
        converters={"classes": literal_eval},
    )
    e_sem["classes"] = e_sem["classes"].map(sorted)
    e_sem["classes"] = e_sem["classes"].map(lambda x: [c.split("/")[-1] for c in x])
    e_sem["classes"] = e_sem["classes"].map(", ".join)
    e_sem = e_sem.to_dict("records")
    e_sem = {e["entity"]: e["classes"] for e in e_sem}

explained_queries = [
    {
        "subject": explained_pred["pred"][0],
        "predicate": explained_pred["pred"][1],
        "explanation": build_explanation(explained_pred, summarization, e_sem),
        "ranking": build_ranking(explained_pred["ranking"]) if include_ranking else "",
    }
    for explained_pred in explained_preds
]

In [115]:
print(few_shot_template.invoke(explained_queries[0]).text)


<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful, respectful and honest assistant.
Your response should be crisp, short and not repetitive.
Discard any preamble, explanation, greeting, or final consideration.
An RDF triple is a statement (subject, predicate, object).
The subject and the object are entities, and the predicate is a relation between the subject and the object.
Perform a Link Prediction (LP) task, specifically, given an incomplete RDF triple (subject, predicate, ?), predict the missing object that completes the triple and makes it a true statement.
Strict requirement: output solely the name of a single object entity, discard any explanations or other text. 
Correct format: Elizabeth_of_Bohemia
Incorrect format: The object entity is Elizabeth_of_Bohemia.

The entity name that you provide must be in the following list of entity names:
Swamp_Ophelia
Rites_of_Passage_(Indigo_Girls_album)
Hanna-Barbera's_All-Star_Comedy_Ice_Revue
CBBC
Pepe_Smith
Br

In [116]:
post_exp_simulations = chain.batch(explained_queries)

In [117]:
post_exp_simulations

['Swamp_Ophelia',
 'Cutfather',
 'Plant',
 'Plant',
 'Plant',
 'Mushroom_Records',
 'Moscow,_Idaho',
 'Fleetwood_Mac',
 'Vatican_City',
 'Electronic',
 'Orchidaceae',
 'Universal_Studios',
 'Tucson,_Arizona',
 'Karen_Clark_Sheard',
 'Airbus_Military',
 'New_Iberia,_Louisiana',
 'Bun_B',
 'South_Korea',
 'Cicadinae',
 'Michael_Szloszer',
 'Megaloptera',
 'Plant',
 'Plant',
 'Mind_on_the_Moon',
 'Sharon_Lowe',
 'Q19088',
 'Bee',
 'Robyn_Is_Here',
 'Thriller_(genre)',
 'Uganda',
 'Miki_Garrod',
 'Single',
 'Zavalaz',
 'Vancouver',
 'Allopterigeron',
 'Persim_Maros',
 'Mindy_McCready',
 'Australia',
 'London',
 'Q729',
 'K-pop',
 'Central_Coast_AVA',
 'Phnom_Penh_Crown_FC',
 'Augustus_George_Vernon_Harcourt',
 'Liu_Shaoqi',
 'Aston_Martin_V8',
 'Moth',
 'Plantae',
 'Iran',
 'Abiy_Ahmed',
 'Germany',
 'Plant',
 'Warrant_(American_band)',
 'Kenneth_Bulmer',
 'Q134556',
 'Windsor,_Connecticut',
 'Rufus_(band)',
 'Steven_Brill_(scriptwriter)',
 'Târnava_Mică_River',
 'Stalino',
 'Arctiinae',
 

In [118]:
predictability_pre = [1 if o == gt else 0 for o, gt in zip(simulations, gts)]
predictability_post = [1 if o == gt else 0 for o, gt in zip(post_exp_simulations, gts)]

In [119]:
explanation_labels = [post - pre for post, pre in zip(predictability_post, predictability_pre)]
gt_explanation_labels = [x["label"] for x in explained_preds]

In [120]:
from sklearn.metrics import classification_report, confusion_matrix, cohen_kappa_score

print(confusion_matrix(gt_explanation_labels, explanation_labels))
print(classification_report(gt_explanation_labels, explanation_labels, zero_division=0))
print(cohen_kappa_score(gt_explanation_labels, explanation_labels))

ValueError: Classification metrics can't handle a mix of unknown and multiclass targets

In [105]:
avg_pre = sum(predictability_pre) / len(predictability_pre)
avg_post = sum(predictability_post) / len(predictability_post)

print(f"{avg_post} - {avg_pre} = {avg_post - avg_pre}")

0.28 - 0.04 = 0.24000000000000002


In [61]:
# from scipy.stats import spearmanr

# correlation, p_value = spearmanr(gt_explanation_labels, explanation_labels)

# print(f"Spearman Rank Correlation: {correlation}")
# print(f"P-value: {p_value}")

In [62]:
# idxs = [i for i, ep in enumerate(explained_preds) if ep["label"] == 1]

# print(idxs)

# for i in idxs:
#     print(f"{predictability_post[i]} - {predictability_pre[i]}")


In [63]:
# output = []
# for i, explained_pred in enumerate(explained_preds):
#     output_ = {
#         "s": explained_pred["pred"][0],
#         "p": explained_pred["pred"][1],
#         "o": explained_pred["pred"][2],
#         "simulation": simulations[i],
#         "post_exp_simulation": post_exp_simulations[i],
#         "explanation": explained_pred["explanation"],
#     } 

#     output.append(output_)

# df = pd.DataFrame(output, index=None)
# df.to_csv("test.csv")


In [64]:
#   method mode      model  dataset summarization
# 1 kelpie necessary TransE DB50K   no
#
# 1 plain                             
# 2 plain + ranking                   
# 3 plain + esempi                    
# 4 plain + ranking + esempi          
#
# 2 kelpie necessary TransE DB100K  no
#
# 1 plain                             
# 2 plain + ranking                   
# 3 plain + esempi                    
# 4 plain + ranking + esempi          
#
# 3 kelpie necessary TransE DB100K  simulation
#
# 1 plain                             
# 2 plain + ranking                   
# 3 plain + esempi           
# 4 plain + ranking + esempi   