# Batch Run Example
This notebook shows how to load a large amount of data and execute runs on them in parallel

In [None]:
import json
import os
import sys
import tarfile
import asyncio
from glob import glob
import re
from collections import Counter

from pdbtools import *
import requests
from datetime import datetime
from pathlib import Path

import tengu

### 0) Setup

In [None]:
# Set our token - ensure you have exported TENGU_TOKEN in your shell; or just replace the os.getenv with your token
TOKEN = os.getenv("TENGU_TOKEN")
URL = os.getenv("TENGU_URL")

In [None]:
# Define our project information
DESCRIPTION = "tengu-py batch notebook"
TAGS = ["qdx", "tengu-py-v2", "demo", "batch", "nogpu"]
WORK_DIR = Path.home() / "qdx" / "tengu-py-batch-demo"
OUT_DIR = WORK_DIR / "runs"
OUT_DIR.mkdir(parents=True, exist_ok=True)
MODULE_LOCK = WORK_DIR / "lock.json"

# Set our inputs
PROTEIN_PDB_PATH = WORK_DIR / "protein.pdb"
LIGAND_PDB_FOLDER_PATH = WORK_DIR / "ligands" # should contain ligands named [name].pdb

TARGET="GADI"
NUM_GPUS=0

## Initialize our tengu client and fetch available module paths

In [None]:
# Get our client, for calling modules and using the tengu API
client = tengu.Provider(access_token=TOKEN, url=URL, workspace=WORK_DIR, batch_tags=TAGS)

In [None]:
# Get our latest modules as a dict[module_name, module_path]
# If a lock file exists, load it so that the run is reproducable
if MODULE_LOCK.exists():
    modules = client.load_module_paths(MODULE_LOCK)
else: 
    modules = await client.get_latest_module_paths()
    client.save_module_paths(modules, MODULE_LOCK)

In [None]:
fns = await client.get_module_functions()

## For each ligand, start a gmx + gmx_mmpbsa run

In [None]:
ligands = map(lambda x: Path(x), glob(str(LIGAND_PDB_FOLDER_PATH /"*.pdb")))

gmx_config = {
    "param_overrides": {
        "md": [("nsteps", "5000")],
        "em": [("nsteps", "1000")],
        "nvt": [("nsteps", "1000")],
        "npt": [("nsteps", "1000")],
        "ions": [],
    },
    "num_gpus": NUM_GPUS,
    "num_replicas": 1,
    "ligand_charge": None,
    "frame_sel": {
       "begin_time": 1,
       "delta_time": 1,
        "end_time": 2
    },  
}
gmx_resources = tengu.Resources(gpus= NUM_GPUS, cpus= 48, storage= 2, storage_units= "GB", walltime= 60)

mmpbsa_config = {
    "start_frame": 1,
    "end_frame": 10,
    "num_cpus": 8,  # cannot be greater than number of frames
}

mmpbsa_resources = tengu.Resources(storage= 2, storage_units= "GB", walltime= 600)

gmx_outputs = []
mmpbsa_outputs = []
for ligand_path in ligands:
    name = ligand_path.stem
    (gmx_output_tar, wet, dry, gmx_ligand_gro_tar, extra) = await client.gmx_pdb(
        PROTEIN_PDB_PATH,
        ligand_path,
        gmx_config,
        target=TARGET,
        resources=gmx_resources,
        tags= [ name ],
        restore = True
    )
    gmx_outputs.append((name, gmx_output_tar, wet, dry, gmx_ligand_gro_tar))
    
    (mmpbsa_output_tar,) = await client.gmx_mmpbsa(
        gmx_output_tar,
        mmpbsa_config,
        target=TARGET,
        resources=mmpbsa_resources,
        tags=[ name ],
        restore=True
    )
    mmpbsa_outputs.append((name, mmpbsa_output_tar))
    print(f"{datetime.now().time()} | Running GROMACS MM-PBSA calculation!")

17:34:49.768125 | Running GROMACS MM-PBSA calculation!


## Report progress
This will show the status of all of your runs

In [None]:
status = await client.status(group_by="path")
print(f"{'Module':<20} | {'Status':<20} | Count")
print("-" * 50)
for module, (status, path, count) in status.items():
    print(f"{path:<20} | {status:<20} | {count}")

Module               | Status               | Count
--------------------------------------------------
gmx_mmpbsa           | RESOLVING            | 2
gmx_pdb              | RESOLVING            | 1
gmx_pdb              | RUNNING              | 1
gmx_mmpbsa           | COMPLETED            | 1
gmx_pdb              | COMPLETED            | 3


## Download Results
This will retrieve results for your completed module_instances

In [None]:
await asyncio.gather(*[output[1].download(filename=f"mmpbsa_{output[0]}.tar.gz") for output in mmpbsa_outputs])

## Check failures
This will retrieve failed runs in your workspace history

In [None]:
for (instance_id, (status,name,count)) in (await client.status()).items():
    if status.value == "FAILED":
        async for log_page in client.logs(instance_id, "stderr"):
            for log in log_page:
                print(log)