# 5. Entity Linking

In [5]:
%run __init__.py

In [6]:
from bokeh.io import output_notebook

output_notebook()

## Defining the entity linking class

In [7]:
import json
import requests


WIKIDATA_BASE = "https://www.wikidata.org/w"

class WikidataEntityLinker():
    def __init__(self, user, passwd):
        pass
    
    def fit(self, X, y, *args):
        return self
    
    def transform(self, X, y, *args):
        return [self.link_entity(entity) 
                for doc in X
                for entity in doc]
    
    def link_entity(self, entity_label):
        url = f"{WIKIDATA_BASE}/api.php?action=wbsearchentities&search=" + \
            f"{entity_label}&language=en&format=json"
        response = requests.get(url)
        if response.status_code != 200:
            raise Error()
        content = json.loads(response.text)
        search_results = content['search']
        if len(search_results) == 0:
            return (entity_label, None)
        return (entity_label, search_results[0]['concepturi'])


In [8]:
entity_linker = WikidataEntityLinker("", "")
res = entity_linker.link_entity('agroforestry')
res

('agroforestry', 'http://www.wikidata.org/entity/Q397350')

## Linking each topic's term to Wikidata

In [9]:
import dill as pickle

# see https://stackoverflow.com/questions/42960637/python-3-5-dill-pickling-unpickling-on-different-servers-keyerror-classtype
pickle._dill._reverse_typemap['ClassType'] = type

def load_object(output_path):
    with open(output_path, 'rb') as file:
        res = pickle.load(file)
    return res

In [10]:
NOTEBOOK_RESULTS_DIR = os.path.join(RESULTS_DIR, '3_topic_modeling')
lda_agriculture_pipe_filename = "agriculture_lda_model.pkl"
dtm_tf_filename = "agriculture_dtm_tf.pkl"

lda_pipe = load_object(os.path.join(NOTEBOOK_RESULTS_DIR, lda_agriculture_pipe_filename))
dtm_tf = load_object(os.path.join(NOTEBOOK_RESULTS_DIR, dtm_tf_filename))

In [11]:
from src.utils import get_topic_terms_by_relevance

def link_topic_terms(entity_linker, model, vectorizer,
                     dtm_tf, n_top_words, lambda_=0.6):
    res = []
    topic_terms = get_topic_terms_by_relevance(model, vectorizer, dtm_tf,
                                               n_top_words, lambda_)
    return [[entity_linker.link_entity(entity) for entity in topic]
            for topic in topic_terms]


In [12]:
linked_terms = link_topic_terms(entity_linker, lda_pipe.named_steps['model'],
                                lda_pipe.named_steps['vectorizer'], dtm_tf, 
                                n_top_words=10, lambda_=0.75)
linked_terms[2]

In [46]:
linked_terms[1]

[('plant', 'http://www.wikidata.org/entity/Q756'),
 ('expression', 'http://www.wikidata.org/entity/Q11024'),
 ('gene', 'http://www.wikidata.org/entity/Q7187'),
 ('adaptation', 'http://www.wikidata.org/entity/Q3331189'),
 ('root', 'http://www.wikidata.org/entity/Q111029'),
 ('shoot', 'http://www.wikidata.org/entity/Q220869'),
 ('diversification', 'http://www.wikidata.org/entity/Q731453'),
 ('site', 'http://www.wikidata.org/entity/Q35127'),
 ('Arabidopsis', 'http://www.wikidata.org/entity/Q157892'),
 ('heat', 'http://www.wikidata.org/entity/Q44432')]

## Topic labelling

In [47]:
import functools
import pdb

from dataclasses import dataclass

import networkx as nx


WIKIDATA_PROPS_EXPAND = ['P31', 'P279', 'P301', 'P910', 'P2579']


def empty_if_keyerror(function):
    """
    A decorator that wraps the passed in function and
    returns an empty string if a key error is raised.
    """
    @functools.wraps(function)
    def wrapper(*args, **kwargs):
        try:
            return function(*args, **kwargs)
        except KeyError:
            return ""
    return wrapper

def _build_uri(entity_id):
    return f"http://www.wikidata.org/entity/{entity_id}"

