In [1]:
from ogb.linkproppred import PygLinkPropPredDataset
import pandas as pd
from tqdm import trange, tqdm
import numpy as np
import torch
import sys
from collections import defaultdict
sys.path.append('/mnt/nfs/zhangtl/utils/')
from util import myout
import pickle as pkl
import json
import dgl.function as fn

import dgl



## load

In [2]:
dataset = PygLinkPropPredDataset(name='ogbl-citation2', root='../../../dataset/ogbl-citation2')
raw_graph = dataset[0]
raw_graph

Data(num_nodes=2927963, edge_index=[2, 30387995], x=[2927963, 128], node_year=[2927963, 1])

In [3]:
raw_graph.node_year.unique()

tensor([1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1912,
        1913, 1914, 1915, 1916, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924,
        1925, 1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936,
        1937, 1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948,
        1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1960,
        1961, 1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972,
        1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984,
        1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996,
        1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008,
        2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019])

In [4]:
dgl_graph = dgl.graph((raw_graph.edge_index[0, :], raw_graph.edge_index[1, :]), num_nodes=int(raw_graph.num_nodes))
dgl_graph.ndata['feat'] = raw_graph.x
dgl_graph.ndata['raw_nid'] = torch.arange(len(raw_graph.x))

dgl_graph.edata['ts'] = raw_graph.node_year[raw_graph.edge_index[0,:]].squeeze(1)
dgl_graph

Graph(num_nodes=2927963, num_edges=30387995,
      ndata_schemes={'feat': Scheme(shape=(128,), dtype=torch.float32), 'raw_nid': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'ts': Scheme(shape=(), dtype=torch.int64)})

## to csv

In [5]:
start_year, end_year = 2000, 2020
ts_eids = dgl_graph.filter_edges(lambda x: (x.data['ts']>=start_year) & (x.data['ts']<end_year))
ts_graph = dgl.edge_subgraph(dgl_graph, ts_eids)
ts_graph

Graph(num_nodes=2469122, num_edges=24822197,
      ndata_schemes={'feat': Scheme(shape=(128,), dtype=torch.float32), 'raw_nid': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'ts': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)})

In [6]:
df = pd.DataFrame({'source': ts_graph.edges()[0].tolist(), 'target': ts_graph.edges()[1].tolist(), \
    'ts': ts_graph.edata['ts'].tolist()})
df = df.sort_values(by=['ts', 'source'])
df

Unnamed: 0,source,target,ts
1806,99,1850904,2000
1807,99,1850916,2000
1808,99,1850917,2000
1809,99,1850918,2000
1810,99,1850919,2000
...,...,...,...
24822171,1850556,1850568,2019
24822172,1850556,1850557,2019
24822173,1850557,1850568,2019
24822183,1850567,1850556,2019


## build graph

In [7]:
def update_idx(idx, dic, cnt, feats, no_emb):
    if idx not in dic:
        dic[idx] = cnt
        cnt += 1
        feats.append(ts_graph.ndata['feat'][idx, :].to(torch.float32))
    return dic, cnt, feats, no_emb

In [8]:
id2nid, cnt, no_emb = {}, 0, 0
lst, feats = [], []

for ii in trange(len(df)):
    year = int(df['ts'].iloc[ii])
    source = int(df['source'].iloc[ii])
    target = int(df['target'].iloc[ii])
    
    id2nid, cnt, feats, no_emb = update_idx(source, id2nid, cnt, feats, no_emb)
    id2nid, cnt, feats, no_emb = update_idx(target, id2nid, cnt, feats, no_emb)
    
    lst.append((id2nid[source], id2nid[target], year))
    
feat = torch.stack(feats)
src = torch.tensor([item[0] for item in lst])
tgt = torch.tensor([item[1] for item in lst])
tsp = torch.tensor([item[2] for item in lst])
# tsp = torch.tensor([item[3] for item in lst])

myout(feat, src, tgt, tsp, id2nid, no_emb)

100%|██████████| 24822197/24822197 [09:50<00:00, 42012.98it/s]


