In [1]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import bz2
import csv
import io
import re
import time
import json
import random
import requests
from tqdm import tqdm
import multiprocessing
import concurrent.futures
import pickle as pkl
import numpy as np
import networkx as nx
from functools import partial, reduce
from collections import Counter
from pathlib import Path
from pprint import pprint
from typing import List, Dict
import matplotlib.pyplot as plt
import lsde2021.csv as csvutil
import lsde2021.utils as utils
from lsde2021.lang import singularize, pluralize
import lsde2021.download as dl
from pyspark.sql import SparkSession, DataFrame
import pyspark.sql.types as T
import pyspark.sql.functions as F

In [2]:
MAX_MEMORY = "10G"

spark = SparkSession \
    .builder \
    .appName("parse-wikipedia-sql-dumps") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config('spark.driver.maxResultSize', MAX_MEMORY) \
    .config('spark.ui.showConsoleProgress', 'false') \
    .getOrCreate()
sc = spark.sparkContext

csv_loader = spark.read.format("csv").options(header='True', inferSchema='True')
parquet_reader = spark.read.format("parquet").options(inferSchema='True')

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/11/05 17:27:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# join categories with english wiki page table
wiki = "enwiki"
raw_pages = parquet_reader.load(str(f"../nvme/wikipedia_sql_dumps/{wiki}/20211001/{wiki}-20211001-page-category-count.sql.parquet"))

In [4]:
# first have a closer look at some of the categories and how they look like so we can split them eventually
example_categories = raw_pages.select("category_name").limit(1_000).rdd.flatMap(lambda x: x).collect()

In [5]:
pprint(example_categories[0:100])

