# Prep filtered scaffold sets for distributed design

### Imports

In [1]:
%load_ext lab_black
# Python standard library
from glob import glob
import os
import socket
import sys

# 3rd party library imports
import dask
import matplotlib.pyplot as plt
import pandas as pd
import pyrosetta
import numpy as np
import scipy
import seaborn as sns
from tqdm.auto import tqdm  # jupyter compatible progress bar

tqdm.pandas()  # link tqdm to pandas
# Notebook magic
# save plots in the notebook
%matplotlib inline
# reloads modules automatically before executing cells
%load_ext autoreload
%autoreload 2
print(f"running in directory: {os.getcwd()}")  # where are we?
print(f"running on node: {socket.gethostname()}")  # what node are we on?

running in directory: /mnt/home/pleung/projects/crispy_shifty/notebooks
running on node: dig27


### Set working directory to the root of the crispy_shifty repo
TODO change to projects dir

In [2]:
os.chdir("/home/pleung/projects/crispy_shifty")
# os.chdir("/projects/crispy_shifty")

### Load a dataframe of filtered scaffolds and associated metadata
TODO change to all inputs instead of test inputs list  
These scaffolds had AF2 run on them; for their best quality prediction out of the 5 AF2 ptm models they have > 92 plddt and < 1.5 RMSD to design.  
We will also make the task generator here

In [3]:
def create_tasks(scaffolds, options):
    metadata_to_keep = [
        "pdb",
        "topo",
        "best_model",
        "best_average_plddts",
        "best_ptm",
        "best_rmsd_to_input",
        "best_average_DAN_plddts",
        "scaffold_type",
    ]
    for _, row in scaffolds.iterrows():
        metadata = dict(row)
        metadata = {k: v for k, v in metadata.items() if k in metadata_to_keep}
        pdb_path = metadata["pdb"]
        tasks = {}
        tasks["extra_options"] = options
        tasks["metadata"] = metadata
        tasks["pdb_path"] = pdb_path
        yield tasks


scaffolds = pd.read_csv(
    os.path.join(
        os.getcwd(), "scaffolds/00_filter_scaffold_sets/test_filtered.csv"
    )  # TODO
)

### Domesticate the scaffolds by trimming off leading and trailing loops, designing away disulfides and adding metadata to the output pdb.bz2s. 
TODO  
`"-holes:dalphaball": "/home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc",` can be replaced by `"-holes:dalphaball": "/software/rosetta/DAlphaBall.gcc/software/rosetta/DAlphaBall.gcc",`
and  `"-indexed_structure_store:fragment_store": "/home/bcov/sc/scaffold_comparison/data/ss_grouped_vall_all.h5",` isn't needed.  
I forgot to change them before the production run. 

In [4]:
# Python standard library
import os
import pwd
import socket
import sys

# 3rd party library imports
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

# Rosetta library imports
from pyrosetta.distributed.cluster.core import PyRosettaCluster

# Custom library imports
sys.path.insert(0, os.getcwd())
from crispy_shifty.protocols.cleaning import (
    remove_terminal_loops,
    redesign_disulfides,
)  # the functions we will distribute


