# Summary

# Imports

In [None]:
import os
import runpy
import shutil
import subprocess
from pathlib import Path

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import yaml

In [None]:
pd.set_option("max_columns", 100)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('train_network')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
DEBUG = "CI" not in os.environ
if DEBUG:
    assert "SLURM_JOB_ID" not in os.environ
DEBUG

# Train network

In [None]:
DATAPKG_OUTPUT_PATH = Path(os.environ['DATAPKG_OUTPUT_DIR'])
DATAPKG_OUTPUT_PATH

In [None]:
def parse_slurm_runtime(sbatch_timelimit: str) -> float:
    runtime = 0
    multipliers = [3600, 60, 1]
    while sbatch_timelimit:
        sbatch_timelimit, _, value = sbatch_timelimit.rpartition(':')
        multipier = multipliers.pop()
        if value:
            runtime += int(value) * multipier
    return runtime

            
parse_slurm_runtime("1:20:11")

In [None]:
if DEBUG:
    runtime = 60  # 1m
else:
    runtime = parse_slurm_runtime(os.environ['SBATCH_TIMELIMIT']) - 1800  # Total time - 30m
runtime

In [None]:
from pagnn.training.dcn import Args, main

In [None]:
args = Args(
    root_path=OUTPUT_PATH,
    training_data_path=(
        DATAPKG_OUTPUT_PATH
        .joinpath("adjacency-net-v2", "master", "training_dataset", "adjacency_matrix.parquet")
    ),
#     training_data_cache=(
#         DATAPKG_OUTPUT_PATH
#         .joinpath("adjacency-net-v2", "master", "training_dataset", f"array_id_{Args().array_id}")
#     ),
    gpu=-1,
    verbosity=1,
    network_name=f"DCN_{OUTPUT_PATH.name}",
    num_negative_examples=63,
)

assert isinstance(args.runtime, float)
args.runtime = runtime

args

In [None]:
args_file = OUTPUT_PATH.joinpath("args.yaml")

with args_file.open("wt") as fout:
    yaml.dump(args.to_dict(), fout, default_flow_style=False)

In [None]:
model_file = OUTPUT_PATH.joinpath("model.py")

with Path("../src/model.py").open('rt') as fin, model_file.open("wt") as fout:
    sub1, sub2 = False, False
    for line in fin:
        if line == "class Custom(nn.Module):\n":
            sub1 = True
            fout.write(f"class {args.network_name}(nn.Module):\n")
        elif line == "pagnn.models.dcn.Custom = Custom\n":
            sub2 = True
            fout.write(f"pagnn.models.dcn.{args.network_name} = {args.network_name}\n")
        else:
            fout.write(line)
    assert sub1 and sub2
    
runpy.run_path(model_file.as_posix(), globals())
None

In [None]:
main(args)