['Surgical_specialties',
 'Articles_with_GND_identifiers',
 'Articles_with_MA_identifiers',
 'Articles_with_LCCN_identifiers',
 'All_articles_needing_additional_references',
 'Articles_needing_additional_references_from_October_2013',
 'General_surgery',
 'All_articles_with_unsourced_statements',
 'Wikipedia_articles_needing_page_number_citations_from_May_2015',
 'Articles_with_unsourced_statements_from_October_2020',
 'CS1_German-language_sources_(de)',
 'Noam_Chomsky',
 'Formal_languages',
 'Wikipedia_articles_needing_page_number_citations_from_November_2014',
 'Articles_with_incomplete_citations_from_September_2013',
 'Lists_of_aquarium_life',
 'Articles_with_short_description',
 'Short_description_is_different_from_Wikidata',
 'All_articles_with_incomplete_citations',
 'Lists_of_fishes',
 'Commons_category_link_is_on_Wikidata',
 'All_articles_lacking_in-text_citations',
 'Articles_with_short_description',
 'Lists_of_musical_instruments',
 'Indian_inventions',
 'Short_description_is

In [6]:
graph = nx.read_gpickle(f"../nvme/en-category-tree.pkl")

In [7]:
# filter out all hidden categories by removing nodes that have an edge to the hidden category
hidden_category = raw_pages \
    .filter((F.col("category_name") == "Hidden_categories") & (F.col("page_namespace") == 14))
hidden_category.limit(100).show()

hidden_category = hidden_category \
    .groupBy("category_page_id") \
    .count()
hidden_category.show()

+--------+-----------------+--------------+--------------------+----------------+-----+
| page_id|    category_name|page_namespace|          page_title|category_page_id|count|
+--------+-----------------+--------------+--------------------+----------------+-----+
|21565229|Hidden_categories|            14|Articles_needing_...|        15961454|    7|
|21565352|Hidden_categories|            14|Articles_needing_...|        15961454|    7|
|21616012|Hidden_categories|            14|Articles_with_wea...|        15961454|    7|
|22853578|Hidden_categories|            14|Articles_with_dea...|        15961454|    7|
|22934874|Hidden_categories|            14|Articles_to_be_ex...|        15961454|    7|
|27162873|Hidden_categories|            14|Articles_slanted_...|        15961454|    7|
|27900429|Hidden_categories|            14|Wikipedia_introdu...|        15961454|    7|
|28658797|Hidden_categories|            14|Wikipedia_article...|        15961454|    7|
|29609568|Hidden_categories|    

In [8]:
hidden_category_node = 15961454

hidden_sub_categories = list(nx.bfs_tree(graph, reverse=True, source=hidden_category_node, depth_limit=1))

# 30259 for depth_limit=1
# 6_838_612 for depth_limit=2
print(len(hidden_sub_categories))
pprint([graph.nodes[n] for n in hidden_sub_categories[-25:]])

30259
[{'is_category': True,
  'node_count': 1,
  'title': 'Pages_with_a_Wikidata_coding_problem'},
 {'is_category': True,
  'node_count': 0,
  'title': 'Interlanguage_link_template_forcing_interwiki_links'},
 {'is_category': True,
  'node_count': 0,
  'title': 'Articles_needing_audio_and_or_video'},
 {'is_category': True,
  'node_count': 1,
  'title': 'PD-logo_candidates_for_review'},
 {'is_category': True,
  'node_count': 0,
  'title': 'World_Geographical_Scheme_for_Recording_Plant_Distributions'},
 {'is_category': True, 'node_count': 1, 'title': 'Upcoming_MMA_fights'},
 {'is_category': True,
  'node_count': 0,
  'title': 'Chembox_and_Drugbox_articles_with_a_broken_CheMoBot_template'},
 {'is_category': True,
  'node_count': 0,
  'title': 'Pages_using_certification_Table_Entry_with_streaming-only_figures'},
 {'is_category': True,
  'node_count': 1,
  'title': 'Files_with_no_machine-readable_patent'},
 {'is_category': True, 'node_count': 1, 'title': 'Test-false'},
 {'is_category': True

In [9]:
# remove hidden topics and their edges from the graph
print("edges before: %d | nodes before: %d" % (len(graph.edges), len(graph.nodes)))
graph.remove_nodes_from(hidden_sub_categories)
print("edges after: %d | nodes after: %d" % (len(graph.edges), len(graph.nodes)))

# edges before: 74832061 | nodes before: 7987708
# edges after: 38634343 | nodes after: 7957449

edges before: 74832061 | nodes before: 7987708
edges after: 38634343 | nodes after: 7957449


In [16]:
print("edges reduced", 38634343/74832061)
print("nodes reduced", 7957449/7987708)

edges reduced 0.5162806220184153
nodes reduced 0.9962118044375182


In [10]:
# save the graph for reuse
# nx.write_gpickle(graph, f"../nvme/en-category-tree-without-hidden.pkl")

In [11]:
# Example case: find the COVID19 wikipedia article
covid_article = raw_pages.filter(F.col("page_title") == "COVID-19").limit(100)
covid_article.show()

+--------+--------------------+--------------+----------+----------------+-----+
| page_id|       category_name|page_namespace|page_title|category_page_id|count|
+--------+--------------------+--------------+----------+----------------+-----+
|63508804|Wikipedia_categor...|            14|  COVID-19|        65947847|    8|
|63508804|          SARS-CoV-2|            14|  COVID-19|        66148304|    8|
|63508804|Viral_respiratory...|            14|  COVID-19|        39756458|    8|
|63508804|Coronavirus-assoc...|            14|  COVID-19|        67944185|    8|
|63508804|Interlanguage_lin...|            14|  COVID-19|        44647562|    8|
|63508804|Commons_category_...|            14|  COVID-19|        59055145|    8|
|63508804|   Airborne_diseases|            14|  COVID-19|        64024573|    8|
|63508804| Atypical_pneumonias|            14|  COVID-19|        62902866|    8|
|63030231|CS1_maint:_DOI_in...|             0|  COVID-19|        67549059|   38|
|63030231|       Public_heal

In [12]:
# find the content category
root_category = raw_pages.filter((F.col("category_name") == "Content") & (F.col("page_namespace") == 14)).limit(100)
root_category.show()

+-------+-------------+--------------+----------+----------------+-----+
|page_id|category_name|page_namespace|page_title|category_page_id|count|
+-------+-------------+--------------+----------+----------------+-----+
+-------+-------------+--------------+----------+----------------+-----+



In [13]:
# find sinks in the graph (there should only be one)
sinks = [node for node in graph.nodes if graph.out_degree(node) == 0 and graph.in_degree(node) > 0]
print(len(sinks))

4625


In [14]:
pprint([graph.nodes[n]["title"] for n in sinks[:20]])

['UK_MPs_1955–1959',
 'Recipients_of_the_Order_of_the_White_Lion',
 'UK_MPs_1935–1945',
 'UK_MPs_1906–1910',
 'UK_MPs_1910–1918',
 'UK_MPs_1959–1964',
 'UK_MPs_1951–1955',
 'UK_MPs_1924–1929',
 'UK_MPs_1945–1950',
 'UK_MPs_1929–1931',
 'UK_MPs_1918–1922',
 'UK_MPs_1910',
 'UK_MPs_1931–1935',
 'UK_MPs_1950–1951',
 'People_from_Hollywood',
 'Gore_Vidal',
 'Prince_(musician)',
 'UK_MPs_1983–1987',
 'UK_MPs_1974–1979',
 'Chemists_as_head_of_government']


In [21]:
# get the average node degree of the graph
leafs = [node for node in graph.nodes if graph.in_degree(node) < 1]
inners = [node for node in graph.nodes if graph.in_degree(node) > 0]

def avg_out_degree(nodes):
    sum_of_edges = sum([graph.out_degree(node) for node in nodes])
    return sum_of_edges / len(nodes)

print("average node degree of leafs", avg_out_degree(leafs))
print("average node degree of inners", avg_out_degree(inners))

average node degree of leafs 5.376505102465392
average node degree of inners 2.9483960075027578


In [24]:
numeric = re.compile(r'^([\s\d]+)$')

patterns = [
    (re.compile(r"^\d+th-century_(\w+)_in_the_(\w+)$"), []),
    (re.compile(r"^\d+th-century_(\w+)_in_(\w+)$"), []),
    
    (re.compile(r"^\d+s_in_the_(\w+)$"), []),
    (re.compile(r"^\d+s_in_(\w+)$"), []),
    (re.compile(r"^\d+_in_the_(\w+)$"), []),
    (re.compile(r"^\d+_in_(\w+)$"), []),
    
    (re.compile(r"^(\w+)_based_in_(\w+)_by_subject$"), []),
    
    (re.compile(r"^(\w+)_established_in_the_(\w+)$"), []),
    (re.compile(r"^(\w+)_established_in_(\w+)$"), []),
    
    (re.compile(r"^(\w+)_in_the_(\w+)$"), []),
    (re.compile(r"^(\w+)_in_(\w+)$"), []),
    
    (re.compile(r"^(\w+)_and_the_(\w+)$"), []),
    (re.compile(r"^(\w+)_and_(\w+)$"), []),
    
    (re.compile(r"^(\w+)_of_the_(\w+)_by_country$"), []),
    (re.compile(r"^(\w+)_of_(\w+)_by_country$"), []),
    (re.compile(r"^(\w+)_of_the_(\w+)$"), []),
    (re.compile(r"^(\w+)_of_(\w+)$"), []),
    
    (re.compile(r"^(\w+)_by_country$"), []),
    (re.compile(r"^(\w+)_by_region$"), []),
    (re.compile(r"^(\w+)_by_location$"), []),
    (re.compile(r"^(\w+)_by_field$"), []),
    (re.compile(r"^(\w+)_by_location$"), []),
    (re.compile(r"^(\w+)_by_type$"), []),
    
    (re.compile(r"^\d+_(\w+)_by_legal_status$"), []),
    (re.compile(r"^\d+_(\w+)_by_year$"), []),
    (re.compile(r"^\d+_(\w+)_by_date$"), []),
    (re.compile(r"^\d+_(\w+)_by_year_and_country$"), []),
    (re.compile(r"^\d+_(\w+)_by_country_and_year$"), []),
    (re.compile(r"^\d+_(\w+)_by_country$"), []),
    (re.compile(r"^\d+_(\w+)_by_continent$"), []),
    (re.compile(r"^\d+_(\w+)_by_decade$"), []),
    (re.compile(r"^\d+_(\w+)_by_date$"), []),
    (re.compile(r"^\d+_(\w+)_by_(\w+)$"), []),
    
    
    (re.compile(r"^(\w+)_by_legal_status$"), []),
    (re.compile(r"^(\w+)_by_year$"), []),
    (re.compile(r"^(\w+)_by_date$"), []),
    (re.compile(r"^(\w+)_by_year_and_country$"), []),
    (re.compile(r"^(\w+)_by_country_and_year$"), []),
    (re.compile(r"^(\w+)_by_country$"), []),
    (re.compile(r"^(\w+)_by_continent$"), []),
    (re.compile(r"^(\w+)_by_decade$"), []),
    (re.compile(r"^(\w+)_by_date$"), []),
    (re.compile(r"^(\w+)_by_(\w+)$"), []),
    
    (re.compile(r"^\d+_(\w+)$"), []),
]
print(len(patterns))

test_str = 'Companies_by_date'
for pattern, extra_words in patterns:
    match = pattern.fullmatch(test_str)
    if match:
        print(list(match.groups()))
        break

44
['Companies']


In [23]:
stopwords = ['a', 'about', 'above', 'across', 'after', 'afterwards']
stopwords += ['again', 'against', 'all', 'almost', 'alone', 'along']
stopwords += ['already', 'also', 'although', 'always', 'am', 'among']
stopwords += ['amongst', 'amoungst', 'amount', 'an', 'and', 'another']
stopwords += ['any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere']
stopwords += ['are', 'around', 'as', 'at', 'back', 'be', 'became']
stopwords += ['because', 'become', 'becomes', 'becoming', 'been']
stopwords += ['before', 'beforehand', 'behind', 'being', 'below']
stopwords += ['beside', 'besides', 'between', 'beyond', 'bill', 'both']
stopwords += ['bottom', 'but', 'by', 'call', 'can', 'cannot', 'cant']
stopwords += ['co', 'computer', 'con', 'could', 'couldnt', 'cry', 'de']
stopwords += ['describe', 'detail', 'did', 'do', 'done', 'down', 'due']
stopwords += ['during', 'each', 'eg', 'eight', 'either', 'eleven', 'else']
stopwords += ['elsewhere', 'empty', 'enough', 'etc', 'even', 'ever']
stopwords += ['every', 'everyone', 'everything', 'everywhere', 'except']
stopwords += ['few', 'fifteen', 'fifty', 'fill', 'find', 'fire', 'first']
stopwords += ['five', 'for', 'former', 'formerly', 'forty', 'found']
stopwords += ['four', 'from', 'front', 'full', 'further', 'get', 'give']
stopwords += ['go', 'had', 'has', 'hasnt', 'have', 'he', 'hence', 'her']
stopwords += ['here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers']
stopwords += ['herself', 'him', 'himself', 'his', 'how', 'however']
stopwords += ['hundred', 'i', 'ie', 'if', 'in', 'inc', 'indeed']
stopwords += ['interest', 'into', 'is', 'it', 'its', 'itself', 'keep']
stopwords += ['last', 'latter', 'latterly', 'least', 'less', 'ltd', 'made']
stopwords += ['many', 'may', 'me', 'meanwhile', 'might', 'mill', 'mine']
stopwords += ['more', 'moreover', 'most', 'mostly', 'move', 'much']
stopwords += ['must', 'my', 'myself', 'name', 'namely', 'neither', 'never']
stopwords += ['nevertheless', 'next', 'nine', 'no', 'nobody', 'none']
stopwords += ['noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'of']
stopwords += ['off', 'often', 'on','once', 'one', 'only', 'onto', 'or']
stopwords += ['other', 'others', 'otherwise', 'our', 'ours', 'ourselves']
stopwords += ['out', 'over', 'own', 'part', 'per', 'perhaps', 'please']
stopwords += ['put', 'rather', 're', 's', 'same', 'see', 'seem', 'seemed']
stopwords += ['seeming', 'seems', 'serious', 'several', 'she', 'should']
stopwords += ['show', 'side', 'since', 'sincere', 'six', 'sixty', 'so']
stopwords += ['some', 'somehow', 'someone', 'something', 'sometime']
stopwords += ['sometimes', 'somewhere', 'still', 'such', 'system', 'take']
stopwords += ['ten', 'than', 'that', 'the', 'their', 'them', 'themselves']
stopwords += ['then', 'thence', 'there', 'thereafter', 'thereby']
stopwords += ['therefore', 'therein', 'thereupon', 'these', 'they']
stopwords += ['thick', 'thin', 'third', 'this', 'those', 'though', 'three']
stopwords += ['three', 'through', 'throughout', 'thru', 'thus', 'to']
stopwords += ['together', 'too', 'top', 'toward', 'towards', 'twelve']
stopwords += ['twenty', 'two', 'un', 'under', 'until', 'up', 'upon']
stopwords += ['us', 'very', 'via', 'was', 'we', 'well', 'were', 'what']
stopwords += ['whatever', 'when', 'whence', 'whenever', 'where']
stopwords += ['whereafter', 'whereas', 'whereby', 'wherein', 'whereupon']
stopwords += ['wherever', 'whether', 'which', 'while', 'whither', 'who']
stopwords += ['whoever', 'whole', 'whom', 'whose', 'why', 'will', 'with']
stopwords += ['within', 'without', 'would', 'yet', 'you', 'your']
stopwords += ['yours', 'yourself', 'yourselves']

EXCLUDE = set(stopwords).union({"by","or","and","with","the","of","in","without","a","on"})
print(len(EXCLUDE))

321


In [25]:
def flatten(l):
    return [item for sublist in l for item in sublist]

def unique(l, key):
    seen = set()
    seen_add = seen.add
    return [x for x in l if not (key(x) in seen or seen_add(key(x)))]

def union(dfs):
    return reduce(DataFrame.unionAll, dfs)

def is_uppercase(s: str):
    return s[0].isupper()

def split_by_pattern(s: str) -> List[str]:
    for pattern, extra_words in patterns:
        match = re.fullmatch(pattern, s)
        if match:
            return list(match.groups()), True
    return [s], False

def split(s: str, split_unmatched=False, singularize=False, pluralize=False, recursive=False):
    # first, test for common patterns
    splitted, matched = split_by_pattern(s)
    # print(splitted)
    
    # split recursively
    rec_splitted = flatten([split_by_pattern(ss)[0] for ss in splitted])
    # print(rec_splitted)
    while recursive and set(splitted) != set(rec_splitted):
        splitted = rec_splitted[:]
        # print(splitted)
        rec_splitted = flatten([split_by_pattern(ss)[0] for ss in splitted])
            
    if not matched:
        if split_unmatched:
            # if no pattern is found, split and remove stopwords
            splitted += re.split(' |,|_', s)
        else:
            splitted = [s]
    
    splitted = set([sp.replace("_", " ") for sp in splitted if numeric.match(sp) is None])
    # print(s, splitted, matched)
    
    if singularize and pluralize:
        splitted = set([singularize(sp) for sp in splitted]).union(set([pluralize(sp) for sp in splitted]))
    elif singularize:
        splitted = set([singularize(sp) for sp in splitted])
    elif pluralize:
        splitted = set([pluralize(sp) for sp in splitted])
    splitted = splitted - EXCLUDE
    return splitted

def split_all(s: str):
    return split(s, split_unmatched=True)

def bfs_tree(g, node, depth_limit=None):
    ans = []
    visited = set()
    level = [(node, 0)]
    while len(level) > 0:
        for v, depth in level:
            ans.append((v, depth))
            visited.add(v)
        next_level = set()
        for v, depth in level:
            for w in g.neighbors(v):
                if w not in visited:
                    next_level.add((w, depth + 1))
        level = next_level
    return ans

def freq_bfs_tree(g, node, depth_limit=None):
    ans = []
    counts = dict()
    visited = set()
    level = [(node, 0)]
    while len(level) > 0:
        for v, depth in level:
            ans.append((v, depth))
            visited.add(v)
            counts[v] = 1
        next_level = set()
        for v, depth in level:
            for w in g.neighbors(v):
                if w in visited:
                    counts[v] += 1
                elif depth_limit is None or depth + 1 <= depth_limit:
                    next_level.add((w, depth + 1))
        level = next_level
    
    levels = dict()
    for n, depth in ans:
        if depth not in levels:
            levels[depth] = []
        levels[depth].append((n, counts[n]))
    
    levels = {depth: sorted(nodes, key=lambda x: x[1], reverse=True) for depth, nodes in levels.items()}
    return levels
    # return [(n, depth, counts[n]) for n, depth in ans]

In [26]:
def find_topics(node, g, depth_limit: int = 4, max_categories: int = 5) -> Dict[int, List[int]]:
    categories = freq_bfs_tree(g, node, depth_limit=depth_limit)
    if False:
        pprint({
            depth: [(g.nodes[n]["title"], n, count) for n, count in nodes]
            for depth, nodes in categories.items() if depth > 0
        })

    return {
        depth: unique(flatten([
            [w.capitalize() for w in split(g.nodes[n]["title"], recursive=True)]
            for n, count in nodes
        ]), key=lambda x: x[0])[:max_categories] for depth, nodes in categories.items() if depth > 0
    }

In [27]:
%%time
depth_limit = 4
n_categories = 5
page_id = 63030231 # covid 19
# page_id = 11867 # germany
# page_id = 24365 # porsche

pprint(find_topics(page_id, g=graph, depth_limit=4, max_categories=5))

{1: ['Covid-19',
     'Occupational safety',
     'Health',
     'Zoonoses',
     'Viral respiratory tract infections'],
 2: ['Health policy',
     'Airborne diseases',
     'Infectious diseases',
     'Safety',
     'Mode'],
 3: ['Animal health',
     'Humanities',
     'Social sciences',
     'Public policy',
     'Diseases'],
 4: ['Social policy',
     'Medicine',
     'Epidemiology',
     'Natural sciences',
     'Anthropology']}
CPU times: user 28.8 ms, sys: 27 µs, total: 28.8 ms
Wall time: 27.8 ms


In [None]:
all_page_ids =  list(enumerate(sorted(raw_pages.filter(F.col("page_namespace") == 0).select("page_id").distinct().rdd.flatMap(lambda x: x).collect())))
print(len(all_page_ids))
with open("../nvme/en_topics/all_page_ids.pkl", 'wb') as f:
    pkl.dump(all_page_ids, f, protocol=pkl.HIGHEST_PROTOCOL)

In [None]:
results = dict()
savepoint = dict()
save_every = 5_000

bk_dir = Path("../nvme/en_topics/savepoints")
bk_dir.mkdir(parents=True, exist_ok=True)

start = time.time()
for i, page_id in tqdm(all_page_ids, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}'):
    results[page_id] = find_topics(page_id, g=graph, depth_limit=4, max_categories=5)
    
    savepoint[page_id] = results[page_id] # .copy()
    if i >= save_every and i % save_every == 0:
        with open(bk_dir / f"page_topics_{i-save_every}_{i}.pkl", 'wb') as f:
            pkl.dump(savepoint, f, protocol=pkl.HIGHEST_PROTOCOL)
        savepoint = dict()

print(len(results))
print("took %.2f hours" % ((time.time() - start)/(60*60)))

In [None]:
%%time
n_parallel = multiprocessing.cpu_count()
n_parallel = 1

test_page_ids = all_page_ids[:1_000_000]

# results = dict()
results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=n_parallel) as executor:
    for i, test enumerate(executor.map(partial(find_topics, g=graph, depth_limit=4, max_categories=5), test_page_ids)):
        results.append(test)
    
        
        # results[page_id] = topics
    #for page_id, topics in zip(
    #    test_page_ids,
    #    executor.map(partial(find_topics, g=graph, depth_limit=4, max_categories=5), test_page_ids)
    #):
    #    results[page_id] = topics

print(len(results))
# pprint(list(results.items())[0])

In [None]:
def find_topics_worker(page_ids, g, depth_limit, max_categories):
    return {
        page_id: find_topics(page_id, g=g, depth_limit=depth_limit, max_categories=max_categories)
        for page_id in page_ids
    }

In [None]:
n_parallel = 1 # multiprocessing.cpu_count()
chunk_size = int(np.ceil(len(all_page_ids[:100]) / n_parallel))
tasks = []
start = time.time()

# initializer=set_global, initargs=(graph,)
with concurrent.futures.ProcessPoolExecutor(max_workers=n_parallel) as executor:
    # executor.map(partial(find_topics, g=graph, depth_limit=4, max_categories=5), worker_page_ids)
    for worker_id in range(n_parallel):
        worker_page_ids = all_page_ids[worker_id * chunk_size: (worker_id + 1) * chunk_size]
        print(worker_page_ids[:5])
        print("worker %d got assigned %d page ids" % (worker_id, len(worker_page_ids)))
        tasks.append(executor.submit(partial(find_topics_worker, g=graph, depth_limit=4, max_categories=5), worker_page_ids))

# collect the results
results = dict()
for worker_id, task in enumerate(tasks):
    results.update(task.result())
    # print(job)
    #for page_id, topics in zip(*job):
    #    results[page_id] = topics   
    # print(r)
    # cur_revids, cur_topics = proc.result()
    # all_revids.update(cur_revids)
    # all_topics.update(cur_topics)
    print("worker %d done" % worker_id)

print("took %.2f hours" % ((time.time() - start)/(60*60)))
print(len(results))
pprint(list(results.items())[0])

In [None]:
# todo build a udf for looking up topics and join with the page table
# find_topics_udf = F.udf(partial(find_topics, g=graph, depth_limit=4, max_categories=5), T.ArrayType(T.StringType()))

In [None]:
raw_pages \
    .filter(F.col("page_title") == "COVID-19") \
    .withColumn("ores_topics", find_topics_udf(raw_pages['page_id'])) \
    .limit(10) \
    .show()

In [None]:
# deprecated
def find_topics_old(g, node: int, max_level: int = 4, topic_count = 5) -> Dict[int, List[int]]:
    start = time.time()
    distances = range(0, max_level+1)
    topics = {dist: set() for dist in distances}
    primary_topic_nodes = {
        dist: [n for n in nx.descendants_at_distance(g, node, distance=dist)]
        for dist in distances
    }
    primary_topics = {
        dist: flatten([split(graph.nodes[n]["title"]) for n in nodes])
        # dist: ([graph.nodes[n]["title"] for n in nodes])
        for dist, nodes in primary_topic_nodes.items()
    }
    # pprint(primary_topics)
    
    # we make the following assumptions:
    # - single word low level topics are usually the most fitting
    # - splitting multi word low level topics and filtering those that are also high level categories is a good high level category
    # - among the high level categories, we can choose the ones with the highest frequency that are also reachable from other pages
    
    # pprint(primary_topics)
    low_level_topics = [title for title in primary_topics[1]]
    flattened_low_level_topics = set(flatten(low_level_topics))
    
    common_parent_topics = {
        dist: {
            graph.nodes[node]["title"]: set(flatten([split(graph.nodes[n]["title"]) for n in nx.bfs_tree(g, node, depth_limit=2)]))
            for node in nodes
        }
        for dist, nodes in primary_topic_nodes.items()
    }
    
    common_parent_topics_counts = {
        dist: Counter(flatten([
            [tt.lower() for tt in t]
            for t in parent_topics.values()
        ]))
        for dist, parent_topics in common_parent_topics.items()
    }
    
    common_parent_topics_scores = {
        dist: sorted([
            (title, np.mean([0.0] + [{**common_parent_topics_counts[0], **common_parent_topics_counts[1]}.get(tt.lower(), 0) for tt in t]))
            for title, t in parent_topics.items()
        ], key=lambda x: x[1], reverse=True)
        for dist, parent_topics in common_parent_topics.items()
    }
    
    # choose single words with uppercase
    # for the others, join pluralize and singularize and check if we find the high level topics
    
    # pprint(low_level_topics)
    # pprint(flattened_low_level_topics)
    pprint(primary_topics[1])
    pprint(common_parent_topics_scores[1])
    
    # add single word low level categories
    # single_word_topics = [list(topics)[0] for topics in low_level_topics if len(topics) == 1]
    # and len(split(topic)) == 1
    single_word_topics = set([topic for topic in flattened_low_level_topics if is_uppercase(topic)])
    topics[1] = topics[1].union(single_word_topics)
    
    for split_part_topic in flattened_low_level_topics - single_word_topics:
        print(split_part_topic.lower())
        for dist in distances:
            if dist < 2:
                continue
            # print([t.lower() for t in primary_topics[dist]])
            for t in primary_topics[dist]:
                if t.lower() == split_part_topic.lower():
                    topics[dist].add(t) # use t as it might be uppercased
            # if split_part_topic.lower() in [t.lower() for t in primary_topics[dist]]:
            #     topics[dist].add(split_part_topic)
    
    # [:int(topic_count/2)]
    # check 
    
    print("took %.2f seconds" % (time.time() - start))
    # pprint(topics)
    return topics

# pprint(find_topics(graph, 63030231)) # covid 19
# pprint(find_topics(graph, 62304)) # pfizer
pprint(find_topics(graph, 1092923)) # google

In [None]:
# deprecated
# todo: find other pages from the splitted words
    # covid_article = raw_pages.filter(F.col("page_title") == "COVID-19").limit(100)
    similar_articles = []
    for low_level_topic in flattened_low_level_topics:
        similar_articles.append(raw_pages.filter(F.lower("page_title").contains(low_level_topic.lower())).limit(100))
    similar_articles = union(similar_articles)
    
    # similar_articles = raw_pages.filter(F.lower("page_title").isin(flattened_low_level_topics))
    
    print(similar_articles.count())
    # similar_articles.show()
    
    # find the categories for similar pages
    other_page_topics = {dist: dict() for dist in distances}
    for similar_page_id in similar_articles.select("page_id").distinct().rdd.flatMap(lambda x: x).collect(): # .rdd.toLocalIterator():
        # print(similar_page_id)
        # similar_page_id = similar_page_row["page_id"]
        if similar_page_id not in g.nodes:
            continue
        for dist in distances:
            for similar_page_topic in [
                (n, graph.nodes[n]["title"]) for n in nx.descendants_at_distance(g, similar_page_id, distance=dist)
            ]:
                # print(similar_page_topic)
                if similar_page_topic not in other_page_topics[dist]:
                    other_page_topics[dist][similar_page_topic] = 0
                other_page_topics[dist][similar_page_topic] += 1
    
    other_page_topic_counts = {
        dist: sorted([(n, title, count) for (n, title), count in other_page_topics[dist].items()], key=lambda x: x[2], reverse=True)
        for dist in distances
    }
    top5_other_page_topic_counts = {
        dist: topics[:10]
        for dist, topics in other_page_topic_counts.items()
    }
    pprint(top5_other_page_topic_counts)
    # for dist in distances:
    #     (n, title, count) for (n, title), count in other_page_topics[dist].items()
    # pprint(other_page_topics)
    # other_page = {
    #     dist: [graph.nodes[n] for n in nx.descendants_at_distance(g, node, distance=dist)]
    #     for dist in range(1, max_level+1)
    # }
    
    # pprint(primary_topics)
    print("took %.2f seconds" % (time.time() - start))
    return topics
    # for dist in range(1, 5):
    #     high_level_categories = [graph.nodes[n] for n in nx.descendants_at_distance(graph, 63030231, distance=dist)]
    #     print("==== d %d" % dist)
    #     pprint(high_level_categories)

In [None]:
# deprecated
depth_limit = 4
n_categories = 5
page_id = 63030231 # covid 19
page_id = 11867 # germany
page_id = 24365 # porsche


all_covid_categories1 = list(nx.dfs_tree(graph, page_id, depth_limit=depth_limit))
# all_covid_categories2 = list(bfs_tree(graph, 63030231, depth_limit=depth_limit))
# print(len(all_covid_categories1), len(all_covid_categories2))
# assert len(all_covid_categories1) == len(all_covid_categories2)
all_covid_categories2 = bfs_tree2(graph, page_id, depth_limit=depth_limit)

# print("highest level root categories")
# pprint([(graph.nodes[n]["title"], depth, count) for n, depth, count in all_covid_categories2 if graph.out_degree(n) == 0])

# max_depth = np.max([depth for _, depth, count in all_covid_categories2])
# print("max depth", max_depth)

# print("highest level root categories")
# pprint([(graph.nodes[n]["title"], depth, count) for n, depth, count in all_covid_categories2 if 30 <= depth <= 60])

# print("lower level categories")
# pprint([(graph.nodes[n]["title"], depth, count) for n, depth, count in all_covid_categories2[:100]])
# print(covid_page_node) # node count really refers to how many categories does a page have assigned to it

# test = {depth: [] for depth in range(1, depth_limit)}
# for n, depth, count in all_covid_categories2:
#     if depth < 1:
#         continue
#     test[depth].append((n, count))

# pprint({depth: [(graph.nodes[n]["title"], n, count) for n, count in nodes] for depth, nodes in all_covid_categories2.items() if depth > 0})

pprint({
    depth: flatten([
        [(w.capitalize(), n) for w in split(graph.nodes[n]["title"])]
        for n, count in nodes
    ])[:n_categories] for depth, nodes in all_covid_categories2.items() if depth > 0
})
# pprint({depth: (graph.nodes[n]["title"], depth, count) for n, depth, count in all_covid_categories2[:100]])