# CSKG embeddings

This notebook computes similarity between nodes in CSKG and performs grounding of questions/answers to CSKG.
We will play with two different families of embeddings: graph and text embeddings.

## Graph embeddings 

The graph embeddings were computed by the using the KGTK `graph-embeddings` as follows:<br>
`kgtk graph-embeddings -i < input_file.tsv > -o < output_file.tsv >`<br>
Please go to https://kgtk.readthedocs.io/en/latest/analysis/graph_embeddings/ to see more details.

## Text embeddings
The text embeddings were computed by using the KGTK `text-embedding` command as follows:<br>
`kgtk text-embedding < input_file.tsv > < output_file.tsv >`<br>
Please go to https://kgtk.readthedocs.io/en/latest/analysis/text_embedding/ to see more details.


# Setup for working with embeddings

## Parameters for invoking the notebook

- `cskg_path`: a folder containing the necessary files and all the analysis products.
- `graph_emb`: the name of the graph embedding output file
- `text_emb`: the name of the text embedding output file
- `distance`: measurement for embedding distance


Tip: Since it takes much time to generate graph embeddings and text embeddings, We have prepared the `graph_emb` and `text_emb` in advance. You can download them from https://drive.google.com/drive/u/1/folders/16347KHSloJJZIbgC9V5gH7_pRx0CzjPQ.


In [1]:
# Parameters
cskg_path = "../output" 
graph_emb = "trans_log_dot_0.1.tsv.gz"
text_emb = "bert-nli-large-embeddings.tsv.gz"
distance='cosine'

In [2]:
# Imports
import os
from pathlib import Path
import gzip
import pickle as pkl
import faiss
import numpy as np
from typing import Callable, List, Tuple
import json
import hashlib
import logging
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

In [3]:
os.environ['CSKG'] = cskg_path
os.environ['GE'] = "{}/{}".format(cskg_path, graph_emb)
os.environ['TE'] = "{}/{}".format(cskg_path, text_emb)
graph_emb_path = os.environ['GE']
text_emb_path = os.environ['TE']

In [4]:
!echo $CSKG
!echo $GE
!echo $TE

../output
../output/trans_log_dot_0.1.tsv.gz
../output/bert-nli-large-embeddings.tsv.gz


# Utilities

In [5]:
class Vocab:
    def __init__(self, words) -> None:
        self.idx_to_word = words
        self.word_to_idx = {word: idx for idx, word in enumerate(words)}

def read_embedding_file(embedding_file: Path, dim: int, emb_col=1) -> Tuple[Vocab, np.ndarray]:

    logger.debug(f'Reading embeddings from {embedding_file}')

    shape = tuple([count_lines(embedding_file), dim])
                  
    with gzip.open(embedding_file, 'r') as f:

        embeddings = np.zeros(shape, dtype=np.float32)

        if emb_col!=1:
            header=next(f)
        i=0
        words = []
        for line in tqdm(f, total=shape[0]):
            line=line.decode()
            if emb_col==1:
                node, *embedding = line.split()
            else:
                line_data=line.split()
                if line_data[1]=='embedding_sentence': continue
                node=line_data[0]
                embedding=line_data[2].split(',')
            embedding = np.array([float(x) for x in embedding])
            words.append(node)
            embeddings[i] = embedding
            i+=1

    vocab = Vocab(words)

    return vocab, embeddings

def count_lines(embedding_file: Path):
    with gzip.open(embedding_file, 'r') as f:
        i=0
        for line in f:
            i+=1
    return i

def build_index(metric: str, embeddings: np.ndarray):

    logger.debug(f'Building search index')

    if metric == 'cosine':
        index = faiss.IndexFlatIP(embeddings.shape[-1])
    elif metric == 'l2':
        index = faiss.IndexFlatL2(embeddings.shape[-1])
    else:
        raise ValueError(f'Bad metric: {metric}')

    index.add(embeddings)

    return index

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

## I. Load embeddings

### Load graph embeddings

In [6]:
graph_dim = 100 # Dimension of the graph embeddings for our example's file
graph_vocab, graph_embeddings = read_embedding_file(graph_emb_path,graph_dim)

100%|██████████| 2160968/2160968 [01:17<00:00, 27865.38it/s]


In [7]:
graph_index = build_index(distance, graph_embeddings)

### Load text embeddings

In [8]:
text_dim=1024 # Dimension of the text embeddings for our example's file
text_vocab, text_embeddings = read_embedding_file(text_emb_path, text_dim, emb_col=2)

100%|█████████▉| 2161048/2161049 [10:44<00:00, 3355.01it/s]


In [9]:
text_index = build_index(distance, text_embeddings)

## II. Most similar nodes in CSKG

In [10]:
query_nodes=['/c/en/turtle', '/c/en/happy', '/c/en/turtle/n/wn/animal', 
             'at:personx_abandons_____altogether', '/c/en/caffeine']

In [11]:
num_neighbors=5

### According to graph embeddings

