### 1. Get All Domains in Wikidata

In [1]:
import requests
import os
from tqdm import tqdm
from json import JSONDecodeError
from time import sleep
import re
from urllib.parse import unquote, urlparse
import openai
import asyncio

url = "https://query.wikidata.org/sparql"

In [2]:
def get_domain_entites(qid, limit=10000):
    query = f"""
        SELECT ?entity ?entityLabel
        WHERE {{
            ?entity wdt:P31 wd:{qid}.
            ?entity rdfs:label ?entityLabel filter (lang(?entityLabel) = "en")
        }}
        LIMIT {limit}
    """
    r = requests.get(url, params={"format": "json", "query": query})
    try:
        data = r.json()
    except JSONDecodeError:
        sleep(5)
        r = requests.get(url, params={"format": "json", "query": query})
        data = r.json()

    entities = dict()

    for e in data["results"]["bindings"]:
        entity_id = e["entity"]["value"].split("/")[-1]
        entity_label = e["entityLabel"]["value"]
        entities[entity_id] = entity_label

    return entities


def get_entity_attributes(qid, retries=0):
    query = f"""
        SELECT ?property ?propertyLabel ?value ?valueLabel
        WHERE {{
            wd:{qid} ?p ?value .
            ?property wikibase:directClaim ?p .
            SERVICE wikibase:label {{ bd:serviceParam wikibase:language "en". }}
        }}
    """
    try:
        r = requests.get(url, params={"format": "json", "query": query})
        data = r.json()
    except JSONDecodeError:
        asyncio.sleep(5)
        r = requests.get(url, params={"format": "json", "query": query})
        data = r.json()
    except requests.ConnectTimeout:
        if retries > 5:
            raise KeyError
        asyncio.sleep(5)
        return get_entity_attributes(qid, retries + 1)

    properties = dict()

    for p in data["results"]["bindings"]:
        property_id = p["property"]["value"].split("/")[-1]
        property_label = p["propertyLabel"]["value"]
        value_label = p["valueLabel"]["value"]

        if property_id.startswith("P"):
            properties[property_id] = (property_label, value_label)

    return properties

In [3]:
def get_wikipedia_content(url):
    # Extracting the title from the URL
    parsed_url = urlparse(url)
    title = unquote(parsed_url.path.split("/")[-1])

    # API request setup
    api_url = f"https://{parsed_url.netloc}/w/api.php"
    params = {
        "action": "query",
        "format": "json",
        "titles": title,
        "prop": "extracts",
        "explaintext": True,
    }

    # Making the request
    try:
        response = requests.get(api_url, params=params)
    except requests.exceptions.InvalidURL:
        print("Invalid URL")
        print(parsed_url.netloc)
        raise

    data = response.json()

    # Extracting the page content
    page = next(iter(data["query"]["pages"].values()))
    if "extract" in page:
        return page["extract"]
    else:
        return "Article content not found."


def get_wikipedia_sentences(
    content, entity_name, word_upper_limit=20, word_lower_limit=5
):
    # Splitting the content into sentences by ". ", ".\n" or ".\t"
    sentences = [s.strip() for s in re.split(r"\.|\;|\,|\n|\!|\?", content)]

    # Filtering the sentences that mention the entity
    entity_sentences = [s for s in sentences if entity_name in s]

    # Filtering the sentences by word count
    entity_sentences = [
        s
        for s in entity_sentences
        if len(s.split()) <= word_upper_limit and len(s.split()) >= word_lower_limit
    ]

    return entity_sentences


def get_wikipedia_url(wikidata_id, language="en"):
    # Constructing the URL to call the API
    url = "https://www.wikidata.org/w/api.php"
    params = {
        "action": "wbgetentities",
        "ids": wikidata_id,
        "format": "json",
        "props": "sitelinks",
    }

    response = requests.get(url, params=params)
    data = response.json()

    # Accessing the sitelink for the specified language
    try:
        wikipedia_title = data["entities"][wikidata_id]["sitelinks"][f"{language}wiki"][
            "title"
        ]
        wikipedia_url = (
            f"https://{language}.wikipedia.org/wiki/{wikipedia_title.replace(' ', '_')}"
        )
        return wikipedia_url
    except KeyError:
        return "No Wikipedia article found for this language."

