In [2]:
import numpy as np
import random
import pickle
from collections import deque


# ===============================
# Global seed (CRITICAL)
# ===============================
def set_global_seed(seed):
    random.seed(seed)
    np.random.seed(seed)


# ===============================
# Ensure connectivity
# ===============================
def is_connected(A):
    n = A.shape[0]
    visited = [False] * n
    q = deque([0])
    visited[0] = True

    while q:
        u = q.popleft()
        for v in range(n):
            if A[u, v] and not visited[v]:
                visited[v] = True
                q.append(v)

    return all(visited)


# ===============================
# Random connected graph
# ===============================
def random_connected_graph(n, p=0.3):
    while True:
        A = np.zeros((n, n), dtype=np.int8)

        # Random edges
        for i in range(n):
            for j in range(i + 1, n):
                if random.random() < p:
                    A[i, j] = 1
                    A[j, i] = 1

        if is_connected(A):
            break

    T = np.full(n, -1, dtype=np.int8)
    return A, T


# ===============================
# Dataset generation
# ===============================
def generate_graph_dataset():
    set_global_seed(42)

    TRAIN_GRAPHS = []
    TEST_GRAPHS = []

    # 5 training graphs
    for _ in range(5):
        n = random.randint(8, 12)
        TRAIN_GRAPHS.append(random_connected_graph(n))

    # 5 test graphs
    for _ in range(5):
        n = random.randint(10, 14)
        TEST_GRAPHS.append(random_connected_graph(n))

    return {
        "train": TRAIN_GRAPHS,
        "test": TEST_GRAPHS
    }


# ===============================
# Main
# ===============================
if __name__ == "__main__":
    dataset = generate_graph_dataset()

    with open("fixed_graph_dataset.pkl", "wb") as f:
        pickle.dump(dataset, f)

    print("Saved fixed_graph_dataset.pkl")

    # Quick sanity check
    print("\nDataset summary:")
    for split in ["train", "test"]:
        print(f"\n{split.upper()}:")
        for i, (A, T) in enumerate(dataset[split]):
            print(f"  Graph {i}: n={A.shape[0]}, edges={A.sum() // 2}")


Saved fixed_graph_dataset.pkl

Dataset summary:

TRAIN:
  Graph 0: n=8, edges=13
  Graph 1: n=12, edges=20
  Graph 2: n=12, edges=21
  Graph 3: n=9, edges=13
  Graph 4: n=12, edges=21

TEST:
  Graph 0: n=12, edges=22
  Graph 1: n=10, edges=12
  Graph 2: n=10, edges=13
  Graph 3: n=11, edges=14
  Graph 4: n=13, edges=26
