# Make bound and free states from the prepped inputs

### Imports

In [1]:
%load_ext lab_black
# Python standard library
from glob import glob
import os
import socket
import sys

# 3rd party library imports
import dask
import matplotlib.pyplot as plt
import pandas as pd
import pyrosetta
import numpy as np
import scipy
import seaborn as sns
from tqdm.auto import tqdm  # jupyter compatible progress bar

tqdm.pandas()  # link tqdm to pandas
# Notebook magic
# save plots in the notebook
%matplotlib inline
# reloads modules automatically before executing cells
%load_ext autoreload
%autoreload 2
print(f"running in directory: {os.getcwd()}")  # where are we?
print(f"running on node: {socket.gethostname()}")  # what node are we on?

running in directory: /mnt/projects/crispy_shifty/notebooks
running on node: dig72


### Set working directory to the root of the crispy_shifty repo

In [2]:
os.chdir("/projects/crispy_shifty")

### Load a the scorefile containing the cleaned input scaffolds
These scaffolds had their disulfides removed and have some standard rosetta metrics, AF2 scores and various metadata written to them.  
We will also make the task generator here

In [3]:
sys.path.insert(0, os.getcwd())
from crispy_shifty.utils.io import parse_scorefile_linear


def create_tasks(scaffolds, options):
    for scaffold in scaffolds.index:
        # determine where to split the scaffold by counting the number of helixes
        num_helices = len(scaffolds.loc[scaffold]["topo"])
        # split even numbers in half
        if num_helices % 2 == 0:
            pre_break_helices = [int(num_helices / 2)]
        # get middle two for odd numbers
        else:
            first_helix = int(num_helices / 2)  # rounds down
            pre_break_helices = [first_helix, first_helix + 1]
        for pre_break_helix in pre_break_helices:
            tasks = {}
            tasks["clash_cutoff"] = 5000
            tasks["extra_options"] = options
            # interfaces must be a ratio of 1:3 or 3:1 between the n and c term halves and the bound helix
            tasks["int_cutoff"] = 0.33
            # get the name of the original design
            tasks["name"] = (
                scaffolds.loc[scaffold]["pdb"].split("/")[-1].replace(".pdb", "", 1)
            )
            tasks["pdb_path"] = scaffold
            tasks["pre_break_helix"] = pre_break_helix
            yield tasks


scaffolds = parse_scorefile_linear(
    os.path.join(os.getcwd(), "scaffolds/01_prep_inputs/scores.json")
)

  from distributed.utils import tmpfile


  0%|          | 0/153 [00:00<?, ?it/s]

### Make helix-bound states from the scaffolds

In [4]:
# Python standard library
import os
import pwd
import socket
import sys
import time

# 3rd party library imports
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

# Rosetta library imports
from pyrosetta.distributed.cluster.core import PyRosettaCluster

# Custom library imports
sys.path.insert(0, os.getcwd())
from crispy_shifty.protocols.states import (
    make_bound_states,
)  # the functions we will distribute