In [4]:
# import openai
from openai import AsyncOpenAI


async def filter_attribute(attribute, retries=0):
    client = AsyncOpenAI(
        # This is the default and can be omitted
        api_key="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"
    )

    try:
        chat_completion = await client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": "Classify the following attribute into 'common knowledge' or 'specialized'.\nExamples of common knowledge attribute: Cause of death, Nationality, category for people buried here, Country of origin\nExamples of specialized attribute: image of grave, Canadiana Name Authority ID, Diamond Catalog ID for persons and organisations\n\n\n"
                    + "Attribute"
                    + attribute
                    + "\n\nClass: ",
                }
            ],
            model="gpt-3.5-turbo",
        )
    except openai.APIConnectionError:
        if retries >= 5:
            print(f"Max retries reached. Skipping attribute {attribute}")
            return "specialized"

        print("GPT-3.5 Timeout. Sleep for 5 seconds")
        asyncio.sleep(5)
        return filter_attribute(attribute, retries + 1)

    # OpenAI API Key setup
    content = chat_completion.choices[0].message.content
    return content


async def filter_attributes(attributes):
    filtered_attributes = dict()

    # Prepare a list of tasks for asynchronous execution
    tasks = [
        asyncio.create_task(filter_attribute(attributes[attribute][0]))
        for attribute in attributes
    ]

    # Use asyncio.gather to run tasks concurrently and wait for all to complete
    results = await asyncio.gather(*tasks)

    # After tasks complete, map results back to attribute keys
    for attribute, result in zip(attributes, results):
        filtered_attributes[attribute] = result

    is_common_knowledge = dict()
    for attribute, pred in filtered_attributes.items():
        try:
            is_common_knowledge[attribute] = "specialized" not in pred.lower()
        except AttributeError:
            is_common_knowledge[attribute] = False

    filtered_attributes = {
        k: attributes[k]
        for k, v in filtered_attributes.items()
        if is_common_knowledge[k]
    }

    filtered_attributes
    return filtered_attributes

In [5]:
import random


async def get_attribute_prompt(entity_name, context, attribute, label, retries=0):
    client = AsyncOpenAI(
        # This is the default and can be omitted
        api_key="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"
    )

    try:
        chat_completion = await client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": f"Given a context sentence about {entity_name}, write a prompt continuing on the context and query the attribute which the model should predict the label in th next token. Following the rules: 1) use co-reference as much as possible to refer to the target entity. 2) Be clear and specific about the target attribute.",
                },
                # 1-st example
                {
                    "role": "user",
                    "content": "Entity: Abraham Lincoln\n\nContext: Abraham Lincoln was the 16th president of the United States,\n\nAttribute: cause of death\n\nLabel: shot to the head\n\n",
                },
                {
                    "role": "assistant",
                    "content": "Prompt: ' who's cause of death was '",
                },
                # 2-nd example
                {
                    "role": "user",
                    "content": "Entity: A Gang Story\n\nContext: A Gang Story (French: Les Lyonnais) is a 2011 French drama film \n\nAttribute: screenwriter\n\nLabel: Edgar Marie\n\n",
                },
                {
                    "role": "assistant",
                    "content": "Prompt: ' which was brought to life by the screenwriter '",
                },
                # 3-rd example
                {
                    "role": "user",
                    "content": "Entity: Palo Alto\n\nContext: Palo Alto is a charter city in the northwestern corner of Santa Clara County\n\nAttribute: country\n\nLabel: United States of America\n\n",
                },
                {
                    "role": "assistant",
                    "content": "Prompt: '; the country that this city belongs to is '",
                },
                {
                    "role": "user",
                    "content": f"Entity: {entity_name}\n\nContext: {context}\n\nAttribute: {attribute}\n\nLabel: {label}\n\n",
                },
            ],
            model="gpt-4o-mini",
        )
    except openai.APIConnectionError:
        if retries >= 5:
            print(f"Max retries reached. Skipping attribute {attribute}")
            return None

        print("GPT-4o Timeout. Sleep for 5 seconds")
        asyncio.sleep(5)
        return get_attribute_prompt(entity_name, context, attribute, label, retries + 1)

    # OpenAI API Key setup
    content = chat_completion.choices[0].message.content
    return content


