In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import bz2
import csv
import io
import re
import time
import json
import random
import requests
from tqdm import tqdm
import multiprocessing
import concurrent.futures
import pickle as pkl
import numpy as np
import networkx as nx
from functools import partial, reduce
from collections import Counter
from pathlib import Path
from pprint import pprint
from typing import List, Dict
import matplotlib.pyplot as plt
import lsde2021.csv as csvutil
import lsde2021.utils as utils
import lsde2021.topics as topics
from lsde2021.lang import singularize, pluralize
import lsde2021.download as dl
from pyspark.sql import SparkSession, DataFrame
import pyspark.sql.types as T
import pyspark.sql.functions as F

In [None]:
with open("../nvme/en_topics/all_page_ids.pkl", 'rb') as f:
    all_page_ids = pkl.load(f)
print(len(all_page_ids))

In [None]:
def find_topics_worker(page_ids, depth_limit, max_categories):
    graph = nx.read_gpickle(f"../nvme/en-category-tree-without-hidden.pkl")
    start = page_ids[0][0]
    
    bk_dir = Path("../nvme/en_topics/savepoints")
    bk_dir.mkdir(parents=True, exist_ok=True)
    
    save_every = 5_000
    results = dict()
    for i, page_id in page_ids:
        try:
            results[page_id] = topics.find_topics(page_id, g=graph, depth_limit=depth_limit, max_categories=max_categories)
        except Exception as e:
            print(e)
            pass
        if i >= save_every and i % save_every == 0:
            with open(bk_dir / f"page_topics_{start}_{i}.pkl", 'wb') as f:
                pkl.dump(results, f, protocol=pkl.HIGHEST_PROTOCOL)
    return results

In [None]:
n_parallel = 2
chunk_size = int(np.ceil(len(all_page_ids) / n_parallel))
tasks = []
start = time.time()

with concurrent.futures.ProcessPoolExecutor(max_workers=n_parallel) as executor:
    for worker_id in range(n_parallel):
        worker_page_ids = all_page_ids[worker_id * chunk_size: (worker_id + 1) * chunk_size]
        print(worker_page_ids[:5])
        print("worker %d got assigned %d page ids" % (worker_id, len(worker_page_ids)))
        tasks.append(executor.submit(partial(find_topics_worker, depth_limit=4, max_categories=5), worker_page_ids))

results = dict()
for worker_id, task in enumerate(tasks):
    results.update(task.result())
    print("worker %d done" % worker_id)

print("took %.2f hours" % ((time.time() - start)/(60*60)))
print(len(results))
pprint(list(results.items())[0])

In [None]:
with open(f"../nvme/en_topics/topics_final.pkl", 'wb') as f:
    pkl.dump(results, f, protocol=pkl.HIGHEST_PROTOCOL)

In [None]:
MAX_MEMORY = "30G"

spark = SparkSession \
    .builder \
    .appName("parse-wikipedia-sql-dumps") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config('spark.driver.maxResultSize', MAX_MEMORY) \
    .config('spark.ui.showConsoleProgress', 'false') \
    .getOrCreate()
sc = spark.sparkContext

In [None]:
schema = T.StructType([
    T.StructField('page_id', T.IntegerType(), False),
    T.StructField('topics1', T.ArrayType(T.StringType()), False),
    T.StructField('topics2', T.ArrayType(T.StringType()), False),
    T.StructField('topics3', T.ArrayType(T.StringType()), False),
    T.StructField('topics4', T.ArrayType(T.StringType()), False),
])

df = spark.createDataFrame([
    dict(page_id=page_id, topics1=topics.get(1, []), topics2=topics.get(2, []), topics3=topics.get(3, []), topics4=topics.get(4, []))
    for page_id, topics in results.items()
], schema)
df.show()

In [None]:
df.write.format("parquet").mode("overwrite").save(f"../nvme/en_topics/topics_final.parquet")

In [None]:
df.show()