# Assignment 1


from typing import Dict, List, Set, Tuple, Union
import itertools

import spacy

import utils as U

nlp = spacy.load("en_core_web_sm")
stopwords = set(nlp.Defaults.stop_words).union({",", ".", "?", ":", ";"})


def load_examples() -> List[Dict[str, Union[str, List[str]]]]:
    """Quick and dirty parser to turn examples.txt into a machine-readable form.

    Returns
    -------
    list[dict[str, Union[str, list[str]]]]
        list of examples. Each examples is of the form
            {
                "question": QUESTION,
                "context": CONTEXT or "",
                "choices": [CHOICE1, CHOICE2, ...]
            }
    """

    with open("../data/raw/examples.txt") as fp:
        data = fp.read()

    examples = data.split("#")[1:]

    parsed = []

    for e in examples:
        parts = e.split("\n")

        question = next(
            (
                U.removeprefix(p, "Question:").strip()
                for p in parts
                if p.startswith("Question: ")
            ),
            "",
        )
        context = next(
            (
                U.removeprefix(p, "Context:").strip()
                for p in parts
                if p.startswith("Context: ")
            ),
            "",
        )
        choices = [p.split(")")[1].strip() for p in parts if p.startswith("(")]

        parsed.append({"question": question, "context": context, "choices": choices})

    return parsed

def extract_terms_from_example(example: dict) -> Tuple[Set[str], Set[str]]:
    """extract terms from an example using `extract_terms`.

    Parameters
    ----------
    example : dict
        example as returned by `load_examples`.

    Returns
    -------
    set[str]
        all terms appearing in question and context
    set[str]]
        all terms appearing in one of the answer choices
    """

    question_context = set(
        itertools.chain(
            extract_terms(example["question"]), extract_terms(example["context"])
        )
    )

    choices = set(
        itertools.chain.from_iterable(extract_terms(c) for c in example["choices"])
    )

    return question_context, choices

In [None]:
from dataclasses import dataclass
from functools import reduce
import operator
from typing import Dict, Iterable, List, NamedTuple, Set, Tuple

import nltk
from nltk.stem import WordNetLemmatizer
import joblib

nltk.download("wordnet")
nltk.download('omw-1.4')

__lemmatizer = WordNetLemmatizer()


class EdgeDescriptor(NamedTuple):
    """Describes an edge in ConceptNet.

    Attributes
    ----------
    label_idx:
        index of the edge label in labels_idx2name
    weight:
        weight of the edge
    row_idx:
        index of the edge in "en_edges.csv" (for further information lookup if necessary)
    """

    label_idx : int
    weight : float
    row_idx : int


@dataclass
class ConceptNet:
    """This class contains the filtered and processed representation of the ConceptNet knowledge 
    graph.

    Attributes
    ----------

    nodes_idx2name:
        mapping node indices to normalized node names
    nodes_name2idx:
        mapping node names to indices

    labels_idx2name:
        mapping edge label indices to labels
    labels_name2idx:
        mapping edge labels to indices

    adjacency_lists:
        contains a list of neighbors for each node (indices are used). graph is represented as 
        undirected.
    edge_descriptors:
        maps a pair of indices to all direct edges in ConceptNet between this two nodes. edges are described via EdgeDescriptors. Edges are treated as *directed* here, i.e. if an edge is in `adjacency_lists`, but not in `edge_descriptors` it is an edge not originally present in ConceptNet and one must look up the reverse edge in `edge_descriptor`.
    """

    nodes_idx2name : List[str]
    nodes_name2idx : Dict[str, int]

    labels_idx2name : List[str]
    labels_name2idx : Dict[str, int]

    adjacency_lists : Dict[int, Set[int]]
    edge_descriptors : Dict[Tuple[int, int], Set[EdgeDescriptor]]

def removeprefix(s: str, prefix: str) -> str:
    if s.startswith(prefix):
        return s[len(prefix):]
    else:
        return s

