# Batch Run Example
This notebook shows how to load a large amount of data and execute runs on them in parallel

In [None]:
import json
import os
import sys
import tarfile
from glob import glob
import re
from collections import Counter

from pdbtools import *
import requests
from datetime import datetime
from pathlib import Path

import tengu

### 0) Setup

In [None]:
# Set our token - ensure you have exported TENGU_TOKEN in your shell; or just replace the os.getenv with your token
TOKEN = os.getenv("TENGU_TOKEN")
URL = os.getenv("TENGU_URL")

In [None]:
# Get our client, for calling modules and using the tengu API
client = tengu.Provider(access_token=TOKEN, url=URL)

In [None]:
# Define our project information
DESCRIPTION = "tengu-py batch notebook"
TAGS = ["qdx", "tengu-py", "demo", "batch", "nogpu"]
WORK_DIR = Path.home() / "qdx" / "tengu-py-batch-demo"
OUT_DIR = WORK_DIR / "runs"
OUT_DIR.mkdir(parents=True, exist_ok=True)
MODULE_LOCK = WORK_DIR / "lock.json"

# Set our inputs
PROTEIN_PDB_PATH = WORK_DIR / "protein.pdb"
LIGAND_PDB_FOLDER_PATH = WORK_DIR / "ligands" # should contain ligands named [name].pdb

TARGET="GADI"
NUM_GPUS=0

## Initialize our tengu client and fetch available module paths

In [None]:
# Get our client, for calling modules and using the tengu API
client = tengu.Provider(access_token=TOKEN, url=URL)

In [None]:
# Get our latest modules as a dict[module_name, module_path]
# If a lock file exists, load it so that the run is reproducable
if MODULE_LOCK.exists():
    modules = client.load_module_paths(MODULE_LOCK)
else: 
    modules = client.get_latest_module_paths()
    client.save_module_paths(modules, MODULE_LOCK)

## For each ligand, start a gmx + gmx_mmpbsa run

In [None]:
ligands = map(lambda x: Path(x), glob(str(LIGAND_PDB_FOLDER_PATH /"*.pdb")))

gmx_config = {
    "param_overrides": {
        "md": [("nsteps", "5000")],
        "em": [("nsteps", "1000")],
        "nvt": [("nsteps", "1000")],
        "npt": [("nsteps", "1000")],
        "ions": [],
    },
    "num_gpus": NUM_GPUS,
    "num_replicas": 1,
    "ligand_charge": None,
    "frame_sel": {
       "begin_time": 1,
       "delta_time": 1,
        "end_time": 2
    },  
}
gmx_resources = {"gpus": NUM_GPUS, "cpus": 48, "storage": 2, "storage_units": "GB", "walltime": 60}

mmpbsa_config = {
    "start_frame": 1,
    "end_frame": 10,
    "num_cpus": 8,  # cannot be greater than number of frames
}

mmpbsa_resources = {"storage": 2, "storage_units": "GB", "walltime": 600}

for ligand_path in ligands:
    name = ligand_path.stem
    gmx_result = client.run2(
        modules["gmx_tengu_pdb"],
        [
            PROTEIN_PDB_PATH,
            ligand_path,
            gmx_config,
        ],
        target=TARGET,
        resources=gmx_resources,
        tags=TAGS + [ name ],
        restore = True
    )
    gmx_run_id = gmx_result["module_instance_id"]
    gmx_output_id = gmx_result["output_ids"][0]
    gmx_ligand_gro_id = gmx_result["output_ids"][3]
    
    # save gmx run ids
    with open(OUT_DIR / f"02-gmx-{name}-{gmx_run_id}.json", "w") as f:
        json.dump(gmx_result, f, default=str, indent=2)
    
    mmpbsa_result = client.run2(
        modules["gmx_mmpbsa_tengu"],
        [
            gmx_output_id,
            mmpbsa_config,
        ],
        target=TARGET,
        resources=mmpbsa_resources,
        tags=TAGS + [name],
        restore=True
    )
    mmpbsa_run_id = mmpbsa_result["module_instance_id"]
    mmpbsa_output_id = mmpbsa_result["output_ids"][0]
    print(f"{datetime.now().time()} | Running GROMACS MM-PBSA calculation!")
    
    # save gmx_mmpbsq run ids
    with open(OUT_DIR / f"03-mmpbsa-{name}-{mmpbsa_run_id}.json", "w") as f:
        json.dump(mmpbsa_result, f, default=str, indent=2)

## Report progres
This will show the status of all of your runs

In [None]:
instance_pages = client.module_instances(tags=TAGS)
instances = [instance for page in instance_pages for instance in page]
c = Counter([(instance["status"], instance["path"].split("#")[1]) for instance in instances])
print(f"{'Module':<10} | {'Status':<20} | Count")
print("-" * 50)
for (module, status), count in c.items():
    print(f"{module:<10} | {status:<20} | {count}")

## Download Results
This will retrieve results for your completed module_instances

In [None]:
completed_instance_pages = client.module_instances(tags=TAGS, status="COMPLETED")
completed_instances = [instance for page in completed_instance_pages for instance in page]
for instance in completed_instances:
    id = instance["id"]
    print(f"{id} completed!")
    if instance["path"].split("#")[1] == "gmx_mmpbsa_tengu":
        # get name of ligand
        run_file = list(glob(f"{OUT_DIR}/03-mmpbsa-*-{id}.json"))[0]
        regex_pattern = rf"{OUT_DIR}/03-mmpbsa-([^-]*)-{id}\.json"
        match = re.search(regex_pattern, run_file)
        if match:
            name = match.group(1)
            client.download_object(mmpbsa_output_id, OUT_DIR / f"03-mmpbsa-output-{name}-{id}.tar.gz")
            print(f"{datetime.now().time()} | Downloaded MM-PBSA results for {name}-{id}!")
        else:
            tags = instance["tags"]
            print(f"{datetime.now().time()} | Failed to find run file for mmpbsa results with tags {tags}!")

## Check failures
This will retrieve failed runs with your specified tags

In [None]:
failed_instance_pages = client.module_instances(tags=TAGS, status="FAILED")
failed_instances = [instance for page in failed_instance_pages for instance in page]
for instance in failed_instances:
    id = instance["id"]
    module_name = instance["path"].split("#")[1]
    print(f"{name}: {id}  failed!")
    stderr_logs = str.join("\n", [line for node in client.module_instance(id)["stderr"]["nodes"] for line in node["content"]])
    print(stderr_logs)
    stdout_logs = str.join("\n", [line for node in client.module_instance(id)["stdout"]["nodes"] for line in node["content"]])
    print(stdout_logs)
