In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import bz2
import csv
import io
import json
import random
import requests
import numpy as np
import networkx as nx
from pathlib import Path
from pprint import pprint
from typing import List, Dict
import matplotlib.pyplot as plt
import lsde2021.csv as csvutil
import lsde2021.utils as utils
import lsde2021.download as dl
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, LongType, IntegerType
import pyspark.sql.functions as F

In [None]:
MAX_MEMORY = "60G"

spark = SparkSession \
    .builder \
    .appName("parse-wikipedia-sql-dumps") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config('spark.driver.maxResultSize', MAX_MEMORY) \
    .config('spark.ui.showConsoleProgress', 'false') \
    .getOrCreate()
sc = spark.sparkContext

csv_loader = spark.read.format("csv").options(header='True', inferSchema='True')
parquet_reader = spark.read.format("parquet").options(inferSchema='True')

In [None]:
# join categories with english wiki page table
wiki = "enwiki"
raw_pages = parquet_reader.load(f"../nvme/wikipedia_sql_dumps/{wiki}/20211001/{wiki}-20211001-page.sql.parquet")
raw_categorylinks = parquet_reader.load(f"../nvme/wikipedia_sql_dumps/{wiki}/20211001/{wiki}-20211001-categorylinks.sql.parquet")

In [None]:
raw_pages.limit(10).show()
raw_categorylinks.limit(10).show()

In [None]:
pages = raw_pages \
    .filter((F.col("page_is_redirect") == 0)) \
    .filter((F.col("page_namespace") == 0) | (F.col("page_namespace") == 14)) \
    .select("page_id", "page_namespace", "page_title")

categorylinks = raw_categorylinks \
    .select("page_id", "category_name")

category_pages = pages \
    .filter(F.col("page_namespace") == 14) \
    .select(
        F.col("page_id").alias("category_page_id"),
        F.col("page_title").alias("category_name"),
    )

print(pages.count())

In [None]:
# find the categories of the page
# .limit(100_000) \
page_cats = pages \
    .join(categorylinks, on="page_id", how="inner")

# find the page_id for the categories
page_cats = page_cats \
    .join(category_pages, on="category_name", how="left")

page_cats.limit(10).show()

In [None]:
# count topic popularity by number of pages
duplicate_counts = page_cats \
    .groupby(["page_id"]) \
    .count()

page_cats = page_cats \
    .join(duplicate_counts, on="page_id", how="inner") \
    .sort('count', ascending=False) \

page_cats.limit(10).show()

In [None]:
# save the pages with category
page_cats.write.format("parquet").mode("overwrite").save(f"../nvme/wikipedia_sql_dumps/{wiki}/20211001/{wiki}-20211001-page-category-count.sql.parquet")

In [None]:
%%time
graph = nx.DiGraph()

max_size = None # 100_000
for i, row in enumerate(page_cats.rdd.toLocalIterator()):
    if i % ((max_size or 20_000_000) / 10) == 0:
        print("row", i)
        
    node = row["page_id"]
    node_count = row["count"]
    
    category_node = row["category_page_id"]
    is_category = False
    try:
        is_category = int(row["page_namespace"]) == 14
    except Exception:
        pass
    
    valid_node = node is not None and node is not np.nan
    valid_category_node = category_node is not None and category_node is not np.nan
    
    # add page node
    if valid_node:
        if node not in graph.nodes:
            graph.add_node(node, is_category=is_category, title=row["page_title"], node_count=node_count)
        else:
            graph.update(nodes={
                node: dict(is_category=is_category, title=row["page_title"], node_count=node_count)
            })
    
    # add category node
    if valid_category_node and category_node not in graph.nodes:
        graph.add_node(category_node, is_category=True, title=row["category_name"], node_count=0)
    
    # add the edge between them
    if valid_node and valid_category_node:
        graph.add_edge(node, category_node)
    
    if max_size is not None and i >= max_size:
        break

In [None]:
# save the graph for reuse
nx.write_gpickle(graph, f"../nvme/en-category-tree.pkl")

In [None]:
# save the graph for reuse
# graphml is too slow
# nx.write_graphml_lxml(graph, f"../nvme/en-category-tree.graphml")

In [None]:
# first have a closer look at some of the categories and how they look like so we can split them eventually
example_categories = page_cats.select("category_name").limit(1_000).rdd.flatMap(lambda x: x).collect()
pprint(example_categories[0:100])

In [None]:
labels = nx.get_node_attributes(graph, 'title')
colors = ["lightblue" if is_cat else "orange" for node, is_cat in nx.get_node_attributes(graph, 'is_category').items()]
plt.figure(figsize=(12,12)) 
pos = nx.spring_layout(graph)
_ = nx.draw_networkx_edges(graph, pos, alpha=0.2)
_ = nx.draw_networkx_nodes(graph, pos, label=labels, node_size=1000, node_color=colors)
_ = nx.draw_networkx_labels(graph, pos)