def normalize_conceptnet(s: str) -> str:
    """Normalize a ConceptNet node to ensure that matching between example words and nodes in 
    ConceptNet works.

    Parameters
    ----------
    s : str
        string to normalize

    Returns
    -------
    str
        normalized string
    """

    s = removeprefix(s, "/c/en/")
    s = s.split("/")[0] # remove the optionally added (/n, /v, ...)
    s = s.replace("_", " ")
    s = s.casefold()
    s = __lemmatizer.lemmatize(s)

    return s

def normalize_input(s: str) -> str:
    """Normalize a input token to ensure that matching between example words and nodes in 
    ConceptNet works.

    Parameters
    ----------
    s : str
        string to normalize

    Returns
    -------
    str
        normalized string
    """
    # TODO switch to Spacy lemmatization

    s = s.casefold()
    s = __lemmatizer.lemmatize(s)

    return s

def prod(vals: Iterable[float]) -> float:
    return reduce(operator.mul, vals, 1)

def load_conceptnet(load_compressed: bool =False) -> ConceptNet:

    if load_compressed:
        return joblib.load("../data/processed/graph_representation_compressed.joblib")
    else:
        return joblib.load("../data/processed/graph_representation.joblib")

In [None]:
from typing import Any, Dict, Iterable, List, Union
import logging

from utils import ConceptNet, normalize_input
from renderer import render_path_brief


def search_shortest_path(
    start_idx: int,
    end_idx: int,
    adjacency_lists: Dict[int, Iterable[int]],
    max_path_len,
) -> list:
    """the actual implementation of a BFS. This function is completely agnostic about ConceptNet, 
    it just works with adjacency lists of integers."""

    queue = [
        (start_idx, 0)
    ]  # nodes to be processed (tuple of node and path length to start node)
    predecessor_idx = {
        start_idx: -1
    }  # visited nodes, mapping each node to the idx of its predecessor

    while queue:
        node, path_len = queue.pop(0)

        #logging.debug(f"Processing {node} (path len {path_len})")

        if node == end_idx:

            #logging.debug("  Final node, building path")

            # build path in reverse
            path = [node]

            pred = predecessor_idx[node]
            while pred != -1:
                #logging.debug(f"    {pred}")
                path.insert(0, pred)
                pred = predecessor_idx[pred]

            return path

        for neighbour in adjacency_lists[node]:
            if neighbour in predecessor_idx:
                continue

            #logging.debug(f"  Processing unseen neighbor {neighbour}")

            predecessor_idx[neighbour] = node

            if path_len + 1 < max_path_len:
                #logging.debug("   Adding node to queue")
                queue.append((neighbour, path_len + 1))

    return []


def find_word_path(
        start_term: str, end_term: str, 
        graph: ConceptNet, 
        max_path_len: int =3,
        renderer=render_path_brief) -> Union[str, List[int]]:
    """Find the shortest path between `start_term` and `end_term` and return its textual 
    representation. 

    Parameters
    ----------
    start_term : str
        start term for path search
    end_term : str
        end term for path search
    graph : ConceptNet
        ConceptNet instance to work with
    max_path_len : int, optional
        maximal number of nodes in a path, by default 3
    renderer, optional
        function to visualize paths, by default render_path_brief. If None, the raw path (a list of int's) is returned.

    Returns
    -------
    str
        path visualization
    list[int]
        raw path, only returned if renderer is None
    """

    start_term = normalize_input(start_term)
    end_term = normalize_input(end_term)

    #logging.info(f"after normalization: {start_term}, {end_term}")

    if start_term in graph.nodes_name2idx:
        start_idx = graph.nodes_name2idx[start_term]
    else:
        #logging.warning(f"start {start_term} not in graph, skipping")
        return []

    if end_term in graph.nodes_name2idx:
        end_idx = graph.nodes_name2idx[end_term]
    else:
        #logging.warning(f"end {end_term} not in graph, skipping")
        return []

    path = search_shortest_path(
        start_idx, end_idx, graph.adjacency_lists, max_path_len=max_path_len
    )

    if renderer:
        return renderer(path, graph)
    else:
        return path

