# CSKG embeddings

This notebook computes similarity between nodes in CSKG and performs grounding of questions/answers to CSKG.

We will play with two different families of embeddings: graph and text embeddings.

## Graph embeddings 

The graph embeddings have been computed by the command:

`python embeddings/embedding_click.py -i input/kgtk_framenet.tsv -o output/kgtk_framenet`

using the `embedding/embedding_click.py` script in this repository. This command invokes the Facebook PyBigGraph (PBG) library and computes graph embeddings with the ComplEx algorithm.

We are currently integrating this function into the KGTK package, to make it more accessible to the AI community.

## Text embeddings
The text embeddings were computed by using the KGTK `text-embedding` command as follows:
```
kgtk text_embedding \
    --embedding-projector-metadata-path none \
    --label-properties "label" \
    --isa-properties "/r/IsA" \
    --description-properties "/r/DefinedAs" \
    --property-value "/r/Causes" "/r/UsedFor" "/r/PartOf" "/r/AtLocation" "/r/CapableOf" \
    "/r/CausesDesire" "/r/SymbolOf" "/r/MadeOf" "/r/LocatedNear" "/r/Desires" "/r/HasProperty" "/r/HasFirstSubevent" \
    "/r/HasLastSubevent" "at:xAttr" "at:xEffect" "at:xIntent" "at:xNeed" "at:xReact" "at:xWant" \
    --has-properties "" \
    -f kgtk_format \
    --output-data-format kgtk_format \
    --model bert-large-nli-cls-token \
    --save-embedding-sentence \
    -i sorted.tsv.gz \
    -p sorted.tsv.gz \
    > cskg_embedings.txt
```

# Setup for grounding

```
conda create -n mowgli-env python=3.6 
conda activate mowgli-env

git clone https://github.com/ucinlp/mowgli-uci

mv mowgli-uci grounding

cd grounding

pip install -r requirements.txt
conda install --yes faiss-cpu -c pytorch -n mowgli-env
python -m spacy download en_core_web_lg

cd ..
```

# Setup for working with embeddings

`conda install -c conda-forge python-annoy`


## I. Load embeddings

In [29]:
from pathlib import Path
import gzip
import pickle as pkl
import faiss
import numpy as np
from typing import Callable, List, Tuple
import json
import hashlib
import logging
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


In [3]:
## GENERAL CONFIG
redo=True
CACHE_DIR = Path('.cache/')

## ANNOY SETUP
distance='cosine'

## GRAPH EMBEDDINGS SETUP
graph_dim=400 # Dimension of the graph embeddings - choose one of 100, 300, 400
graph_trees=20
graph_emb_path='output/embeddings/entity_embedding_%d.tsv.gz' % graph_dim
graph_index_path='tmp/complex_%d.ann' % graph_dim
graph_node2id_path='tmp/graph_node2id.pkl'
graph_id2node_path='tmp/graph_id2node.pkl'
graph_emb_col=1
graph_emb_del=' '


## TEXT EMBEDDINGS SETUP
text_dim=1024
text_trees=10
text_emb_path='output/embeddings/cskg_embeddings_bert_nli_large.txt.gz'
text_index_path='tmp/bert_large.ann'
text_node2id_path='tmp/text_node2id.pkl'
text_id2node_path='tmp/text_id2node.pkl'
text_emb_col=2
text_emb_del=','

In [4]:
def init_cache():
    if not CACHE_DIR.exists():
        logger.debug(f'Creating cache dir at: {CACHE_DIR}')
        CACHE_DIR.mkdir(parents=True)

In [5]:
def _cache_path(fn, args, kwargs):
    fn_name = fn.__name__
    args_string = ','.join(str(arg) for arg in args)
    kwargs_string = json.dumps(kwargs)
    byte_string = fn_name + args_string + kwargs_string
    hash_object = hashlib.sha1(byte_string.encode())
    return CACHE_DIR / hash_object.hexdigest()


def cache():
    def decorator(fn):
        def load_cached_if_available(*args, **kwargs):
            path = _cache_path(fn, args, kwargs)
            if path.exists():
                logger.debug(f'Loading `{fn.__name__}` output from cache')
                with open(path, 'rb') as f:
                    return pkl.load(f)
            output = fn(*args, **kwargs)
            with open(path, 'wb') as f:
                pkl.dump(output, f, protocol=4)
            return output
        return load_cached_if_available
    return decorator

