In [2]:
from urllib.request import urlopen 
from bs4 import BeautifulSoup

def remove_content_between_chars(string: str, start="[", end="]"):
    edited = []
    skip = False
    for char in string:
        if char == start:
            skip = True

        if not skip:
            edited.append(char)

        if char == end:
            assert skip, f"Closing character found without opening in {''.join(edited)}"
            skip = False
    return "".join(edited)


In [3]:
import csv
import json
import re
from typing import List, Set

from spacy.tokens.doc import Doc
from spacy.tokens import Span
from tqdm import tqdm
from unidecode import unidecode

def get_paragraphs(url: str):
        page = urlopen(url)
        assert page.getcode() == 200, f"{url} is not valid"
        soup = BeautifulSoup(page, "html.parser")

        res = [soup.find("h1")]
        div = soup.find("div", class_="mw-content-ltr mw-parser-output")
        res += [i for i in div.find_all(re.compile("(h.)|p"))]

        paragraphs = []
        for tag in res:
            if tag.name.startswith("h"):
                current_header = tag
            elif tag.name == "p" and tag.get("style") is None:
                # remove multiple whitespaces
                p = " ".join(tag.text.split())
                # remove additional whitespaces
                p = p.strip()
                # remove content between '[' and ']'
                p = remove_content_between_chars(p)
                if len(p) > 0:
                    paragraphs.append({
                        "text": p,
                        "header_name": current_header.text,
                        "header_level": current_header.name[-1]
                    })
        
        return paragraphs
        
        
def check_exact_match(paragraph: str, answer: str):
    # case insensitive match
    return bool(re.search(f"(^|[^\w]{{1}}){answer}($|[^\w]{{1}})", paragraph, flags=re.IGNORECASE))


def check_full_match(paragraph: str, attribute: str, answer: str):
    # a full match requires to match both attribute (e.g. President) and answers
    for to_find in [attribute, answer]:
        if not check_exact_match(paragraph, to_find):
            return False
    return True


def write_roman(num: int):
    ROMAN = {
        1000: "M",
        900: "CM",
        500: "D",
        400: "CD",
        100: "C",
        90: "XC",
        50: "L",
        40: "XL",
        10: "X",
        9: "IX",
        5: "V",
        4: "IV",
        1: "I",
    }

    def roman_num(num: int):
        for r in ROMAN.keys():
            x, y = divmod(num, r)
            yield ROMAN[r] * x
            num -= r * x
            if num <= 0:
                break

    return "".join([a for a in roman_num(num)])


def remove_additional_bits(string: str, additional_bits: List[str]):
    for bit in additional_bits:
        string = re.sub(bit, "", string)
    return " ".join(string.split())  # remove additional whitespaces


def find_main_chunk(doc: Doc):
    ancestor = None
    for chunk in doc.noun_chunks:
        if ancestor is None:
            ancestor = chunk
        elif chunk.root.is_ancestor(ancestor.root):
            ancestor = chunk.root
    return ancestor


def is_monarch(span: Span, monarch_nums: Set[str]):
    for name_chunk in span.text.split():
        if name_chunk in monarch_nums:
            return True
    return False

In [9]:
import spacy
from spacy.tokenizer import Tokenizer

MONARCH_NUMS = {write_roman(i) for i in range(1, 100, 1)}

ADDITIONAL_BITS = [
    "[\w]?F(\.)?C(\.)?[\w]?",  # FC, F.C., AFC, ... for Football
    "[\w]?C(\.)?F(\.)?[\w]?",  # CF, ... for Football
    "[\w]?F(\.)?K(\.)?[\w]?",  # FK, ... for Football
    "[\w]?A(\.)?S(\.)?[\w]?",  # AS, ... for Football
    "[\w]?S(\.)?V(\.)?[\w]?",  # SV, ... for Football
    "[\w]?B(\.)?C(\.)?[\w]?",  # BC, ... for Basketball
    "[\w](\.)[\w](\.)",  # General regex for to remove two letter acronyms (with )
    "football",
    "(t|T)eam",
    "association",
    "men's",
    "basketball",
    "F1",
    "(S|s)cuderia",
    "(R|r)acing",
]



nlp = spacy.load("en_core_web_trf")
nlp.tokenizer = Tokenizer(nlp.vocab)  # Whitespace tokenization

passages = {}
with open("wikipedia_pages.csv", "r") as f:
    lines = f.readlines()

