In [1]:
import numpy as np

def link_prediction_split(graph_file, files, portions):
    """
    Divide a normal graph into a train split and several test splits for link prediction use.
    Each test split contains half true and half false edges.

    Parameters:
        graph_file (str): graph file
        files (list of str): file names,
            the first file is treated as train file
        portions (list of float): split portions
    """
    assert len(files) == len(portions)
    print(f"splitting graph {graph_file} into {', '.join([f for f in files])}")
    np.random.seed(1024)

    nodes = set()
    edges = set()
    portions = np.cumsum(portions, dtype=np.float32) / np.sum(portions)
    files = [open(file, "w") for file in files]
    num_edges = [0] * len(files)
    with open(graph_file, "r") as fin:
        for line in fin:
            u, v = line.split()[:2]
            nodes.update([u, v])
            edges.add((u, v))
            i = np.searchsorted(portions, np.random.rand())
            if i == 0:
                files[i].write(line)
            else:
                files[i].write(f"{u}\t{v}\t1\n")
            num_edges[i] += 1

    nodes = list(nodes)
    for file, num_edge in zip(files[1:], num_edges[1:]):
        for _ in range(num_edge):
            valid = False
            while not valid:
                u = nodes[int(np.random.rand() * len(nodes))]
                v = nodes[int(np.random.rand() * len(nodes))]
                valid = u != v and (u, v) not in edges and (v, u) not in edges
            file.write(f"{u}\t{v}\t0\n")
    for file in files:
        file.close()


In [2]:
link_prediction_split('data/full.txt', ['train8.txt', 'test2.txt'], [8, 2])

splitting graph data/mini.txt into mi8.txt, mi2.txt
