# Make bound and free states from the prepped inputs

### Imports

In [None]:
%load_ext lab_black
# Python standard library
from glob import glob
import os
import socket
import sys

# 3rd party library imports
import dask
import matplotlib.pyplot as plt
import pandas as pd
import pyrosetta
import numpy as np
import scipy
import seaborn as sns
from tqdm.auto import tqdm  # jupyter compatible progress bar

tqdm.pandas()  # link tqdm to pandas
# Notebook magic
# save plots in the notebook
%matplotlib inline
# reloads modules automatically before executing cells
%load_ext autoreload
%autoreload 2
print(f"running in directory: {os.getcwd()}")  # where are we?
print(f"running on node: {socket.gethostname()}")  # what node are we on?

### Set working directory to the root of the crispy_shifty repo

In [2]:
os.chdir("/home/pleung/projects/crispy_shifty")  # TODO projects dir

### Load a the scorefile containing the cleaned input scaffolds
These scaffolds had their disulfides removed and have some standard rosetta metrics, AF2 scores and various metadata written to them.  
We will also make the task generator here

In [3]:
sys.path.insert(0, os.getcwd())
from crispy_shifty.utils.io import parse_scorefile_linear


def create_tasks(scaffolds, options):
    for scaffold in scaffolds.index:
        num_helices = len(scaffolds.loc[scaffold]["topo"])
        if scaffolds.loc[scaffold]["scaffold_type"] != "drhicks1_JHR":
            continue  # TODO
        if num_helices % 2 == 0:  # split even numbers in half
            pre_break_helices = [int(num_helices / 2)]
        else:  # get middle two for odd numbers
            first_helix = int(num_helices / 2)  # rounds down
            pre_break_helices = [first_helix, first_helix + 1]
        for pre_break_helix in pre_break_helices:
            tasks = {}
            tasks["clash_cutoff"] = 5000
            tasks["extra_options"] = options
            tasks["name"] = (
                scaffolds.loc[scaffold]["pdb"].split("/")[-1].replace(".pdb", "", 1)
            )
            tasks["pdb_path"] = scaffold
            tasks["pre_break_helix"] = pre_break_helix
            yield tasks


scaffolds = parse_scorefile_linear(
    os.path.join(os.getcwd(), "scaffolds/01_prep_inputs/scores.json")
)

15503it [01:07, 228.03it/s]


### Make helix-bound states from the scaffolds

In [None]:
# Python standard library
import os
import pwd
import socket
import sys

# 3rd party library imports
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

# Rosetta library imports
from pyrosetta.distributed.cluster.core import PyRosettaCluster

# Custom library imports
sys.path.insert(0, os.getcwd())
from crispy_shifty.protocols.states import make_bound_states  # the functions we will distribute


print(
    "run the following from your local terminal to port forward the dashboard to localhost"
)
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)
print("dashboard is now visible at localhost:8000")
print(f"can also view dashboard at {socket.gethostname()}:8787 without port forwarding")
options = {
    "-out:level": "200",  # warning outputs only
}
output_path = os.path.join(os.getcwd(), "scaffolds/02_make_states")
os.makedirs(output_path, exist_ok=True)

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="8GB",
        queue="short",
        walltime="3:00:00",
        death_timeout=120,
        local_directory="$TMPDIR",  # spill worker litter on local node temp storage
        log_directory="/mnt/home/pleung/logs/slurm_logs",
        extra=["--lifetime", "3h", "--lifetime-stagger", "5m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-150 workers,
        cluster.adapt(
            minimum=1,
            maximum=150,
            wait_count=999,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
            target_duration="60s",
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            client.upload_file(
                os.path.join(os.getcwd(), "crispy_shifty/protocols/states.py")
            )  # upload the script that contains the functions to distribute
            PyRosettaCluster(
                client=client,
                logging_level="WARNING",
                output_path=output_path,
                project_name="crispy_shifty",
                scratch_dir=output_path,
                simulation_name="notebooks_02_make_states",
                tasks=create_tasks(scaffolds, options),
            ).distribute(protocols=[make_bound_states])
            client.close()
        cluster.scale(0)
        cluster.close()
    print("distributed run complete")

### Unused blocks

In [None]:
%%time 
import pyrosetta

pyrosetta.init()


sys.path.insert(0, "~/projects/crispy_shifty/") # TODO projects
from crispy_shifty.protocols.states import make_bound_states


t = next(
    make_bound_states(
        None,
        **{
            'pdb_path': '/mnt/projects/crispy_shifty/scaffolds/01_prep_inputs/decoys/0000/notebooks_01_prep_inputs_5f660bfadf6b46fa8bdfc410a4f2dc2b.pdb.bz2',
            'name': 'JHR_bd5_04863',
            'pre_break_helix': 3,
        }
    )
)

In [14]:
t.pose.dump_pdb("test.pdb")

True

In [15]:
t.scores

{'best_average_DAN_plddts': '0.9248046875',
 'best_average_plddts': '95.1726622713',
 'best_model': '5',
 'best_ptm': '0.8553902834',
 'best_rmsd_to_input': '1.2009233396',
 'buns_parent': '0.0',
 'docked_helix': '3',
 'dslf_fa13': '0.0',
 'exposed_hydrophobics_parent': '734.278564453125',
 'fa_atr': '-851.7773147936174',
 'fa_dun_dev': '22.375049926361513',
 'fa_dun_rot': '89.51412508526414',
 'fa_dun_semi': '150.17635028193519',
 'fa_elec': '-275.7626123403799',
 'fa_intra_atr_xover4': '-50.247539441568634',
 'fa_intra_elec': '-17.843802392344905',
 'fa_intra_rep_xover4': '51.75833742586744',
 'fa_intra_sol_xover4': '36.48246339198184',
 'fa_rep': '192.4940382661695',
 'fa_sol': '639.1173028594893',
 'geometry_parent': '1.0',
 'hbond_bb_sc': '-7.789377755467804',
 'hbond_lr_bb': '-5.9531405191324',
 'hbond_sc': '-20.2549378213548',
 'hbond_sr_bb': '-113.80052920118844',
 'holes_all_parent': '-1.4814857244491577',
 'holes_core_parent': '-1.3556149005889893',
 'hxl_tors': '0.6040193966