@empty_if_keyerror
def _get_aliases(entity_info, lang='en'):
    return [alias['value'] 
            for alias in entity_info['aliases'][lang]]

@empty_if_keyerror
def _get_desc(entity_info, lang='en'):
    return entity_info['descriptions'][lang]['value']

@empty_if_keyerror
def _get_labels(entity_info, lang='en'):
    return entity_info['labels'][lang]['value']


@dataclass
class WikidataNode():
    label: str
    uri: str
    desc: str
    alias: str
        
    def __hash__(self):
        return hash(self.uri)
    
    def to_dict(self):
        return {
            'alias':self.alias,
            'desc': self.desc,
            'label': self.label,
            'uri': self.uri
        }


class WikidataGraphBuilder():
    def __init__(self, max_hops=2, additional_props=None):
        self.max_hops = max_hops
        self.props_to_expand = WIKIDATA_PROPS_EXPAND
        if additional_props:
            self.props_to_expand += additional_props
    
    def build_graph(self, topic):
        G = nx.Graph()
        for term in topic:
            term_uri = term[1]
            term_id = term_uri.split('/')[-1]
            self._add_wd_node_info(G, term_id, None, 0)
        return G
    
    def _add_wd_node_info(self, graph, term_id, prev_node, curr_hop):
        print(f"Visiting entity '{term_id}' - Curr hop: {curr_hop}")
        if curr_hop > self.max_hops:
            return
        
        # call wikidata API for uri
        endpoint = f"{WIKIDATA_BASE}/api.php?action=wbgetentities&ids={term_id}&languages=en&format=json"
        res = requests.get(endpoint)
        if res.status_code != 200:
            raise Error()
        
        content = json.loads(res.text)
        entity_info = content['entities'][term_id]
        
        if term_id not in graph.nodes:
            graph.add_node(term_id)
            #graph.nodes[term_id]['alias'] = _get_aliases(entity_info)
            graph.nodes[term_id]['desc'] = _get_desc(entity_info)
            graph.nodes[term_id]['label'] = _get_labels(entity_info)
            graph.nodes[term_id]['n'] = curr_hop

        if prev_node is not None and not graph.has_edge(prev_node, term_id):
            graph.add_edge(prev_node, term_id)
        
        for claim_key, claim_values in entity_info['claims'].items():
            if claim_key not in self.props_to_expand:
                continue
            
            for value in claim_values:
                snaktype = value['mainsnak']['snaktype']
                if snaktype in ['novalue', 'somevalue']:
                    continue
                
                new_node_id = value['mainsnak']['datavalue']['value']['id']
                self._add_wd_node_info(graph, new_node_id, term_id, curr_hop + 1)


In [None]:
graph_builder = WikidataGraphBuilder(max_hops=2)
topic_graphs = [graph_builder.build_graph(topic) for topic in linked_terms]

Visiting entity 'Q2095' - Curr hop: 0
Visiting entity 'Q2424752' - Curr hop: 1
Visiting entity 'Q28877' - Curr hop: 2
Visiting entity 'Q29028649' - Curr hop: 3
Visiting entity 'Q64513524' - Curr hop: 3
Visiting entity 'Q5672864' - Curr hop: 3
Visiting entity 'Q8134' - Curr hop: 3
Visiting entity 'Q337060' - Curr hop: 3
Visiting entity 'Q8205328' - Curr hop: 2
Visiting entity 'Q223557' - Curr hop: 3
Visiting entity 'Q16686448' - Curr hop: 3
Visiting entity 'Q26991679' - Curr hop: 3
Visiting entity 'Q15401930' - Curr hop: 2
Visiting entity 'Q488383' - Curr hop: 3
Visiting entity 'Q7189878' - Curr hop: 2
Visiting entity 'Q2424752' - Curr hop: 3
Visiting entity 'Q2897903' - Curr hop: 3
Visiting entity 'Q4167836' - Curr hop: 3
Visiting entity 'Q1194058' - Curr hop: 1
Visiting entity 'Q2424752' - Curr hop: 2
Visiting entity 'Q28877' - Curr hop: 3
Visiting entity 'Q8205328' - Curr hop: 3
Visiting entity 'Q15401930' - Curr hop: 3
Visiting entity 'Q7189878' - Curr hop: 3
Visiting entity 'Q16813