async def generate_prompts(entity_name, entity_contexts, attributes, labels):
    prompt_dict = dict()

    for context in entity_contexts:
        prompt_dict[context] = dict()

        for attribute, label in zip(attributes, labels):
            prompt = await get_attribute_prompt(entity_name, context, attribute, label)

            if prompt is None:
                continue
            prompt = prompt.replace("Prompt: ", "")[1:-1]

            prompt_dict[context][attribute] = (prompt, label)

    return prompt_dict

### 3. Get Domain Entities

In [6]:
from datasets import Dataset


async def generate_domain_dataset(
    qid,
    num_of_context=1,
    n_property=5,
    limit=1000,
    save_per_entity=25,
    save_dir=None,
    resuming_from=None,
):
    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

    domain_dataset = []

    idxs_resuming_from = None

    if resuming_from is not None:
        domain_dataset = Dataset.load_from_disk(resuming_from)
        domain_dataset = domain_dataset.to_list()

        idxs_resuming_from = len(set([d["entity"] for d in domain_dataset]))

    all_entities = get_domain_entites(qid, limit)

    print(f"Number of entities: {len(all_entities)}")

    for i, (entity_qid, entity_name) in enumerate(tqdm(all_entities.items())):
        if idxs_resuming_from is not None:
            if i < idxs_resuming_from:
                print(f"Skipping entity {i}-th entity")
                continue

        if save_per_entity is not None and len(all_entities) > save_per_entity:
            if i % save_per_entity == 0 and i != 0:
                if save_dir is not None:
                    temp = Dataset.from_list(domain_dataset)
                    temp.save_to_disk(os.path.join(save_dir, f"{qid}"))

        attributes = get_entity_attributes(entity_qid)
        filtered_attributes = await filter_attributes(attributes)

        try:
            wikipedia_url = get_wikipedia_url(entity_qid)
        except KeyError:
            continue
        except requests.ConnectTimeout:
            asyncio.sleep(5)
            wikipedia_url = get_wikipedia_url(entity_qid)
            if wikipedia_url is None or wikipedia_url == "":
                continue
        try:
            wikipedia_content = get_wikipedia_content(wikipedia_url)
        except KeyError:
            continue
        except requests.exceptions.InvalidURL:
            print(wikipedia_url)
            raise
        except requests.ConnectTimeout:
            asyncio.sleep(5)
            wikipedia_content = get_wikipedia_content(wikipedia_url)

        entity_contexts = get_wikipedia_sentences(wikipedia_content, entity_name)
        if len(entity_contexts) == 0 or len(filtered_attributes) == 0:
            continue

        if num_of_context is None:
            entity_contexts = entity_contexts
        elif num_of_context == 1:
            entity_contexts = [entity_contexts[0]]
        elif len(entity_contexts) > num_of_context:
            entity_contexts = random.sample(entity_contexts, num_of_context)

        if len(filtered_attributes) > n_property:
            selected_keys = random.sample(filtered_attributes.keys(), n_property)
            filtered_attributes = {k: filtered_attributes[k] for k in selected_keys}

        attributes = []
        labels = []

        for attribute_qid in filtered_attributes.keys():
            attribute, value = filtered_attributes[attribute_qid]
            attributes.append(attribute)
            labels.append(value)

        prompt_dict = await generate_prompts(
            entity_name, entity_contexts, attributes, labels
        )

        for context in prompt_dict:
            for attribute in prompt_dict[context].keys():
                prompt, label = prompt_dict[context][attribute]

                domain_dataset.append(
                    {
                        "entity": entity_name,
                        "context": context,
                        "attribute": attribute,
                        "prompt": prompt,
                        "label": label,
                    }
                )

    if save_dir is not None:
        temp = Dataset.from_list(domain_dataset)
        temp.save_to_disk(os.path.join(save_dir, f"{qid}_{len(all_entities)}"))

    return Dataset.from_list(domain_dataset)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