In [6]:
def build_index(metric: str, embeddings: np.ndarray):

    logger.debug(f'Building search index')

    if metric == 'cosine':
        index = faiss.IndexFlatIP(embeddings.shape[-1])
    elif metric == 'l2':
        index = faiss.IndexFlatL2(embeddings.shape[-1])
    else:
        raise ValueError(f'Bad metric: {metric}')

    index.add(embeddings)

    return index

In [7]:
class Vocab:
    def __init__(self, words) -> None:
        self.idx_to_word = words
        self.word_to_idx = {word: idx for idx, word in enumerate(words)}

In [8]:
def count_lines(embedding_file: Path):
    with gzip.open(embedding_file, 'r') as f:
        i=0
        for line in f:
#            if 'embedding_sentence' in line_data: continue
            i+=1
    return i

In [9]:
@cache()
def read_embedding_file(embedding_file: Path, dim: int, emb_col=1) -> Tuple[Vocab, np.ndarray]:

    logger.debug(f'Reading embeddings from {embedding_file}')

    shape = tuple([count_lines(embedding_file), dim])
                  
    with gzip.open(embedding_file, 'r') as f:

        embeddings = np.zeros(shape, dtype=np.float32)

        if emb_col!=1:
            header=next(f)
        i=0
        words = []
        for line in tqdm(f, total=shape[0]):
            line=line.decode()
            if emb_col==1:
                node, *embedding = line.split()
            else:
                line_data=line.split()
                if line_data[1]=='embedding_sentence': continue
                node=line_data[0]
                embedding=line_data[2].split(',')
            embedding = np.array([float(x) for x in embedding])
            words.append(node)
            embeddings[i] = embedding
            i+=1

    vocab = Vocab(words)

    return vocab, embeddings

In [10]:
init_cache()

### Load graph embeddings

In [11]:
graph_vocab, graph_embeddings = read_embedding_file(graph_emb_path, graph_dim)

100%|██████████| 2160968/2160968 [05:08<00:00, 7003.03it/s]


In [12]:
graph_index = build_index(distance, graph_embeddings)

### Load text embeddings

In [13]:
text_vocab, text_embeddings = read_embedding_file(text_emb_path, text_dim, emb_col=2)

100%|█████████▉| 4322096/4322097 [12:14<00:00, 5882.91it/s]


In [14]:
text_index = build_index(distance, text_embeddings)

## II. Most similar nodes in CSKG

In [18]:
query_nodes=['/c/en/turtle', '/c/en/happy', '/c/en/turtle/n/wn/animal', 'at:personx_abandons_____altogether', '/c/en/caffeine']

In [19]:
num_neighbors=5

### According to graph embeddings

In [20]:
ids=[graph_vocab.word_to_idx[n] for n in query_nodes]
distances, neighbors = graph_index.search(graph_embeddings[ids], num_neighbors+1)
for node_nbrs in neighbors:
    neighboring_nodes=[graph_vocab.idx_to_word[n] for n in node_nbrs]

    print('Nearest neighbors to *%s*' % neighboring_nodes[0])
    print(neighboring_nodes[1:])
    print()
    print()

Nearest neighbors to */c/en/turtle*
['/c/en/tortoise', '/c/en/animal', '/c/en/carapace', '/c/en/bill/n/wikt/en_2', '/c/en/turtled/v']


Nearest neighbors to */c/en/happy*
['/c/en/joyful', '/c/en/excited', '/c/en/pleased', '/c/en/glad', '/c/en/elated']


Nearest neighbors to */c/en/turtle/n/wn/animal*
['/c/en/luger/n/wn/person', '/c/en/mud_turtle/n/wn/animal', '/c/en/carapace/n/wn/animal', '/c/en/chelonian/n/wn/animal', '/c/en/testudinidae/n/wn/animal']


Nearest neighbors to *at:personx_abandons_____altogether*
['at:to_start_fresh', '/c/en/sad', '/c/en/impatient', 'at:to_find_a_new_job', '/c/en/authoritative']


Nearest neighbors to */c/en/caffeine*
['/c/en/coffee', '/c/en/caffeinated/a', '/c/en/caffeine_free', '/c/en/tea', '/c/en/drug/n']




### According to text embeddings

In [22]:
ids=[text_vocab.word_to_idx[n] for n in query_nodes]

distances, neighbors = text_index.search(text_embeddings[ids], num_neighbors+1)

for node_nbrs in neighbors:
    neighboring_nodes=[text_vocab.idx_to_word[n] for n in node_nbrs]

    print('Nearest neighbors to *%s*' % neighboring_nodes[0])
    print(neighboring_nodes[1:])
    print()
    print()

