In [1]:
import time
from pathlib import Path

import roach
from relbench.base import TaskType
from relbench.datasets import get_dataset_names
from relbench.tasks import get_task, get_task_names

In [2]:
project = "relbench/2024-07-29"
queue_gpu = "relbench/2024-07-29_gpu"
queue_cpu = "relbench/2024-07-29_cpu"

In [3]:
task_files = {}

In [4]:
def join(task_keys):
    tests = []
    for task_key in task_keys:
        task_file = task_files[task_key]
        task_file = task_file.replace("/ready/", "/done/")
        test = f"test -f {task_file}"
        tests.append(test)
    test = " && ".join(tests)
    return test

# training runs

In [5]:
for seed in range(5):
    for dataset in get_dataset_names():
        for task in get_task_names(dataset):
            task_obj = get_task(dataset, task)
            if task_obj.task_type.value == TaskType.LINK_PREDICTION.value:
                continue

            script = "gnn_node"
            for include_label_tables in [
                "all",
                "task_only",
                "none",
            ]:
                cmd = (
                    f"OMP_NUM_THREADS=8"
                    f" python {script}.py"
                    f" --dataset {dataset}"
                    f" --task {task}"
                    f" --seed {seed}"
                    f" --roach_project {project}"
                    f" --include_label_tables {include_label_tables}"
                )
                roach.submit(queue_gpu, cmd)
                
            script = "lightgbm_node"
            for use_ar_label_flag in [
                "--use_ar_label",
                "--no-use_ar_label",
            ]:
                cmd = (
                    f"OMP_NUM_THREADS=8"
                    f" python {script}.py"
                    f" --dataset {dataset}"
                    f" --task {task}"
                    f" --seed {seed}"
                    f" --roach_project {project}"
                    f" {use_ar_label_flag}"
                )
                roach.submit(queue_cpu, cmd)