Visiting entity 'Q4167836' - Curr hop: 3
Visiting entity 'Q1496967' - Curr hop: 3
Visiting entity 'Q101998' - Curr hop: 1
Visiting entity 'Q107425' - Curr hop: 2
Visiting entity 'Q1496967' - Curr hop: 3
Visiting entity 'Q7143080' - Curr hop: 3
Visiting entity 'Q30336093' - Curr hop: 2
Visiting entity 'Q58778' - Curr hop: 3
Visiting entity 'Q815297' - Curr hop: 3
Visiting entity 'Q7145637' - Curr hop: 3
Visiting entity 'Q96116695' - Curr hop: 3
Visiting entity 'Q7031839' - Curr hop: 2
Visiting entity 'Q101998' - Curr hop: 3
Visiting entity 'Q4167836' - Curr hop: 3
Visiting entity 'Q2083910' - Curr hop: 1
Visiting entity 'Q7725551' - Curr hop: 2
Visiting entity 'Q2083910' - Curr hop: 3
Visiting entity 'Q4167836' - Curr hop: 3
Visiting entity 'Q101998' - Curr hop: 2
Visiting entity 'Q107425' - Curr hop: 3
Visiting entity 'Q30336093' - Curr hop: 3
Visiting entity 'Q7031839' - Curr hop: 3
Visiting entity 'Q175208' - Curr hop: 2
Visiting entity 'Q8908997' - Curr hop: 3
Visiting entity 'Q4257

Visiting entity 'Q4167836' - Curr hop: 3
Visiting entity 'Q25403900' - Curr hop: 3
Visiting entity 'Q50413986' - Curr hop: 2
Visiting entity 'Q50377228' - Curr hop: 3
Visiting entity 'Q7134776' - Curr hop: 1
Visiting entity 'Q11004' - Curr hop: 2
Visiting entity 'Q756' - Curr hop: 3
Visiting entity 'Q2095' - Curr hop: 3
Visiting entity 'Q25403900' - Curr hop: 3
Visiting entity 'Q7134776' - Curr hop: 3
Visiting entity 'Q173113' - Curr hop: 3
Visiting entity 'Q2375831' - Curr hop: 3
Visiting entity 'Q4167836' - Curr hop: 2
Visiting entity 'Q12139612' - Curr hop: 3
Visiting entity 'Q17442446' - Curr hop: 3
Visiting entity 'Q35252665' - Curr hop: 3
Visiting entity 'Q2944534' - Curr hop: 3
Visiting entity 'Q173113' - Curr hop: 1
Visiting entity 'Q8227263' - Curr hop: 2
Visiting entity 'Q173113' - Curr hop: 3
Visiting entity 'Q59541917' - Curr hop: 3
Visiting entity 'Q11862829' - Curr hop: 2
Visiting entity 'Q1047113' - Curr hop: 3
Visiting entity 'Q6642719' - Curr hop: 3
Visiting entity 'Q2

Visiting entity 'Q1458590' - Curr hop: 3
Visiting entity 'Q1914636' - Curr hop: 3
Visiting entity 'Q3919817' - Curr hop: 3
Visiting entity 'Q2996394' - Curr hop: 2
Visiting entity 'Q3249551' - Curr hop: 3
Visiting entity 'Q3249551' - Curr hop: 3
Visiting entity 'Q13878858' - Curr hop: 3
Visiting entity 'Q64732777' - Curr hop: 3
Visiting entity 'Q6226217' - Curr hop: 3
Visiting entity 'Q8971613' - Curr hop: 2
Visiting entity 'Q4167836' - Curr hop: 3
Visiting entity 'Q921513' - Curr hop: 3
Visiting entity 'Q6031064' - Curr hop: 1
Visiting entity 'Q23009552' - Curr hop: 2
Visiting entity 'Q23009675' - Curr hop: 3
Visiting entity 'Q52948' - Curr hop: 2
Visiting entity 'Q8550126' - Curr hop: 3
Visiting entity 'Q4026292' - Curr hop: 3
Visiting entity 'Q97008347' - Curr hop: 3
Visiting entity 'Q1458498' - Curr hop: 1
Visiting entity 'Q11024' - Curr hop: 2
Visiting entity 'Q921513' - Curr hop: 3
Visiting entity 'Q6031064' - Curr hop: 3
Visiting entity 'Q1458498' - Curr hop: 3
Visiting entity '