Nearest neighbors to */c/en/turtle*
['/c/en/turtles', '/c/en/large_sea_turtle/n', '/c/en/large_freshwater_turtle/n', '/c/en/shrimp_and_turtle', '/c/en/sea_turtle/n']


Nearest neighbors to */c/en/happy*
['at:to_tell_personx_they_are_happy_to_see_that', 'at:to_tell_personx_they_are_happy_to_see_them', 'at:tell_people_how_happy_they_are', 'at:to_let_person_x_know_how_happy_they_are', '/c/en/bring_happiness']


Nearest neighbors to */c/en/turtle/n/wn/animal*
['Q1705322', '/c/en/sea_turtle/n/wn/animal', '/c/en/freshwater_turtle/n', '/c/en/sea_turtle/n', '/c/en/ridley_sea_turtle']


Nearest neighbors to *at:personx_abandons_____altogether*
["at:personx_loses_personx's_ability", 'at:personx_chases_persony_away', 'at:personx_goes_too_far', 'at:personx_feels_hopeless', 'at:personx_loses_persony_opportunity']


Nearest neighbors to */c/en/caffeine*
['at:wakes_up_from_caffeine', 'at:becomes_more_awake_from_the_caffeine', '/c/en/amphetamines', 'at:to_be_energized_with_caffeine', '/c/en/caffeine_w

## III. Compute similarity between two nodes

In [67]:
node_pairs=[['/c/en/woman', '/c/en/man'], ['/c/en/pencil', 'Q614304'], ['/c/en/ash', 'rg:en_ash-gray'], ['/c/en/spiritual', '/c/en/religion'], ['/c/en/monkey', '/c/en/gorilla'], ['/c/en/monkey', '/c/en/tea']]

### According to graph embeddings

In [68]:
for nodes in node_pairs:
    ids=[graph_vocab.word_to_idx[n] for n in nodes]
    ge=graph_embeddings[ids]
    print(' '.join(nodes), cosine_similarity([ge[0]], [ge[1]])[0][0])

/c/en/woman /c/en/man 0.35791242
/c/en/pencil Q614304 -0.083074115
/c/en/ash rg:en_ash-gray 0.11169712
/c/en/spiritual /c/en/religion 0.2867519
/c/en/monkey /c/en/gorilla 0.32852727
/c/en/monkey /c/en/tea -0.004493271


### According to text embeddings

In [69]:
for nodes in node_pairs:
    ids=[text_vocab.word_to_idx[n] for n in nodes]
    ge=text_embeddings[ids]
    print(' '.join(nodes), cosine_similarity([ge[0]], [ge[1]])[0][0])

/c/en/woman /c/en/man 0.66569746
/c/en/pencil Q614304 0.36885476
/c/en/ash rg:en_ash-gray 0.35797378
/c/en/spiritual /c/en/religion 0.5916587
/c/en/monkey /c/en/gorilla 0.6783559
/c/en/monkey /c/en/tea 0.6468366


## IV. Parsing questions and answers

In [None]:
from grounding.graphify import parse


In [None]:
sentences=[
    'Max looked for the onions so that he could  make a stew.',
    'To get the bathroom counters dry after washing your face, take a small hand lotion and wipe away the extra water around the sink.',
    'To get the bathroom counters dry after washing your face, take a small hand towel and wipe away the extra water around the sink.'
]

In [None]:
parse_trees=parse.graphify_dataset(sentences)

In [None]:
for sent_data in parse_trees:
    print('Sentence:', sent_data['sentence'])
    print('Tokenized sentence', sent_data['tokenized_sentence'])
    
    nodes={}
    for n_id, n_data in sent_data['nodes'].items():
        nodes[n_id]=n_data['phrase']
    
    for e_id, e_data in sent_data['edges'].items():
        print('NODE1:', ' '.join(nodes[e_data['head_node_id']]), 'RELATION', e_data['edge_name'], 'NODE2', ' '.join(nodes[e_data['tail_node_id']]) )
    print()

## V. Grounding questions and questions to ConceptNet

In [None]:
from grounding.graphify import link

In [None]:
linked_data=link.link(parse_trees, embedding_file='grounding/numberbatch-en-19.08.txt')

In [None]:
for sent_data in linked_data:
    print('Sentence:', sent_data['sentence'])
    for n_id, n_data in sent_data['nodes'].items():
        print('Node phrase:', n_data['phrase'])
        for c in reversed(n_data['candidates']):
            print(c)
        print()
        
    print()