In [1]:
import json
import gc
import random
import networkx as nx
from networkx.readwrite import json_graph
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import Dict, List
import scml
import mylib

In [2]:
# number of context words in the left and right directions
m = 8
iters = 20
limit = 0
weight_min = 7
choices_max = 1000
type2id = {
    "click": 0,
    "cart": 1,
    "order": 2,
}

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()

In [4]:
%%time
with open("input/vocab3.json") as f:
    id2label = [w for w, c in json.load(f)]
label2id = {k:v for v, k in enumerate(id2label)}
print(f"len(label2id)={len(label2id):,}")

len(label2id)=1,855,603
Wall time: 2.47 s


In [5]:
del id2label
gc.collect()

41

In [6]:
%%time
with open("input/graph.json") as f:
    g = nx.adjacency_graph(json.load(f))
print(g)

DiGraph with 3764159 nodes and 100953434 edges
Wall time: 9min 2s


# get neighbours

In [7]:
rows = []
_nodes = g.nodes
if limit>0:
    _nodes = list(_nodes)[:limit]
for center in tqdm(_nodes):
    center_word, center_type = tuple(center.split("_"))
    # going left: predecessor nodes
    for u, v, d in g.in_edges(center, data=True):
        if d["weight"]<weight_min:
            continue
        curr = u
        curr_word, curr_type = tuple(curr.split("_"))
        if curr_word!=center_word:
            rows.append({
                "center_word": label2id[center_word],
                "center_type": type2id[center_type],
                "outside_word": label2id[curr_word],
                "outside_type": type2id[curr_type],
            })
    # going right: successor nodes
    for u, v, d in g.out_edges(center, data=True):
        if d["weight"]<weight_min:
            continue
        curr = v
        curr_word, curr_type = tuple(curr.split("_"))
        if curr_word!=center_word:
            rows.append({
                "center_word": label2id[center_word],
                "center_type": type2id[center_type],
                "outside_word": label2id[curr_word],
                "outside_type": type2id[curr_type],
            })

100%|███████████████████████| 3764159/3764159 [02:40<00:00, 23437.99it/s]


# random walk

In [8]:
out_edges = {}
in_edges = {}
for i in range(iters):
    for center in tqdm(_nodes, desc=f"i={i}"):
        # going left: predecessor nodes
        center_word, center_type = tuple(center.split("_"))
        curr = center
        for _ in range(m):
            if curr not in in_edges:
                choices = []
                for u, v, d in g.in_edges(curr, data=True):
                    if len(choices)>=choices_max:
                        break
                    if d["weight"]<weight_min:
                        continue
                    choices.append((u, d["weight"]))
                in_edges[curr] = choices
            choices = in_edges[curr]
            if len(choices)==0:
                break
            curr = mylib.weighted_choice(choices)
            if curr==center:
                break
            curr_word, curr_type = tuple(curr.split("_"))
            if curr_word!=center_word:
                rows.append({
                    "center_word": label2id[center_word],
                    "center_type": type2id[center_type],
                    "outside_word": label2id[curr_word],
                    "outside_type": type2id[curr_type],
                })
        # going right: successor nodes
        curr = center
        for _ in range(m):
            if curr not in out_edges:
                choices = []
                for u, v, d in g.out_edges(curr, data=True):
                    if len(choices)>=choices_max:
                        break
                    if d["weight"]<weight_min:
                        continue
                    choices.append((v, d["weight"]))
                out_edges[curr] = choices
            choices = out_edges[curr]
            if len(choices)==0:
                break
            curr = mylib.weighted_choice(choices)
            if curr==center:
                break
            curr_word, curr_type = tuple(curr.split("_"))
            if curr_word!=center_word:
                rows.append({
                    "center_word": label2id[center_word],
                    "center_type": type2id[center_type],
                    "outside_word": label2id[curr_word],
                    "outside_type": type2id[curr_type],
                })

i=0: 100%|██████████████████| 3764159/3764159 [04:36<00:00, 13631.13it/s]
i=1: 100%|██████████████████| 3764159/3764159 [00:57<00:00, 65500.52it/s]
i=2: 100%|██████████████████| 3764159/3764159 [01:14<00:00, 50540.56it/s]
i=3: 100%|██████████████████| 3764159/3764159 [00:57<00:00, 65035.14it/s]
i=4: 100%|██████████████████| 3764159/3764159 [00:57<00:00, 65234.42it/s]
i=5: 100%|██████████████████| 3764159/3764159 [00:57<00:00, 65198.24it/s]
i=6: 100%|██████████████████| 3764159/3764159 [01:14<00:00, 50269.73it/s]
i=7: 100%|██████████████████| 3764159/3764159 [00:57<00:00, 65229.05it/s]
i=8: 100%|██████████████████| 3764159/3764159 [00:57<00:00, 65338.04it/s]
i=9: 100%|██████████████████| 3764159/3764159 [01:15<00:00, 50137.35it/s]
i=10: 100%|█████████████████| 3764159/3764159 [00:57<00:00, 65155.16it/s]
i=11: 100%|█████████████████| 3764159/3764159 [00:58<00:00, 64729.32it/s]
i=12: 100%|█████████████████| 3764159/3764159 [00:58<00:00, 64645.44it/s]
i=13: 100%|█████████████████| 3764159/

In [9]:
df = pd.DataFrame.from_records(rows)
more = len(df)
df.drop_duplicates(keep="first", inplace=True, ignore_index=True)
print(f"{more - len(df):,} rows dropped: duplicates")

60,717,060 rows dropped: duplicates


In [10]:
cols = ["center_word", "outside_word"]
df[cols] = df[cols].astype(np.int32)
cols = ["center_type", "outside_type"]
df[cols] = df[cols].astype(np.int8)
df.info()
assert (df["center_word"]!=df["outside_word"]).all()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40948896 entries, 0 to 40948895
Data columns (total 4 columns):
 #   Column        Dtype
---  ------        -----
 0   center_word   int32
 1   center_type   int8 
 2   outside_word  int32
 3   outside_type  int8 
dtypes: int32(2), int8(2)
memory usage: 390.5 MB


In [11]:
df.describe(percentiles=percentiles)

Unnamed: 0,center_word,center_type,outside_word,outside_type
count,40948900.0,40948900.0,40948900.0,40948900.0
mean,130645.2,0.4481594,50668.08,0.1488616
std,140962.9,0.6004859,84514.85,0.3961119
min,0.0,0.0,0.0,0.0
1%,588.0,0.0,32.0,0.0
5%,4074.0,0.0,305.0,0.0
10%,9586.0,0.0,918.0,0.0
20%,23585.0,0.0,3030.0,0.0
30%,41148.0,0.0,6492.0,0.0
40%,62366.0,0.0,11706.0,0.0


In [12]:
df.head(30)

Unnamed: 0,center_word,center_type,outside_word,outside_type
0,1,0,35,0
1,1,0,35,1
2,1,0,7,0
3,1,0,55,0
4,1,0,55,1
5,1,0,4,0
6,1,0,20,0
7,1,0,8,0
8,1,0,251,0
9,1,0,251,1


In [13]:
%%time
assert df.notna().all(axis=None)
df.to_parquet("output/pairs.parquet", index=False)

Wall time: 2.89 s


In [14]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")

Total time taken 0:38:41.824343
