In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import bz2
import csv
import io
import json
import time
import urllib
import random
import requests
import concurrent.futures
import pickle as pkl
import numpy as np
from pathlib import Path
from pprint import pprint
from typing import List, Dict
import lsde2021.csv as csvutil
import lsde2021.utils as utils
import lsde2021.download as dl
from pyspark.sql import SparkSession
import pyspark.sql.types as T
import pyspark.sql.functions as F

In [None]:
MAX_MEMORY = "60G"

spark = SparkSession \
    .builder \
    .appName("parse-wikipedia-sql-dumps") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config('spark.driver.maxResultSize', MAX_MEMORY) \
    .config('spark.dynamicAllocation.maxExecutors', 4) \
    .config('spark.ui.showConsoleProgress', 'false') \
    .getOrCreate()
sc = spark.sparkContext

In [None]:
wiki = "enwiki"
pages = spark.read.format("parquet").options(inferSchema='True').load(f"../nvme/wikipedia_sql_dumps/{wiki}/20211001/{wiki}-20211001-page.sql.parquet")
pages.limit(10).show()

In [None]:
en_page_ids = pages \
    .filter(F.col("page_id").isNotNull() & (F.col("page_namespace") == 0)) \
    .select("page_id") \
    .distinct() \
    .withColumn("page_id", F.col("page_id").cast(T.IntegerType())) \
    .sort('page_id', ascending=True) \
    .rdd.flatMap(lambda x: x).collect()

In [None]:
print(len(en_page_ids))
pprint(en_page_ids[:20])
assert isinstance(en_page_ids[0], int)

ores_topics_dir = Path("../nvme/ores_topics")
ores_topics_dir.mkdir(parents=True, exist_ok=True)

with open(ores_topics_dir / "en_page_ids_sorted.pkl", 'wb') as f:
    pkl.dump(en_page_ids, f, protocol=pkl.HIGHEST_PROTOCOL)

In [None]:
def get_page_rev_ids(page_ids: List[int]) -> Dict[int, int]:
    page_ids_str = '|'.join(map(str, page_ids))
    revids = dict()
    with requests.get(f"https://en.wikipedia.org/w/api.php?action=query&prop=revisions&pageids={page_ids_str}&format=json") as r:
        r.raise_for_status()
        content = json.loads(r.content)
        if "query" in content:
            if "pages" in content["query"]:
                pages = content["query"]["pages"].items()
                for page_id, metadata in pages:
                    try:
                        if "revisions" in metadata and len(metadata["revisions"]) >= 1:
                            revisions = metadata["revisions"][0]
                            revids[int(page_id)] = int(revisions.get("revid", None))
                    except Exception:
                        pass
        return revids
    
pprint(get_page_rev_ids([604727, 604728]))

In [None]:
class ORESException(ValueError):
    def __init__(self, message):
        super().__init__(message)

def get_ores_articletopics(context: str, models: List[str], rev_ids: List[int]) -> Dict[int, int]:
    url = "https://ores.wikimedia.org/v3/scores/{0}/".format(urllib.parse.quote(context))

    params = {'revids': "|".join(str(rid) for rid in rev_ids),
              'models': "|".join(urllib.parse.quote(model) for model in models)}
    
    headers = {"User-Agent": random.choice(dl.USER_AGENTS)}
    with requests.get(url, params=params, headers=headers) as r:
        r.raise_for_status()
        content = json.loads(r.content)
        
        if 'error' in content:
            raise ORESException(content['error'])
        if 'warnings' in content:
            for warning in content['warnings']:
                print(warning)
        
        return [content[context]['scores'][str(rev_id)] for rev_id in rev_ids]
    
get_ores_articletopics(context="enwiki", models=["articletopic"], rev_ids=["1050929646"])

In [None]:
min_prob=0.6
max_topics=5
retries=10

def get_ores_articletopics_for_page_ids (page_ids):
    all_rev_ids = dict()
    all_topics = dict()
    for i in range(0, len(page_ids), 50):
        attempts, success = 0, False
        while not success and attempts < retries:
            attempts += 1
            try:
                rev_ids = get_page_rev_ids(page_ids[i:i+50]).items()
                scores = list(get_ores_articletopics(context="enwiki", models=["articletopic"], rev_ids=[rid for _, rid in rev_ids]))
                for (page_id, rev_id), score in zip(rev_ids, scores):
                    all_rev_ids[page_id] = rev_id
                    if "articletopic" in score:
                        response = score["articletopic"]
                        if "error" not in response and "score" in response:
                            if "probability" in response["score"]:
                                topic_probs = response["score"]["probability"]
                                topic_probs = sorted(topic_probs.items(), key=lambda t: t[1], reverse=True)
                                topic_probs = [t for t, prob in topic_probs if prob > min_prob]
                                all_topics[page_id] = topic_probs[:max_topics]
                                success = True
                        else:
                            print("bad response", response)
            except Exception as e:
                # raise e
                print("error", e)
    return all_rev_ids, all_topics

In [None]:
n_parallel = 2
start = 0 * 100_000
count = 100_000
chunk_size = int(np.ceil(count / n_parallel))
all_revids = dict()
all_topics = dict()
procs = []

with concurrent.futures.ProcessPoolExecutor(max_workers=n_parallel) as executor:
    for worker_id in range(n_parallel):
        worker_page_ids = en_page_ids[start + worker_id * chunk_size: start + (worker_id + 1) * chunk_size]
        print(worker_page_ids[:10])
        print("worker %d got assigned %d page ids" % (worker_id, len(worker_page_ids)))
        procs.append(executor.submit(get_ores_articletopics_for_page_ids, worker_page_ids))

# collect the results
for i, proc in enumerate(procs):
    cur_revids, cur_topics = proc.result()
    all_revids.update(cur_revids)
    all_topics.update(cur_topics)
    print("worker %d done" % i)

# save result to pickle
revids_file = ores_topics_dir / ("revids_%d_to_%d.pkl" % (start, start+count))
topics_file = ores_topics_dir / ("topics_%d_to_%d.pkl" % (start, start+count))
print(revids_file)
print(topics_file)

with open(revids_file, 'wb') as f:
    pkl.dump(all_revids, f, protocol=pkl.HIGHEST_PROTOCOL)
with open(topics_file, 'wb') as f:
    pkl.dump(all_topics, f, protocol=pkl.HIGHEST_PROTOCOL)

print(len(all_revids))
# https://ores.wikimedia.org/v3/scores/enwiki?models=articletopic&revids=421063984

In [None]:
pages_with_topics.write.format("parquet").mode("overwrite") \
    .partitionBy("ores_topic1").save(f"../nvme/wikipedia_sql_dumps/{wiki}/20211001/{wiki}-20211001-page-ores-topics.sql.parquet")