with open("wikipedia_pages.csv", "r") as f:
    reader = csv.DictReader(f, delimiter="\t")

    for row in tqdm(reader, total=len(lines)-1):
        category = row["category"]
        item = row["item"]
        page_url = row["page_url"]
        attribute = row["attribute"]
        attribute = attribute if len(attribute) > 0 else None
        
        if category not in passages:
            passages[category] = {}
        
        if item not in passages[category]:
            passages[category][item] = {}

        if attribute is not None:
            if attribute not in passages[category][item]:
                passages[category][item][attribute] = {}

        # remove unicode characters from url
        url = unidecode(page_url)
        paragraphs = get_paragraphs(url)

        matches = {
            "full": [],
            "em": [],
            "simplified": [],
            "head": [],
        }

        no_match = True
        for paragraph in tqdm(paragraphs, desc=f"paragraphs for {item}"):
            answer = row["answer"]
            append_to = None
            matched = None

            attr = attribute
            if attribute is not None:
                if "prime minister" in attr.lower():
                    attr = "prime minister"
                if "president" in attr.lower():
                    attr = "president"
                if "king" in attr.lower():
                    attr = "king"
                if "monarch" in attr.lower():
                    attr = "monarch"
                if "supreme leader" in attr.lower():
                    attr = "supreme leader"
                if "premier" in attr.lower():
                    attr = "premier"
            if check_full_match(paragraph["text"], attr, answer):
                append_to = "full"
                matched = (attr, answer)
                no_match = False
            elif check_exact_match(paragraph["text"], answer):
                append_to = "em"
                matched = answer
                no_match = False
            else:
                if category in ["athletes_byPayment"]:
                    answer = remove_additional_bits(answer, ADDITIONAL_BITS)

                if check_exact_match(paragraph["text"], answer):
                    append_to = "simplified"
                    matched = answer
                    no_match = False
                elif len(answer.split()) > 1:
                    doc = nlp(answer)
                    main_chunk = find_main_chunk(doc)

                    if main_chunk is not None:
                        if is_monarch(main_chunk, MONARCH_NUMS):
                            answer = main_chunk.text
                        else:
                            answer = main_chunk.root.text

                    if check_exact_match(paragraph["text"], answer):
                        append_to = "head"
                        #print(answer)
                        matched = answer
                        no_match = False

            
            if append_to:
                matches[append_to].append({
                    "paragraph": paragraph,
                    "matched": matched
                })

        if attribute is not None:
            if attribute not in passages[category][item]:
                passages[category][item][attribute] = {}

            passages[category][item][attribute] = {
                "matches": matches,
                "page_url": url,
                "no_match": no_match
            }
        else:
            passages[category][item] = {
                "matches": matches,
                "page_url": url,
                "no_match": no_match
            }



with open("passages.json", "w") as f:
    json.dump(passages, f, indent=4)


paragraphs for Spain: 100%|██████████| 150/150 [00:02<00:00, 57.52it/s]
paragraphs for Stephen Curry: 100%|██████████| 100/100 [00:01<00:00, 62.79it/s]
paragraphs for Toyota: 100%|██████████| 161/161 [00:02<00:00, 62.01it/s]
paragraphs for Canada: 100%|██████████| 116/116 [00:01<00:00, 61.37it/s]
paragraphs for Japan: 100%|██████████| 97/97 [00:01<00:00, 60.21it/s]
paragraphs for Kevin Durant: 100%|██████████| 64/64 [00:00<00:00, 67.91it/s]
paragraphs for CVS Health: 100%|██████████| 54/54 [00:00<00:00, 62.18it/s]
paragraphs for Gazprom: 100%|██████████| 96/96 [00:01<00:00, 63.24it/s]
paragraphs for Netherlands: 100%|██████████| 150/150 [00:02<00:00, 62.20it/s]
paragraphs for Belgium: 100%|██████████| 101/101 [00:01<00:00, 58.17it/s]
paragraphs for Singapore: 100%|██████████| 113/113 [00:01<00:00, 56.89it/s]
paragraphs for Egypt: 100%|██████████| 205/205 [00:03<00:00, 58.14it/s]
paragraphs for Karim Benzema: 100%|██████████| 85/85 [00:01<00:00, 60.50it/s]
paragraphs for Harry Kane: 100

In [26]:
import os
from pathlib import Path
from copy import deepcopy

with open("passages.json", "r") as f:
    passages = json.load(f)

outdated_questions = {}
for folder in Path("/home/simone/papers/ACL/knowledge-editing/models_editing/editing_datasets").iterdir():
    editing_dataset = os.path.join(folder, "editing_dataset.json")
    with open(editing_dataset, "r") as f:
        editing_dataset = json.load(f)
    for sample in editing_dataset:
        domain = sample["domain"]
        element = sample["element"]
        attribute = sample["attribute"]
        if domain not in outdated_questions:
            outdated_questions[domain] = {}
        if element not in outdated_questions[domain]:
            outdated_questions[domain][element] = {}
        if attribute is not None:
            if  element not in outdated_questions[domain][element]:
                outdated_questions[domain][element][attribute] = {}

def count_questions(passages):
    n_questions = 0
    for domain in passages:
        for element in passages[domain]:
            if domain in ["countries_byGDP", "organizations"]:
                for attribute in passages[domain][element]:
                    n_questions += 1
            else:
                n_questions += 1
    return n_questions

print("BEFORE: ", count_questions(passages))
passages_copy = deepcopy(passages)

