In [1]:
!wget http://files.grouplens.org/datasets/movielens/ml-1m.zip
!unzip ml-1m.zip

--2021-02-16 09:12:08--  http://files.grouplens.org/datasets/movielens/ml-1m.zip
Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152
Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5917549 (5.6M) [application/zip]
Saving to: ‘ml-1m.zip’


2021-02-16 09:12:09 (7.00 MB/s) - ‘ml-1m.zip’ saved [5917549/5917549]

Archive:  ml-1m.zip
   creating: ml-1m/
  inflating: ml-1m/movies.dat        
  inflating: ml-1m/ratings.dat       
  inflating: ml-1m/README            
  inflating: ml-1m/users.dat         


In [2]:
import pandas as pd
import numpy as np
import dgl
import torch as th
import time

Using backend: pytorch


In [3]:
data = pd.read_csv('ml-1m/ratings.dat', sep='::', engine='python',
                   names=['src', 'dst', 'rating', 'ts'])

In [4]:
src = np.array(data['src'])
dst = np.array(data['dst'])
max_src_id = np.max(src)
dst = dst + max_src_id + 1
g = dgl.graph((src, dst))
g = dgl.to_bidirected(g)

In [5]:
num_walks = 1000
num_seeds = 1000
seeds = np.random.choice(max_src_id + 1, num_seeds, replace=False)
seeds = np.repeat(seeds, num_walks)

In [6]:
start = time.time()
paths = dgl.sampling.random_walk(g, seeds, length=6)[0]
paths = paths.reshape(num_seeds, num_walks, -1)
flat_paths = paths[:,:,1:].reshape(num_seeds, -1)
print('random walk: {:.4f}s'.format(time.time() - start))

k = 10
topk_eles = []
for row in flat_paths:
    uniq_eles, cnts = np.unique(row, return_counts=True)
    idx = np.argsort(cnts)
    uniq_eles, cnts = uniq_eles[idx], cnts[idx]
    topk_eles.append((uniq_eles[-k:], cnts[-k:]))
print('get topK: {:.4f}s'.format(time.time() - start))

random walk: 1.0646s
get topK: 1.5455s


In [8]:
print('seeds:', seeds)
for eles in topk_eles:
    print('topK elements:', eles[0])
    print('frequency:', eles[1])

seeds: [ 561  561  561 ... 3836 3836 3836]
topK elements: [9231 9664 8803 9522 8751 9938 9449 9657 9314 6689]
frequency: [33 34 35 35 35 35 37 40 41 42]
topK elements: [7271 6042 6960 7311 6649 9038 8052 6337 7658 7237]
frequency: [14 14 14 15 15 15 15 16 17 18]
topK elements: [6145 6048 8474 8437 9796 9196 7526 7638 7434 6397]
frequency: [28 29 29 29 30 30 31 33 37 44]
topK elements: [9792 8439 9937 9795 7687 9834 9230 9952 9449 6568]
frequency: [23 24 24 25 26 28 29 29 30 31]
topK elements: [7254 9155 7340 9459 8741 8757 7306 8845 7120 7770]
frequency: [45 47 47 48 48 49 49 52 54 58]
topK elements: [7520 7825 1680 7649 7135 8803 8899 8437 7658 6337]
frequency: [10 10 10 10 11 11 12 12 12 15]
topK elements: [8897 8437 6649 8838 9462 7237 7345 9072 8899 7266]
frequency: [27 27 27 28 29 29 29 31 31 32]
topK elements: [6976 6941 6359 8958 9569 6531 8121 6940 9240 9590]
frequency: [48 49 50 52 52 52 57 59 61 63]
topK elements: [6992 8436 6964 6951 6994 6949 7334 7306 7237 9108]
frequency:

frequency: [38 39 39 39 41 43 45 45 46 49]
topK elements: [8053 9134 9046 6634 6953 6960 6649 6152 6630 8899]
frequency: [16 16 16 16 17 17 17 19 20 25]
topK elements: [7432 7255 8957 6483 7239 2909 8041 7251 6521 8669]
frequency: [ 9  9  9  9  9  9 10 10 11 12]
topK elements: [7621 6337 9596 6418 7077 7261 6047 6634 7332 7428]
frequency: [12 12 12 14 14 14 14 15 16 16]
topK elements: [8069 9188 6580 6151 6264 8612 7168 7964 8899 9427]
frequency: [15 15 16 16 17 17 17 17 18 19]
topK elements: [9989 9785 9449 8724 9862 9605 9612 9796 9038 9834]
frequency: [20 20 21 21 21 21 21 23 26 31]
topK elements: [6151 6042 6945 9200 7076 8561 8437 7322 6521 6899]
frequency: [37 37 38 39 39 39 39 42 42 46]
topK elements: [6960 8757 7189 7237 7239 6945 6521 7345 8899 8957]
frequency: [14 14 14 14 14 15 15 15 16 17]
topK elements: [6301 8913 7255 9540 8658 7138 7256 6151 7319 8565]
frequency: [17 17 17 18 18 19 19 19 20 22]
topK elements: [9038 6953 7189 7019 7317 8436 7993 7245 8899 7234]
frequency: