In [None]:
from pathlib import Path
import json
import pandas as pd
from ollama import Client
from tqdm import tqdm

In [None]:
df = pd.read_parquet("datasets/muc/muc.parquet")
df.head()

In [None]:
unique_incident_types = df["incident_type"].unique()
unique_incident_types

In [None]:
labels = {
    "Incident": "One of 'Arson', 'Attack', 'Bombing', 'Kidnapping', 'Robbery' or 'None'",
    "Perpetrator": "An individual perpetrator",
    "Group Perpetrator": "A group or organizational perpetrator",
    "Victim": "Sentient victims of the incident",
    "Target": "Physical objects targeted by the incident",
    "Weapon": "Weapons employed by the perpetrators",
}

In [None]:
slots = ["incident", "perpetrator", "group perpetrator", "victim", "target", "weapon"]

In [None]:
client = Client(host='http://localhost:19290')
client.list()

In [None]:
system_prompt = "You are a system to support the analysis of large amounts of text. You will assist the user by extracting the required information from the provided documents. You will always answer in the required format and use no other formatting than expected by the user!"

In [None]:
user_prompt = """
I want you to extract the following information about incidents from the text below. The slots are:

Incident: One of 'Arson', 'Attack', 'Bombing', 'Kidnapping', 'Robbery'
Perpetrator: An individual perpetrator
Group Perpetrator: A group or organizational perpetrator
Victim: Sentient victims of the incident
Target: Physical objects targeted by the incident
Weapon: Weapons employed by the perpetrators

Please extract the information about the incidents (if any) from the following text:
{}

Respond in the following format:
Incident: <incident type>
Perpetrator: <perpetrator>
Group Perpetrator: <group perpetrator>
Victim: <victim>
Target: <target>
Weapon: <weapon>

e.g.
Incident: Arson
Perpetrator: John Doe
Group Perpetrator: None
Victim: None
Target: Building
Weapon: Matches

If there is no information about a certain slot in the provided text, leave it empty with "None".
Also, if there is no incident in the text, you have to leave the rest of the slots empty.

Remember, you MUST extract the information verbatim from the text, do not generate it!
"""

In [None]:
df.head()

In [None]:
from typing import Dict, List


def parse_response(response: str) -> Dict[str, List[str]]:
    result = {
        "incident": None,
        "perpetrator": [],
        "group perpetrator": [],
        "victim": [],
        "target": [],
        "weapon": [],
    }

    for line in response.strip().split("\n"):
        if not line.strip():
            continue
        if ":" not in line:
            continue

        splitted_line = line.split(":")
        if len(splitted_line) != 2:
            continue
        
        slot = splitted_line[0].strip()
        answer = splitted_line[1].strip()

        if slot.startswith("<"):
            slot = slot[1:]
        if slot.endswith(">"):
            slot = slot[:-1]

        if slot.startswith("**"):
            slot = slot[2:]
        if slot.endswith("**"):
            slot = slot[:-2]

        if slot.lower() not in result:
            continue

        if answer.lower() == "none":
            continue

        answer = answer.strip().lower()

        result[slot.lower()] = answer

    return result

In [None]:
golds = []
preds = []
messages = []   
for idx, sample in tqdm(df[:10].iterrows(), desc="Evaluating"):
    gold = {
        "incident": [sample["incident_type"]],
        "perpetrator": sample["PerpInd"].tolist(),
        "group perpetrator": sample["PerpOrg"].tolist(),
        "victim": sample["Victim"].tolist(),
        "target": sample["Target"].tolist(),
        "weapon": sample["Weapon"].tolist()
    }
    document = sample["doctext"]

    response = client.chat(model='gemma2', messages=[
        {
            'role': 'system',
            'content': system_prompt.strip(),
        },
        {
        'role': 'user',
        'content': user_prompt.format(document).strip(),
        },
    ])
    message = response["message"]["content"]
    pred = parse_response(message)

    golds.append(gold)
    preds.append(pred)
    messages.append(message)

In [None]:
golds[2]["perpetrator"]

In [None]:
preds[2]

In [None]:
df["doctext"][2]

In [None]:
import evaluate

squad_v2_metric = evaluate.load("squad_v2")

In [None]:
print(squad_v2_metric.inputs_description)

In [None]:
def transform_to_squad(slots: List[str], data: List[Dict[str, List[str]]], is_gold: bool) -> Dict[str, List[Dict[str, str]]]:
    transformed = {
        slot: [] for slot in slots
    }

    for idx, datapoint in enumerate(data):
        for slot in slots:
            assert slot in datapoint

            if is_gold:
                transformed[slot].append({'answers': {"answer_start": [0], "text": datapoint[slot]}, "id": str(idx)})
            else:
                has_answer = datapoint[slot] is not None and len(datapoint[slot]) > 0
                transformed[slot].append({'prediction_text': datapoint[slot][0] if has_answer else '', 'id': str(idx), 'no_answer_probability': 0.0 if has_answer else 1.0})

    return transformed

In [None]:
gold_transformed = transform_to_squad(slots, golds, is_gold=True)
pred_transformed = transform_to_squad(slots, preds, is_gold=False)

In [None]:
gold_transformed["incident"]

In [None]:
pred_transformed["incident"]

In [None]:
gold_transformed["perpetrator"]

In [None]:
pred_transformed["perpetrator"]

In [None]:
def report(slots, golds, preds):
    squad_v2_metric = evaluate.load("squad_v2")

    gold_transformed = transform_to_squad(slots, golds, is_gold=True)
    pred_transformed = transform_to_squad(slots, preds, is_gold=False)

    for slot in slots:
        assert len(gold_transformed[slot]) == len(pred_transformed[slot])
        print(f"Slot: {slot}")
        print(squad_v2_metric.compute(references=gold_transformed[slot], predictions=pred_transformed[slot]))

In [None]:
report(slots=slots, golds=golds, preds=preds)