print(
    "run the following from your local terminal to port forward the dashboard to localhost"
)
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)
print("dashboard is now visible at localhost:8000")
print(f"can also view dashboard at {socket.gethostname()}:8787 without port forwarding")
options = {
    "-out:level": "200",  # warning outputs only
}
output_path = os.path.join(os.getcwd(), "scaffolds/02_make_bound_states")
os.makedirs(output_path, exist_ok=True)
logs_path = os.path.join(output_path, "slurm_logs")
os.makedirs(logs_path, exist_ok=True)

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="8GB",
        queue="short",
        walltime="3:00:00",
        death_timeout=120,
        local_directory="$TMPDIR",  # spill worker litter on local node temp storage
        log_directory=logs_path,
        extra=["--lifetime", "3h", "--lifetime-stagger", "5m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-150 workers,
        cluster.adapt(
            minimum=1,
            maximum=150,
            wait_count=999,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
            target_duration="60s",
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            client.upload_file(
                os.path.join(os.getcwd(), "crispy_shifty/protocols/states.py")
            )  # upload the script that contains the functions to distribute
            PyRosettaCluster(
                client=client,
                logging_level="WARNING",
                output_path=output_path,
                project_name="crispy_shifty",
                scratch_dir=output_path,
                simulation_name="notebooks_02_make_bound_states",
                tasks=create_tasks(scaffolds, options),
            ).distribute(protocols=[make_bound_states])
            time.sleep(5)
            client.close()
        time.sleep(5)
        cluster.scale(0)
        time.sleep(5)
        cluster.close()
    print("distributed run complete")

run the following from your local terminal to port forward the dashboard to localhost
ssh -L 8000:localhost:8787 pleung@dig72
dashboard is now visible at localhost:8000
can also view dashboard at dig72:8787 without port forwarding
#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -e /mnt/projects/crispy_shifty/scaffolds/02_make_bound_states/slurm_logs/dask-worker-%J.err
#SBATCH -o /mnt/projects/crispy_shifty/scaffolds/02_make_bound_states/slurm_logs/dask-worker-%J.out
#SBATCH -p short
#SBATCH -n 1
#SBATCH --cpus-per-task=1
#SBATCH --mem=8G
#SBATCH -t 3:00:00

/projects/crispy_shifty/envs/crispy/bin/python -m distributed.cli.dask_worker tcp://172.16.131.102:41635 --nthreads 1 --memory-limit 7.45GiB --name dummy-name --nanny --death-timeout 120 --local-directory $TMPDIR --lifetime 3h --lifetime-stagger 5m --protocol tcp://

<Client: 'tcp://172.16.131.102:41635' processes=0 threads=0, memory=0 B>
distributed run complete


### Make free states from the scaffolds

In [7]:
# Python standard library
import os
import pwd
import socket
import sys
import time

# 3rd party library imports
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

# Rosetta library imports
from pyrosetta.distributed.cluster.core import PyRosettaCluster

# Custom library imports
sys.path.insert(0, os.getcwd())
from crispy_shifty.protocols.states import (
    make_free_states,
)  # the functions we will distribute


print(
    "run the following from your local terminal to port forward the dashboard to localhost"
)
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)
print("dashboard is now visible at localhost:8000")
print(f"can also view dashboard at {socket.gethostname()}:8787 without port forwarding")
options = {
    "-out:level": "200",  # warning outputs only
}
output_path = os.path.join(os.getcwd(), "scaffolds/02_make_free_states")
os.makedirs(output_path, exist_ok=True)
logs_path = os.path.join(output_path, "slurm_logs")
os.makedirs(logs_path, exist_ok=True)

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="8GB",
        queue="short",
        walltime="3:00:00",
        death_timeout=120,
        local_directory="$TMPDIR",  # spill worker litter on local node temp storage
        log_directory=logs_path,
        extra=["--lifetime", "3h", "--lifetime-stagger", "5m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-150 workers,
        cluster.adapt(
            minimum=1,
            maximum=150,
            wait_count=999,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
            target_duration="60s",
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            client.upload_file(
                os.path.join(os.getcwd(), "crispy_shifty/protocols/states.py")
            )  # upload the script that contains the functions to distribute
            PyRosettaCluster(
                client=client,
                logging_level="WARNING",
                output_path=output_path,
                project_name="crispy_shifty",
                scratch_dir=output_path,
                simulation_name="notebooks_02_make_free_states",
                tasks=create_tasks(scaffolds, options),
            ).distribute(protocols=[make_free_states])
            time.sleep(5)
            client.close()
        time.sleep(5)
        cluster.scale(0)
        time.sleep(5)
        cluster.close()
    print("distributed run complete")

run the following from your local terminal to port forward the dashboard to localhost
ssh -L 8000:localhost:8787 pleung@dig72
dashboard is now visible at localhost:8000
can also view dashboard at dig72:8787 without port forwarding
#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -e /mnt/projects/crispy_shifty/scaffolds/02_make_free_states/slurm_logs/dask-worker-%J.err
#SBATCH -o /mnt/projects/crispy_shifty/scaffolds/02_make_free_states/slurm_logs/dask-worker-%J.out
#SBATCH -p short
#SBATCH -n 1
#SBATCH --cpus-per-task=1
#SBATCH --mem=8G
#SBATCH -t 3:00:00

/projects/crispy_shifty/envs/crispy/bin/python -m distributed.cli.dask_worker tcp://172.16.131.102:34673 --nthreads 1 --memory-limit 7.45GiB --name dummy-name --nanny --death-timeout 120 --local-directory $TMPDIR --lifetime 3h --lifetime-stagger 5m --protocol tcp://

<Client: 'tcp://172.16.131.102:34673' processes=0 threads=0, memory=0 B>
distributed run complete


### Load resulting scorefiles of bound and free states

In [8]:
sys.path.insert(0, os.getcwd())
from crispy_shifty.utils.io import parse_scorefile_linear

bound_scores_df = parse_scorefile_linear(
    os.path.join(os.getcwd(), "scaffolds/02_make_bound_states/scores.json")
)

free_scores_df = parse_scorefile_linear(
    os.path.join(os.getcwd(), "scaffolds/02_make_free_states/scores.json")
)

  0%|          | 0/2591 [00:00<?, ?it/s]

  0%|          | 0/3082 [00:00<?, ?it/s]

### Save a list of outputs

In [9]:
with open(
    os.path.join(os.getcwd(), "scaffolds/02_make_bound_states/bound_states.list"), "w"
) as f:
    for path in tqdm(bound_scores_df.index):
        print(path, file=f)
with open(
    os.path.join(os.getcwd(), "scaffolds/02_make_free_states/free_states.list"), "w"
) as f:
    for path in tqdm(free_scores_df.index):
        print(path, file=f)

  0%|          | 0/2591 [00:00<?, ?it/s]

  0%|          | 0/3082 [00:00<?, ?it/s]

# Save also a CSV of just free states that have 0 shift
We will need them later

In [10]:
output_path = os.path.join(
    os.getcwd(), "scaffolds/02_make_free_states/free_state_0s.csv"
)
free_state_0s = free_scores_df.query("shift == '0' and pivot_helix == pre_break_helix")
free_state_0s.to_csv(output_path)

### Prototyping blocks

test `make_bound_states`

In [None]:
%%time 
import pyrosetta

pyrosetta.init()


sys.path.insert(0, "/projects/crispy_shifty/")
from crispy_shifty.protocols.states import make_bound_states
# from crispy_shifty.protocols.states import make_bound_states


t = make_bound_states(
        None,
        **{
            'pdb_path': '/mnt/projects/crispy_shifty/scaffolds/01_prep_inputs/decoys/0000/notebooks_01_prep_inputs_fa1b5ca9cef5486383f1054118203438.pdb.bz2',
            'name': 'DHR78_DHR71_l2_0_v2c',
            'pre_break_helix': 2,
#             'clash_cutoff': 5000,
#             'int_cutoff': 0.9,
#             'full_helix': True,
        }
)

In [None]:
for i, tppose in enumerate(t):
    tppose.pose.dump_pdb(f"{tppose.scores['state']}.pdb")

test `grow_terminal_helices`

In [None]:
import pyrosetta

sys.path.insert(0, "/projects/crispy_shifty/")
from crispy_shifty.protocols.states import grow_terminal_helices


pyrosetta.init()
tpose = pyrosetta.pose_from_file(
    "/home/pleung/projects/bistable_bundle/r4/helix_binders/08_analysis/pdbs/cs_088_Y.pdb"
)
tpose2 = grow_terminal_helices(
    pose=tpose,
    chain=2,
    extend_n_term=7,
    extend_c_term=7,
)

test `extend_helix_termini`

In [None]:
import pyrosetta

sys.path.insert(0, "/projects/crispy_shifty/")
from crispy_shifty.protocols.states import extend_helix_termini


pyrosetta.init()
tpose = pyrosetta.pose_from_file(
    "/home/pleung/projects/bistable_bundle/r4/helix_binders/08_analysis/pdbs/cs_088_Y.pdb"
)
tpose2 = extend_helix_termini(
    pose=tpose,
    chain=2,
    extend_n_term=7,
    extend_c_term=7,
)

In [None]:
tpose2.dump_pdb("test2.pdb")