# GMX MD + MMPBSA
## First import dependencies

In [None]:
import json
import os
import sys
import tarfile

from datetime import datetime
from pathlib import Path

import tengu

## Set our token

In [None]:
# Set our token - ensure you have exported TENGU_TOKEN in your shell; or just replace the os.getenv with your token
TOKEN = os.getenv("TENGU_TOKEN")

## Set up our working directory and shared configuration

In [None]:
# Define our project information
DESCRIPTION = "tengu-py demo notebook"
TAGS = ["md_only", "tengu-py", "demo"]
WORK_DIR = Path.home() / "qdx" / "tengu-py-md-demo"
OUT_DIR = WORK_DIR / "runs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Set our inputs - unsure ligand residue is UNL
SYSTEM_PDB_PATH = WORK_DIR / "test.pdb"
PROTEIN_PDB_PATH = WORK_DIR / "test_P.pdb"
LIGAND_PDB_PATH = WORK_DIR / "test_L.pdb"
NUM_GPUS=0

In [None]:
# fetch datafiles
complex = list(pdb_keepcoord.keep_coordinates(pdb_fetch.fetch_structure("3HTB")))
protein = pdb_delhetatm.remove_hetatm(pdb_selchain.select_chain(complex, "A"))
ligand = pdb_rplresname.rename_residues(pdb_selresname.filter_residue_by_name(complex, "JZ4"), "JZ4", "UNL")
with open(SYSTEM_PDB_PATH, 'w') as f:
    for l in complex:
        f.write(str(l))
with open(PROTEIN_PDB_PATH, 'w') as f:
    for l in protein:
        f.write(str(l))
with open(LIGAND_PDB_PATH, 'w') as f:
    for l in ligand:
        f.write(str(l))        

## Initialize our tengu client and fetch available module paths

In [None]:
# Get our client, for calling modules and using the tengu API
client = tengu.Provider(access_token=TOKEN)

In [None]:
# Get our latest modules as a dict[module_name, module_path]
modules = client.get_latest_module_paths()

## Run gromacs on Gadi

In [None]:
gmx_config = {
    "param_overrides": {
        "md": [("nsteps", "5000")],
        "em": [("nsteps", "1000")],
        "nvt": [("nsteps", "1000")],
        "npt": [("nsteps", "1000")],
        "ions": [],
    },
    "num_gpus": NUM_GPUS,
    "num_replicas": 1,
    "ligand_charge": None,
    "frame_sel": {
       "begin_time": 1,
       "delta_time": 1,
        "end_time": 2
    },  
}
gmx_result = client.run2(
    modules["gmx_tengu_pdb"],
    [
        PROTEIN_PDB_PATH,
        LIGAND_PDB_PATH,
        gmx_config,
    ],
    target="GADI",
    resources={"gpus": NUM_GPUS, "storage": 1_024_000_000, "walltime": 60},
    tags=TAGS,
)
gmx_run_id = gmx_result["module_instance_id"]
gmx_output_id = gmx_result["output_ids"][0]
gmx_ligand_gro_id = gmx_result["output_ids"][3]
print(f"{datetime.now().time()} | Running GROMACS simulation!")

## Save run details

In [None]:
with open(OUT_DIR / f"02-gmx-{gmx_run_id}.json", "w") as f:
    json.dump(gmx_result, f, default=str, indent=2)

## Wait for module to complete and download results

In [None]:
done = client.poll_module_instance(gmx_run_id, n_retries=60, poll_rate=60)
if done["status"] != "COMPLETED":
    print("Module instance failed!")
    for content in done["stdout"]["nodes"]:
        for line in content["content"]:
            print(line)
    for line in [c for c in [content["content"] for content in done["stderr"]["nodes"]]]:
        for line in line:
            print(line)
else: 
    client.download_object(gmx_output_id, OUT_DIR / "02-gmx-output.zip")
    # Get the "dry" (i.e. non-solvated) frames we asked for
    with tarfile.open(OUT_DIR / "02-gmx-output.tar.gz", "r") as tf:
        selected_frame_pdbs = [
            tf.extractfile(member)
            for member in sorted(tf, key=lambda m: m.name)
            if ("dry" in member.name and "pdb" in member.name)
        ]
    client.download_object(gmx_ligand_gro_id, OUT_DIR / "02-gmx-ligand.gro")
    print(f"{datetime.now().time()} | Downloaded GROMACS output!")

## Run MMPBSA on Gadi, using GMX outputs

In [None]:
mmpbsa_config = [
    401,  # start frame
    901,  # end frame
    None,  # optional argument for overriding raw GROMACS parameters
    12,  # num_cpus
]
mmpbsa_result = client.run2(
    modules["gmx_mmpbsa_tengu"],
    [
        gmx_output_id,
        *mmpbsa_config,
    ],
    target="GADI",
    resources={"storage": 1_024_000_000, "walltime": 600},
    tags=TAGS,
)
mmpbsa_run_id = mmpbsa_result["module_instance_id"]
mmpbsa_output_id = mmpbsa_result["output_ids"][0]
print(f"{datetime.now().time()} | Running GROMACS MM-PBSA calculation!")

## Save run details

In [None]:
with open(OUT_DIR / f"03-mmpbsa-{mmpbsa_run_id}.json", "w") as f:
    json.dump(mmpbsa_result, f, default=str, indent=2)

## Wait for module to complete and download results

In [None]:
client.poll_module_instance(mmpbsa_run_id)
client.download_object(mmpbsa_output_id, OUT_DIR / "03-mmpbsa-output.tar.gz")
print(f"{datetime.now().time()} | Downloaded MM-PBSA results!")