for domain in passages_copy:
    for element in passages_copy[domain]:
        if element not in outdated_questions[domain]:
            del passages[domain][element]
        else:
            if domain in ["countries_byGDP", "organizations"]:
                for attribute in passages_copy[domain][element]:
                    if attribute not in outdated_questions[domain][element]:
                        del passages[domain][element][attribute]

        
print("AFTER: ", count_questions(passages))
        
    
with open("editing_passages.json", "w") as f:
    json.dump(passages, f, indent=4)

BEFORE:  130
AFTER:  79


In [2]:
import json
import random

with open("editing_passages.json", "r") as f:
    editing_passages = json.load(f)

passages_per_domain = {}
for domain in editing_passages:
    if domain not in passages_per_domain:
        passages_per_domain[domain] = []
    for element in editing_passages[domain]:
        if domain in ["countries_byGDP", "organizations"]:
            for attribute in editing_passages[domain][element]:
                question_passages = editing_passages[domain][element][attribute]

                matches = []
                for matches_per_category in question_passages["matches"].values():
                    matches += [m["paragraph"]["text"] for m in matches_per_category]

                assert len(matches) == 1, f"You should have only 1 passage for each question but you have {len(matches)} for {domain} -- {element} -- {attribute}"
                context = matches.pop()

                passages_per_domain[domain].append(context)

        else:
            question_passages = editing_passages[domain][element]

            matches = []
            for matches_per_category in question_passages["matches"].values():
                matches += [m["paragraph"]["text"] for m in matches_per_category]

            assert len(matches) == 1, f"You should have only 1 passage for each question but you have {len(matches)} for {domain} -- {element}"
            context = matches.pop()

            passages_per_domain[domain].append(context)

#for domain, ps in passages_per_domain.items():
#    print(domain, len(ps))
#print()

# set the seed once at the beginning
random.seed(42)
edited_passages = {}
for domain in editing_passages:
    # take the passages that do not belong to this domain
    noisy_passages = [p for d, ps in passages_per_domain.items() for p in ps if d != domain]
    #print(domain, len(noisy_passages))
    for element in editing_passages[domain]:
        if domain in ["countries_byGDP", "organizations"]:
            for attribute in editing_passages[domain][element]:
                question_passages = editing_passages[domain][element][attribute]

                matches = []
                for matches_per_category in question_passages["matches"].values():
                    matches += [m["paragraph"]["text"] for m in matches_per_category]

                assert len(matches) == 1, f"You should have only 1 passage for each question but you have {len(matches)} for {domain} -- {element} -- {attribute}"
                
                for matches_per_category in question_passages["matches"].values():
                    if len(matches_per_category) == 1:
                        noisy_p = random.choice(noisy_passages)
                        p = matches_per_category[0]["paragraph"]["text"]
                        p = "\n".join([noisy_p, p])
                        matches_per_category[0]["paragraph"]["text"] = p
                        break
        else:
            question_passages = editing_passages[domain][element]

            matches = []
            for matches_per_category in question_passages["matches"].values():
                matches += [m["paragraph"]["text"] for m in matches_per_category]

            assert len(matches) == 1, f"You should have only 1 passage for each question but you have {len(matches)} for {domain} -- {element}"

            for matches_per_category in question_passages["matches"].values():
                if len(matches_per_category) == 1:
                    noisy_p = random.choice(noisy_passages)
                    p = matches_per_category[0]["paragraph"]["text"]
                    p = "\n".join([noisy_p, p])
                    matches_per_category[0]["paragraph"]["text"] = p
                    break


with open("noisy_editing_passages.json", "w") as f:
    json.dump(editing_passages, f, indent=4)

In [3]:
import json

with open("editing_passages.json", "r") as f:
    editing_passages = json.load(f)

passages_per_domain = {}
for domain in editing_passages:
    if domain not in passages_per_domain:
        passages_per_domain[domain] = []
    for element in editing_passages[domain]:
        if domain in ["countries_byGDP", "organizations"]:
            for attribute in editing_passages[domain][element]:
                question_passages = editing_passages[domain][element][attribute]

                matches = []
                for matches_per_category in question_passages["matches"].values():
                    matches += [m["paragraph"]["text"] for m in matches_per_category]

                assert len(matches) == 1, f"You should have only 1 passage for each question but you have {len(matches)} for {domain} -- {element} -- {attribute}"
                context = matches.pop()

                passages_per_domain[domain].append(context)

        else:
            question_passages = editing_passages[domain][element]

            matches = []
            for matches_per_category in question_passages["matches"].values():
                matches += [m["paragraph"]["text"] for m in matches_per_category]

            assert len(matches) == 1, f"You should have only 1 passage for each question but you have {len(matches)} for {domain} -- {element}"
            context = matches.pop()

            passages_per_domain[domain].append(context)

In [10]:
import random

all_passages = [p for ps in passages_per_domain.values() for p in ps]
random.seed(42)
random.shuffle(all_passages)

with open("langchain/data/passages.txt", "w") as f:
    for p in all_passages:
        f.write(p + "\n")