Visiting entity 'Q1186952' - Curr hop: 3
Visiting entity 'Q4167836' - Curr hop: 3
Visiting entity 'Q1714118' - Curr hop: 1
Visiting entity 'Q21572908' - Curr hop: 2


In [None]:
import networkx as nx

from bokeh.io import output_file, show
from bokeh.layouts import gridplot
from bokeh.models import (BoxZoomTool, Circle, HoverTool,
                          MultiLine, Plot, Range1d, ResetTool,)
from bokeh.palettes import Spectral4
from bokeh.plotting import from_networkx

def build_graph_plot(G, title=""):
    plot = Plot(plot_width=400, plot_height=400,
                x_range=Range1d(-1.1, 1.1), y_range=Range1d(-1.1, 1.1))
    plot.title.text = title

    SAME_CLUB_COLOR, DIFFERENT_CLUB_COLOR = "black", "red"
    node_attrs = {}

    for node in G.nodes(data=True):
        node_color = Spectral4[node[1]['n']]
        node_attrs[node[0]] = node_color

    nx.set_node_attributes(G, node_attrs, "node_color")

    node_hover_tool = HoverTool(tooltips=[("Label", "@label"), ("n", "@n")])
    plot.add_tools(node_hover_tool, BoxZoomTool(), ResetTool())

    graph_renderer = from_networkx(G, nx.spring_layout, scale=1, center=(0, 0))

    graph_renderer.node_renderer.glyph = Circle(size=15, fill_color="node_color")
    graph_renderer.edge_renderer.glyph = MultiLine(line_alpha=0.8, line_width=1)
    plot.renderers.append(graph_renderer)
    return plot


plots = [build_graph_plot(g, f"Topic {idx}") 
         for idx, g in enumerate(topic_graphs)]
grid = gridplot(plots, ncols=2)
show(grid)

In [None]:
import networkx.algorithms as nxa

def get_largest_connected_subgraph(g):
    S = [g.subgraph(c).copy() for c in nxa.components.connected_components(g)]
    return max(S, key=len)


In [None]:
connected_topic_subgraphs = [get_largest_connected_subgraph(g) 
                             for g in topic_graphs]

In [None]:
plots = [build_graph_plot(g, f"Largest Connected subgraph for topic {idx}") 
         for idx, g in enumerate(connected_topic_subgraphs)]
grid = gridplot(plots, ncols=2)
show(grid)

In [None]:
def show_algorithm_results(topic_subgraphs, algorithm):
    pass


In [None]:
import networkx.algorithms as nxa

res = nxa.centrality.information_centrality(connected_topic_subgraphs[3])
res

In [None]:
import operator

sorted(res, key=res.get, reverse=True)[:3]

In [78]:
connected_topic_subgraphs[3].nodes['Q4167836']

{'desc': "use with 'instance of' (P31) for Wikimedia category",
 'label': 'Wikimedia category',
 'n': 3,
 'node_color': '#d7191c'}

The main funcionality from above will be implemented in a custom class that conforms to the sklearn API:

In [None]:

class TopicLabeller():
    def __init__(self, graph_builder,
                 r=nxa.centrality.information_centrality,
                 num_labels_per_topic=1):
        self.graph_builder = graph_builder
        self.r = r
        self.num_labels = num_labels_per_topic
    
    def transform(self, X, y, **kwargs):
        return [self.get_topics_labels(topic) for topic in X]
    
    def get_topic_labels(self, topic_graph):
        topic_neighbourhood = self.graph_builder.build_graph(topic)
        subgraph = get_largest_connected_subgraph(topic_neighbourhood)
        metrics = r(subgraph) 
        # TODO remove seed concepts from metrics
        best_qids = max(metrics.items(), key=operator.itemgetter(1))
        return [subgraph[qid], metrics[qid] for qid in best_qids



## Add labels to LDA model