In [12]:
%load_ext autoreload
%autoreload 2

import itertools

import pandas as pd
from tqdm import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
examples = load_examples() # load examples.txt in machine-readable form
conceptnet = load_conceptnet() # load preprocessed ConceptNet pickle 

## Extract Terms From Examples

In [3]:
extract_terms_example(examples[0])

({'bag',
  'baggage',
  'checked',
  'drawstring',
  'drawstring bag',
  'heading',
  'only baggage',
  'wa',
  'woman'},
 {'airport',
  'garbage',
  'jewelry',
  'jewelry store',
  'military',
  'safe',
  'store'})

## Path Search 

In [10]:
find_word_path("safe", "baggage", conceptnet)

[]

In [6]:
# Path length with four
find_word_path("safe", "baggage", conceptnet, max_path_len=4)

'safe --RelatedTo--> heavy <--RelatedTo-- carry --Antonym--> baggage'

In [58]:
find_word_path("checked", "garbage", conceptnet, max_path_len=4)

'checked --HasContext--> north america --HasContext--> canada <--HasContext-- garbage'

In [64]:
find_word_path("excited", "apply", conceptnet, max_path_len=4)

'excited <--HasSubevent-- score home run --HasPrerequisite--> play baseball --HasPrerequisite--> apply'

In [59]:
find_word_path("safe", "baggage", conceptnet, max_path_len=4)

'safe --RelatedTo--> heavy <--RelatedTo-- carry --Antonym--> baggage'

## Extract Paths

In [54]:
result = []

for idx, example in enumerate(tqdm(examples), start=1):
    question_context, choices = extract_terms_example(example)

    for tq in question_context:
        for tc in choices:
            p = find_word_path(tq, tc, conceptnet)

            result.append(
                {
                    "example": idx,
                    "question_context": tq,
                    "choices": tc,
                    "path": p,
                }
            )


100%|██████████| 10/10 [07:08<00:00, 42.82s/it]


In [55]:
path_df = pd.DataFrame(result)

In [56]:
path_df

Unnamed: 0,example,question_context,choices,path
0,1,drawstring,military,drawstring --PartOf--> drawstring bag --AtLoca...
1,1,drawstring,jewelry,[]
2,1,drawstring,jewelry store,drawstring --PartOf--> drawstring bag --AtLoca...
3,1,drawstring,safe,drawstring --PartOf--> drawstring bag --AtLoca...
4,1,drawstring,store,drawstring --PartOf--> drawstring bag --AtLoca...
...,...,...,...,...
804,10,water,lift,water --HasContext--> dialectal <--HasContext-...
805,10,water,bottle,water --AtLocation--> bottle
806,10,water,suction,[]
807,10,water,press,water <--RelatedTo-- cast <--MannerOf-- press


In [57]:
path_df.index.name = "index"
path_df.to_csv("../data/paths_examples.csv")

### Most Common Relations

In [5]:
df = pd.read_csv("../data/en_edges.csv")

In [7]:
df.label.value_counts()

/r/RelatedTo                    1703582
/r/FormOf                        378859
/r/DerivedFrom                   325374
/r/HasContext                    232935
/r/IsA                           230137
/r/Synonym                       222156
/r/UsedFor                        39790
/r/EtymologicallyRelatedTo        32075
/r/SimilarTo                      30280
/r/AtLocation                     27797
/r/HasSubevent                    25238
/r/HasPrerequisite                22710
/r/CapableOf                      22677
/r/Antonym                        19066
/r/Causes                         16801
/r/PartOf                         13077
/r/MannerOf                       12715
/r/MotivatedByGoal                 9489
/r/HasProperty                     8433
/r/ReceivesAction                  6037
/r/HasA                            5545
/r/CausesDesire                    4688
/r/dbpedia/genre                   3824
/r/HasFirstSubevent                3347
/r/DistinctFrom                    3315
