In [1]:
import json
import gc
import random
import networkx as nx
from networkx.readwrite import json_graph
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import Dict, List
import scml
import mylib

In [2]:
# window length to the left or right context
m = 2
iters = 1

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything(31)

In [4]:
%%time
with open("input/vocab2.json") as f:
    id2label = [w for w, c in json.load(f)]
label2id = {k:v for v, k in enumerate(id2label)}
print(f"len(label2id)={len(label2id):,}")

len(label2id)=3,748,278
Wall time: 4.39 s


In [5]:
del id2label
gc.collect()

21

In [6]:
%%time
with open("input/graph.json") as f:
    g = nx.adjacency_graph(json.load(f))
print(g)

DiGraph with 3747723 nodes and 99004731 edges
Wall time: 8min 50s


In [9]:
rows = []
for i in range(iters):
    for center in tqdm(g.nodes, desc=f"i={i}"):
        curr = center
        for _ in range(m):
            choices = []
            for u, v, d in g.in_edges(curr, data=True):
                #print(f"u={u}, v={v}, d={d}")
                choices.append((u, d["weight"]))
            if len(choices)==0:
                break
            curr = mylib.weighted_choice(choices)
            rows.append({
                "center": label2id[center],
                "outside": label2id[curr]
            })
        curr = center
        for _ in range(m):
            choices = []
            for u, v, d in g.out_edges(curr, data=True):
                choices.append((v, d["weight"]))
            if len(choices)==0:
                break
            curr = mylib.weighted_choice(choices)
            rows.append({
                "center": label2id[center],
                "outside": label2id[curr]
            })

i=0: 100%|██████████████████| 3747723/3747723 [2:36:04<00:00, 400.19it/s]


In [14]:
df = pd.DataFrame.from_records(rows)
df.drop_duplicates(keep="first", inplace=True, ignore_index=True)
cols = ["center", "outside"]
df[cols] = df[cols].astype(np.int32)
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14461968 entries, 0 to 14461967
Data columns (total 2 columns):
 #   Column   Dtype
---  ------   -----
 0   center   int32
 1   outside  int32
dtypes: int32(2)
memory usage: 110.3 MB


In [15]:
df.describe(percentiles=percentiles)

Unnamed: 0,center,outside
count,14461970.0,14461970.0
mean,1846924.0,717386.1
std,1075523.0,856472.5
min,0.0,0.0
1%,36463.0,205.0
5%,182536.0,3050.0
10%,365494.0,11123.0
20%,732046.4,43434.0
30%,1099470.0,103357.0
40%,1467838.0,201465.0


In [16]:
%%time
assert df.notna().all(axis=None)
df.to_parquet("output/pairs.parquet", index=False)

Wall time: 421 ms


In [17]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")

RuntimeError: Not started