In [None]:
!export RAPIDS_NO_INITIALIZE="1"
!export CUDF_SPILL="1"
!export LIBCUDF_CUFILE_POLICY="OFF"

from cugraph_bulk_sampling import start_dask_client, benchmark_cugraph_bulk_sampling, load_disk_dataset, construct_graph
from cugraph_bulk_sampling import sample_graph
import os

# Setup Cluster

In [None]:
dask_worker_devices='0,1,2,3,4,5,6,7'

In [None]:
client, cluster = start_dask_client(dask_worker_devices=dask_worker_devices,
                                    jit_unspill=False,
                                    rmm_pool_size=28e9,
                                    rmm_async=True)

# Setup Benchmark

In [None]:
dataset='ogbn_papers100M'
dataset_root="."
output_root="."
reverse_edges=True
add_edge_types=False
batch_size=512
seeds_per_call=524288
fanout=[25,25]
replication_factor=2
seed=123

dataset_dir=dataset_root
output_path=output_root
persist=False


In [None]:
dask_edgelist_df, dask_label_df, node_offsets, edge_offsets, total_num_nodes = \
    load_disk_dataset(
        dataset,
        dataset_dir=dataset_dir,
        reverse_edges=reverse_edges,
        replication_factor=replication_factor,
        persist=False,
        add_edge_types=add_edge_types
    )
num_input_edges = len(dask_edgelist_df)
print(
f"Number of input edges = {num_input_edges:,}"
)

G = construct_graph(
dask_edgelist_df
)
del dask_edgelist_df
print('constructed graph')

In [None]:
input_memory = G.edgelist.edgelist_df.memory_usage().sum().compute()
print(f'input memory: {input_memory}')

output_subdir = os.path.join(output_path, f'{dataset}[{replication_factor}]_b{batch_size}_f{fanout}')
os.makedirs(output_subdir, exist_ok=True)

output_sample_path = os.path.join(output_subdir, 'samples')
os.makedirs(output_sample_path,  exist_ok=True)

batches_per_partition = 200_000 // batch_size

# Benchmarking Sample Graph

In [None]:
#%%timeit -n10 -r1


execution_time, allocation_counts = sample_graph(
    G,
    dask_label_df,
    output_sample_path,
    seed=seed,
    batch_size=batch_size,
    seeds_per_call=seeds_per_call,
    batches_per_partition=batches_per_partition,
    fanout=fanout,
    persist=persist,
)