feat : shape=torch.Size([2469122, 128])
tensor([[-0.0484,  0.1094, -0.0697,  ..., -0.2037,  0.0151, -0.0495],
        [ 0.0445, -0.2518, -0.2032,  ...,  0.0807, -0.2814, -0.3711],
        [-0.1243, -0.1694, -0.0672,  ...,  0.1012,  0.0278, -0.0913],
        ...,
        [-0.2082, -0.0208, -0.2475,  ...,  0.1836, -0.0607, -0.2489],
        [-0.1309, -0.0508, -0.2470,  ...,  0.1475, -0.1363, -0.1227],
        [-0.1423, -0.2114, -0.1148,  ...,  0.1184,  0.0162, -0.3453]])
src : shape=torch.Size([24822197]), tensor([      0,       0,       0,  ..., 2469121, 2469119, 2469120])
tgt : shape=torch.Size([24822197]), tensor([      1,       2,       3,  ..., 2469120, 2469118, 2469118])
tsp : shape=torch.Size([24822197]), tensor([2000, 2000, 2000,  ..., 2019, 2019, 2019])
id2nid : len=2469122, dict([99: 0, 1850904: 1, 1850916: 2, 1850917: 3, 1850918: 4, 1850919: 5, ...])
no_emb = 0


In [9]:
graph = dgl.graph((src, tgt), num_nodes=len(feat))
graph.ndata['feat'] = feat

nid2id = {vv: kk for kk, vv in id2nid.items()}
graph.ndata['raw_nid'] = torch.arange(len(feat))

# graph.edata['rel'] = rel
graph.edata['ts'] = tsp
graph

Graph(num_nodes=2469122, num_edges=24822197,
      ndata_schemes={'feat': Scheme(shape=(128,), dtype=torch.float32), 'raw_nid': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'ts': Scheme(shape=(), dtype=torch.int64)})

In [10]:
dataset = 'ogb'
dgl.save_graphs(f'../data/{dataset}/graph.bin', [graph])

In [11]:
json.dump(id2nid, open(f'../data/{dataset}/id2nid.json', 'w'))

## gen cites

In [13]:
cites = {}
for year in range(start_year, end_year):
    cites[year] = defaultdict(int)

for ii in trange(len(df)):
    year = int(df['ts'].iloc[ii])
    target = int(df['target'].iloc[ii])
    cites[year][target] += 1
myout(cites[2005])

100%|██████████| 24822197/24822197 [05:53<00:00, 70215.87it/s]

 : len=313384, dict([801138: 1, 1850626: 5, 1850627: 3, 1850635: 1, 1850636: 1, 829492: 2, ...])





In [14]:
tsp = graph.edata['ts']
ts_vals, ts_cuts = np.unique(tsp.numpy(), return_index=True)
ts_cuts = list(ts_cuts) + [len(tsp.numpy())]

num_ts = len(ts_vals)
ts_infos = np.stack([ts_vals, ts_cuts[0:num_ts], ts_cuts[1:num_ts+1]]).transpose()
myout(ts_cuts, ts_vals, ts_infos)

ts_cuts : len=21, list([0, 376616, 814976, ..., 20964706, 22904816, 24822197])
ts_vals : shape=(20,), [2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013
 2014 2015 2016 2017 2018 2019]
ts_infos : shape=(20, 3)
[[    2000        0   376616]
 [    2001   376616   814976]
 [    2002   814976  1333195]
 [    2003  1333195  1938661]
 [    2004  1938661  2619256]
 [    2005  2619256  3396342]
 [    2006  3396342  4303030]
 [    2007  4303030  5303742]
 [    2008  5303742  6404886]
 [    2009  6404886  7593859]
 [    2010  7593859  8864377]
 [    2011  8864377 10274836]
 [    2012 10274836 11857173]
 [    2013 11857173 13563194]
 [    2014 13563194 15299791]
 [    2015 15299791 17116622]
 [    2016 17116622 19016770]
 [    2017 19016770 20964706]
 [    2018 20964706 22904816]
 [    2019 22904816 24822197]]


In [15]:
labels = {}
nid2id = {v:k for k,v in id2nid.items()}
for year in range(start_year, end_year):
    left, right = ts_infos[np.where(ts_infos[:, 0]==year)[0][0], 1:]
    nids = graph.edges()[0][left:right].unique().tolist()
    ids = [nid2id[nid] for nid in nids]
    
    pdf = pd.DataFrame({'id': ids, 'nid': nids})
    tbar = trange(year+1, end_year, desc=str(year))
    for yy in tbar:
        cdf = pd.DataFrame({'id': list(cites[yy].keys()), str(yy): list(cites[yy].values())})
        cdf[str(yy)] = cdf[str(yy)].astype('float32')
        
        pdf = pd.merge(pdf, cdf, how='left', on='id')
        tbar.set_postfix(year=year, pdf=len(pdf))
    pdf.fillna(0, inplace=True)
    labels[year] = pdf

2000: 100%|██████████| 19/19 [00:04<00:00,  4.21it/s, pdf=51919, year=2000]
2001: 100%|██████████| 18/18 [00:04<00:00,  4.38it/s, pdf=54952, year=2001]
2002: 100%|██████████| 17/17 [00:03<00:00,  5.11it/s, pdf=61443, year=2002]
2003: 100%|██████████| 16/16 [00:03<00:00,  4.01it/s, pdf=67807, year=2003]
2004: 100%|██████████| 15/15 [00:03<00:00,  3.97it/s, pdf=73108, year=2004]
2005: 100%|██████████| 14/14 [00:03<00:00,  3.95it/s, pdf=79071, year=2005]
2006: 100%|██████████| 13/13 [00:03<00:00,  3.59it/s, pdf=87737, year=2006]
2007: 100%|██████████| 12/12 [00:03<00:00,  3.69it/s, pdf=92101, year=2007]
2008: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s, pdf=95860, year=2008]
2009: 100%|██████████| 10/10 [00:02<00:00,  3.88it/s, pdf=96028, year=2009]
2010: 100%|██████████| 9/9 [00:02<00:00,  3.45it/s, pdf=98054, year=2010]
2011: 100%|██████████| 8/8 [00:02<00:00,  3.42it/s, pdf=104023, year=2011]
2012: 100%|██████████| 7/7 [00:02<00:00,  3.11it/s, pdf=109187, year=2012]
2013: 100%|█████

In [16]:
labels[2005]

Unnamed: 0,id,nid,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019
0,11,698329,1.0,3.0,2.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0
1,1457402,698331,0.0,4.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,2.0,0.0,0.0,0.0,0.0
2,27,698332,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,48,698333,1.0,1.0,0.0,1.0,0.0,1.0,1.0,2.0,2.0,0.0,0.0,0.0,0.0,0.0
4,53,698335,0.0,1.0,2.0,0.0,2.0,1.0,4.0,3.0,1.0,1.0,1.0,2.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79066,1850488,806331,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
79067,1850563,806332,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
79068,1850565,806333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
79069,1850578,806334,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [17]:
pkl.dump(labels, open(f'../data/{dataset}/labels.pkl', 'wb'))

## cum log

In [18]:
def cumulative_log(df):
    colsn = list(df.columns)
    for i in range(3, len(colsn)):
        df[colsn[i]] += df[colsn[i-1]]
    df.iloc[:, 2:] = np.log(df.iloc[:, 2:] + 1)
    return df

labels_cum_log = {}
for year in range(start_year, end_year-1):
    labels_cum_log[year] = cumulative_log(labels[year])
labels_cum_log[end_year-2] = labels[end_year-2]
print(len(labels_cum_log))
labels_cum_log[2005]

19


Unnamed: 0,id,nid,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019
0,11,698329,0.693147,1.609438,1.945910,1.945910,2.079442,2.197225,2.197225,2.197225,2.197225,2.302585,2.397895,2.397895,2.397895,2.397895
1,1457402,698331,0.000000,1.609438,1.609438,1.609438,1.609438,1.609438,1.791759,1.945910,1.945910,2.197225,2.197225,2.197225,2.197225,2.197225
2,27,698332,0.000000,0.693147,0.693147,0.693147,0.693147,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612
3,48,698333,0.693147,1.098612,1.098612,1.386294,1.386294,1.609438,1.791759,2.079442,2.302585,2.302585,2.302585,2.302585,2.302585,2.302585
4,53,698335,0.000000,0.693147,1.386294,1.386294,1.791759,1.945910,2.397895,2.639057,2.708050,2.772589,2.833213,2.944439,2.995732,3.044523
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79066,1850488,806331,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
79067,1850563,806332,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
79068,1850565,806333,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
79069,1850578,806334,0.693147,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612,1.098612


In [None]:
pkl.dump(labels_cum_log, open(f'../data/{dataset}/labels_cum_log.pkl', 'wb'))