print(
    "run the following from your local terminal to port forward the dashboard to localhost"
)
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)
print("dashboard is now visible at localhost:8000")
print(f"can also view dashboard at {socket.gethostname()}:8787 without port forwarding")
options = {
    "-out:level": "200",  # warning outputs only
    "-corrections::beta_nov16": "true",
    "-detect_disulf": "false",
    "-holes:dalphaball": "/home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc",
    "-indexed_structure_store:fragment_store": "/home/bcov/sc/scaffold_comparison/data/ss_grouped_vall_all.h5",
}
output_path = os.path.join(os.getcwd(), "scaffolds/01_prep_inputs")
os.makedirs(output_path, exist_ok=True)

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="10GB",
        queue="medium",
        walltime="23:00:00",
        death_timeout=120,
        local_directory="$TMPDIR",  # spill worker litter on local node temp storage
        log_directory="/mnt/home/pleung/logs/slurm_logs",
        extra=["--lifetime", "23h", "--lifetime-stagger", "5m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-500 workers,
        cluster.adapt(
            minimum=1,
            maximum=500,
            wait_count=999,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            client.upload_file(
                os.path.join(os.getcwd(), "crispy_shifty/protocols/cleaning.py")
            )  # upload the script that contains the functions to distribute
            PyRosettaCluster(
                client=client,
                logging_level="WARNING",
                output_path=output_path,
                project_name="crispy_shifty",
                scratch_dir=output_path,
                simulation_name="notebooks_01_prep_inputs",
                tasks=create_tasks(scaffolds, options),
            ).distribute(protocols=[remove_terminal_loops, redesign_disulfides])
            client.close()
        cluster.scale(0)
        cluster.close()
    print("distributed run complete")

run the following from your local terminal to port forward the dashboard to localhost
ssh -L 8000:localhost:8787 pleung@dig27
dashboard is now visible at localhost:8000
can also view dashboard at dig27:8787 without port forwarding
#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -e /mnt/home/pleung/logs/slurm_logs/dask-worker-%J.err
#SBATCH -o /mnt/home/pleung/logs/slurm_logs/dask-worker-%J.out
#SBATCH -p medium
#SBATCH -n 1
#SBATCH --cpus-per-task=1
#SBATCH --mem=10G
#SBATCH -t 23:00:00

/home/pleung/.conda/envs/crispy/bin/python -m distributed.cli.dask_worker tcp://172.16.131.57:38721 --nthreads 1 --memory-limit 9.31GiB --name dummy-name --nanny --death-timeout 120 --local-directory $TMPDIR --lifetime 23h --lifetime-stagger 5m --protocol tcp://

<Client: 'tcp://172.16.131.57:38721' processes=0 threads=0, memory=0 B>
distributed run complete


### Load resulting scorefile

In [5]:
sys.path.insert(0, os.getcwd())
from crispy_shifty.utils.io import parse_scorefile_linear

scores_df = parse_scorefile_linear(
    os.path.join(os.getcwd(), "scaffolds/01_prep_inputs/scores.json")
)

20it [00:00, 198.26it/s]Task was destroyed but it is pending!
task: <Task pending name='Task-126115' coro=<_needs_document_lock.<locals>._needs_document_lock_wrapper() running at /home/pleung/.conda/envs/crispy/lib/python3.8/site-packages/bokeh/server/session.py:61> wait_for=<Future finished result=<tornado.lock...x7f8a7ee60c10>> cb=[multi_future.<locals>.callback() at /home/pleung/.conda/envs/crispy/lib/python3.8/site-packages/tornado/gen.py:520]>
113it [00:00, 200.89it/s]


### Save a list of outputs

In [6]:
with open(
    os.path.join(os.getcwd(), "scaffolds/01_prep_inputs/prepped_inputs.list"), "w"
) as f:
    for path in scores_df.index:
        print(path, file=f)

### Unused blocks

In [None]:
%%time 
import pyrosetta

pyrosetta.init(
    "-corrections::beta_nov16 true \
    -detect_disulf false \
    -holes:dalphaball /software/rosetta/DAlphaBall.gcc "  # \
    #     -indexed_structure_store:fragment_store /home/bcov/sc/scaffold_comparison/data/ss_grouped_vall_all.h5"
)


sys.path.insert(0, "/projects/crispy_shifty/")
from crispy_shifty.protocols.cleaning import remove_terminal_loops, redesign_disulfides


t = next(
    remove_terminal_loops(
        None,
        pdb_path="/net/shared/scaffolds/pre_scaffold_DB/tj_junctions/DHR82_DHR79_l3_t1_t2_9_v4c.pdb",
        metadata={
            "pdb": "/net/shared/scaffolds/pre_scaffold_DB/tj_junctions/DHR82_DHR79_l3_t1_t2_9_v4c.pdb",
            "topo": "HHHHHHHH",
            "best_model": 2,
            "best_average_plddts": 96.0650150399,
            "best_ptm": 0.8458813818,
            "best_rmsd_to_input": 1.2088154485,
            "best_average_DAN_plddts": 0.947265625,
            "scaffold_type": "tj_junctions",
        },
    )
)

t2 = next(redesign_disulfides(t))