In [12]:
ids=[graph_vocab.word_to_idx[n] for n in query_nodes]
distances, neighbors = graph_index.search(graph_embeddings[ids], num_neighbors+1)
for node_nbrs in neighbors:
    neighboring_nodes=[graph_vocab.idx_to_word[n] for n in node_nbrs]

    print('Nearest neighbors to *%s*' % neighboring_nodes[0])
    print(neighboring_nodes[1:])
    print()
    print()

Nearest neighbors to */c/en/turtle*
['/c/en/tortoise', '/c/en/turtling/v/wikt/en_1', '/c/en/turtles/v', '/c/en/turtled/v', '/c/en/turtles/n']


Nearest neighbors to */c/en/happy*
['/c/en/pleased', '/c/en/excited', '/c/en/content', '/c/en/satisfied', '/c/en/joyful']


Nearest neighbors to */c/en/turtle/n/wn/animal*
['/c/en/chelonian/n/wn/animal', '/c/en/sea_turtle/n/wn/animal', '/c/en/pseudemys/n/wn/animal', '/c/en/mud_turtle/n/wn/animal', '/c/en/cooter/n/wn/animal']


Nearest neighbors to *at:personx_abandons_____altogether*
['at:turns_over_a_new_leaf', 'at:to_get_permission_from_his_parents', 'at:plows_the_field', 'at:was_just_city', 'at:to_search_for_a_new_job']


Nearest neighbors to */c/en/caffeine*
['/c/en/caffiene/n', '/c/en/caffeines/n', '/c/en/caffein/n', '/c/en/noncaffeine', '/c/en/caffeinelike']




### According to text embeddings

In [13]:
ids=[text_vocab.word_to_idx[n] for n in query_nodes]
distances, neighbors = text_index.search(text_embeddings[ids], num_neighbors+1)

for node_nbrs in neighbors:
    neighboring_nodes=[text_vocab.idx_to_word[n] for n in node_nbrs]

    print('Nearest neighbors to *%s*' % neighboring_nodes[0])
    print(neighboring_nodes[1:])
    print()
    print()

Nearest neighbors to */c/en/turtle*
['/c/en/shrimp_and_turtle', '/c/en/turtles', '/c/en/dolphin', '/c/en/loon/n/wn/animal', '/c/en/ducks']


Nearest neighbors to */c/en/happy*
['/c/en/bring_happiness', 'at:happy_that_they_went_to_the_party', 'at:like_a_party_is_a_good_way_to_express_their_jubilation', 'at:if_for_a_party,_happy', "/c/en/encouraging_person's_talent"]


Nearest neighbors to */c/en/turtle/n/wn/animal*
['/c/en/glyptemys/n', '/c/en/chelidae/n', '/c/en/pelocomastes/n', '/c/en/parahydraspis/n', '/c/en/sternotherus/n']


Nearest neighbors to *at:personx_abandons_____altogether*
['at:personx_is_sent_home', "at:personx_loses_personx's_position", 'at:personx_loses_persony_opportunity', 'at:personx_loses_a_bet', 'at:personx_is_promptly_fired']


Nearest neighbors to */c/en/caffeine*
['/c/en/people_drink_coffee_because', '/c/en/cup_of_coffee', '/c/en/brewing_coffee/n', 'at:put_coffee_grounds_in_caffee_maker', '/c/en/hot_chocolate']




## III. Calculate similarity between two nodes

In [14]:
node_pairs=[['/c/en/woman', '/c/en/man'], 
            ['/c/en/pencil', 'Q614304'], 
            ['/c/en/ash', 'rg:en_ash-gray'], 
            ['/c/en/spiritual', '/c/en/religion'], 
            ['/c/en/monkey', '/c/en/gorilla'], 
            ['/c/en/monkey', '/c/en/tea']]

### According to graph embeddings

In [15]:
for nodes in node_pairs:
    ids=[graph_vocab.word_to_idx[n] for n in nodes]
    ge=graph_embeddings[ids]
    print(' '.join(nodes), cosine_similarity([ge[0]], [ge[1]])[0][0])

/c/en/woman /c/en/man 0.62742233
/c/en/pencil Q614304 0.028264966
/c/en/ash rg:en_ash-gray 0.30954424
/c/en/spiritual /c/en/religion 0.4183224
/c/en/monkey /c/en/gorilla 0.46226677
/c/en/monkey /c/en/tea 0.24436088


### According to text embeddings

In [16]:
for nodes in node_pairs:
    ids=[text_vocab.word_to_idx[n] for n in nodes]
    ge=text_embeddings[ids]
    print(' '.join(nodes), cosine_similarity([ge[0]], [ge[1]])[0][0])

/c/en/woman /c/en/man 0.32196832
/c/en/pencil Q614304 0.43661886
/c/en/ash rg:en_ash-gray 0.5951916
/c/en/spiritual /c/en/religion 0.54427576
/c/en/monkey /c/en/gorilla 0.75078595
/c/en/monkey /c/en/tea 0.48101428