dataset = await generate_domain_dataset(
    "Q5",
    num_of_context=1,
    n_property=5,
    limit=2000,
    save_per_entity=10,
    save_dir="./auto_ravel",
    resuming_from="./auto_ravel/Q5",
)

Number of entities: 2000


  0%|          | 0/2000 [00:00<?, ?it/s]

Skipping entity 0-th entity
Skipping entity 1-th entity
Skipping entity 2-th entity
Skipping entity 3-th entity
Skipping entity 4-th entity
Skipping entity 5-th entity
Skipping entity 6-th entity
Skipping entity 7-th entity
Skipping entity 8-th entity
Skipping entity 9-th entity
Skipping entity 10-th entity
Skipping entity 11-th entity
Skipping entity 12-th entity
Skipping entity 13-th entity
Skipping entity 14-th entity
Skipping entity 15-th entity
Skipping entity 16-th entity
Skipping entity 17-th entity
Skipping entity 18-th entity
Skipping entity 19-th entity
Skipping entity 20-th entity
Skipping entity 21-th entity
Skipping entity 22-th entity
Skipping entity 23-th entity
Skipping entity 24-th entity
Skipping entity 25-th entity
Skipping entity 26-th entity
Skipping entity 27-th entity
Skipping entity 28-th entity
Skipping entity 29-th entity
Skipping entity 30-th entity
Skipping entity 31-th entity
Skipping entity 32-th entity
Skipping entity 33-th entity
Skipping entity 34-th en

since Python 3.9 and will be removed in a subsequent version.
  selected_keys = random.sample(filtered_attributes.keys(), n_property)
Saving the dataset (1/1 shards): 100%|██████████| 4020/4020 [00:00<00:00, 341698.29 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 4055/4055 [00:00<00:00, 272982.52 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 4100/4100 [00:00<00:00, 225402.68 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 4140/4140 [00:00<00:00, 260926.81 examples/s]
 26%|██▌       | 520/2000 [07:01<19:59,  1.23it/s]  

Invalid URL

No Wikipedia article found for this language.





InvalidURL: Invalid URL 'https:///w/api.php': No host supplied

In [9]:
from transformers import AutoTokenizer, LlamaForCausalLM
import torch

model = LlamaForCausalLM.from_pretrained(
    "/nlp/scr/sjd24/llama3-8b", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("/nlp/scr/sjd24/llama3-8b")

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.82it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
def get_model_prediction_labels(model, tokenizer, target_dataset, max_num_tokens=3):
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

    def collate_fn(batch):
        input_texts = []

        for b in batch:
            prefix = b["context"]
            suffix = b["prompt"]

            input_text = f"{tokenizer.bos_token} {prefix}{suffix}"
            input_texts.append(input_text)

        inputs = tokenizer(
            input_texts, return_tensors="pt", padding=True, truncation=True
        )
        # inputs["position_ids"] = torch.cumsum(inputs["attention_mask"], dim=1) * inputs["attention_mask"] - 1

        return inputs

    model_predictions = []

    dataloader = torch.utils.data.DataLoader(
        target_dataset, batch_size=16, collate_fn=collate_fn, shuffle=False
    )

    model = model.to("cuda")
    for batch in tqdm(dataloader):
        batch = {k: v.to("cuda") for k, v in batch.items()}
        outputs = model.generate(**batch, max_new_tokens=max_num_tokens)

        outputs = outputs[:, batch["input_ids"].shape[1] :]

        predictions = [
            tokenizer.decode(output, skip_special_tokens=True) for output in outputs
        ]
        model_predictions.extend(predictions)

    target_dataset = target_dataset.add_column("model_predictions", model_predictions)

    return target_dataset