# Multi-state design

### Boilerplate

In [1]:
%load_ext lab_black
# python internal
import collections
import copy
import gc
from glob import glob
import h5py
import itertools
import os
import random
import re
import socket
import shutil
import subprocess
import sys

# conda/pip
import dask
import graphviz
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import scipy
import seaborn as sns
from tqdm import tqdm

# special packages on the DIGS
import py3Dmol
import pymol
import pyrosetta

# notebook magic
%matplotlib inline
%load_ext autoreload
%autoreload 2

print(os.getcwd())
print(socket.gethostname())

/mnt/home/pleung/projects/bistable_bundle/r4/helix_binders
dig96


### Original approach for multistate design of helix binders
`/home/flop/switch/5thround/DHRs/allo3/msd1/msd_scripts/msd_fnr.py` is nonpolar interface  
`/home/flop/switch/5thround/DHRs/allo3/msd3_pol/msd_scripts/msd_fnr.py` allows polars, while  
`/home/flop/switch/5thround/DHRs/allo3/msd4_pol/msd_scripts/msd_fnr.py` uses constraints to force in polars
I will use the serialization build of PyRosetta to enable recording user defined info about the designs.  
This enables downstream inline filtering and data analysis, as well as clustering by lineage.

Fix disulfides and chain labels, and optionally trim state X if it has a trailing loop  
TODO make sure it is ok to trim the way it is now  
TODO make a function that trims the smaller pose of X,Y safely (align by DSSP?)  for now, bases largely covered by trimming C term of state X

### Make functions for multi-state design
no TRP constraint

In [2]:
from pyrosetta.distributed.packed_pose.core import PackedPose
from typing import *


def msd(packed_pose_in: PackedPose, **kwargs) -> PackedPose:
    """
    Assumes middle split, allowing backwards selection
    """

    import bz2
    from copy import deepcopy
    import pyrosetta
    from pyrosetta.rosetta.core.pose import Pose
    import pyrosetta.distributed.io as io
    from pyrosetta.distributed.tasks.rosetta_scripts import (
        SingleoutputRosettaScriptsTask,
    )
    from pyrosetta.rosetta.core.pack.task.operation import (
        OperateOnResidueSubset,
        PreventRepackingRLT,
        RestrictAbsentCanonicalAASExceptNativeRLT,
        RestrictToRepackingRLT,
    )
    from pyrosetta.rosetta.core.select import get_residues_from_subset
    from pyrosetta.rosetta.core.select.residue_selector import (
        AndResidueSelector,
        ChainSelector,
        FalseResidueSelector,
        InterGroupInterfaceByVectorSelector,
        LayerSelector,
        NeighborhoodResidueSelector,
        NotResidueSelector,
        OrResidueSelector,
        ResidueIndexSelector,
        TrueResidueSelector,
    )

    def yeet_pose_xyz(pose, xyz=(1, 0, 0)):
        """
        Given a pose and a cartesian 3D unit vector, translates the pose
        according to 100 * the unit vector without applying a rotation:
        @pleung @bcov @flop
        Args:
            pose (Pose): The pose to move.
            xyz (tuple): The cartesian 3D unit vector to move the pose in.

        Returns:
            pose (Pose): The moved pose.
        """
        from pyrosetta.rosetta.core.select.residue_selector import TrueResidueSelector
        from pyrosetta.rosetta.protocols.toolbox.pose_manipulation import (
            rigid_body_move,
        )

        assert len(xyz) == 3
        pose = pose.clone()
        entire = TrueResidueSelector()
        subset = entire.apply(pose)
        # get which direction in cartesian unit vectors (xyz) to yeet pose
        unit = pyrosetta.rosetta.numeric.xyzVector_double_t(*xyz)
        scaled_xyz = tuple([100 * x for x in xyz])
        far_away = pyrosetta.rosetta.numeric.xyzVector_double_t(*scaled_xyz)
        rigid_body_move(unit, 0, far_away, pose, subset)
        return pose

    def combined_pose_maker(poses=[]) -> Pose:
        """
        Combine up to 6 poses in a list into one multichain pose
        """
        if len(poses) == 0:
            raise RuntimeError("Empty list of poses passed")
        else:
            pass
        # get the first pose
        new_pose = poses.pop(0).clone()
        # unit vectors
        xyzs = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (-1, 0, 0), (0, -1, 0), (0, 0, -1)]
        # go through rest of poses and add them into the first one
        for i, pose in enumerate(poses):
            xyz = xyzs[i]
            to_append = yeet_pose_xyz(pose.clone(), xyz)
            new_pose.append_pose_by_jump(
                to_append,
                new_pose.num_jump() + 1,  # last jump
            )
        return new_pose

    poses = []
    # load state Y
    if packed_pose_in == None:
        file = kwargs["-s"]
        with open(file, "rb") as f:
            packed_Y = io.pose_from_pdbstring(bz2.decompress(f.read()).decode())
        scores_Y = pyrosetta.distributed.cluster.get_scores_dict(file)["scores"]
        pose = io.to_pose(packed_Y)
        for key, value in scores_Y.items():
            pyrosetta.rosetta.core.pose.setPoseExtraScore(pose, key, value)
        poses.append(pose)
    else:
        raise RuntimeError("Need to supply an input for state Y")
    # load state X
    if kwargs["-x"] != None:
        file = kwargs["-x"]
        with open(file, "rb") as f:
            packed_X = io.pose_from_pdbstring(bz2.decompress(f.read()).decode())
        scores_X = pyrosetta.distributed.cluster.get_scores_dict(file)["scores"]
        pose = io.to_pose(packed_X)
        for key, value in scores_X.items():
            pyrosetta.rosetta.core.pose.setPoseExtraScore(pose, key, value)
        poses.insert(0, pose)
    else:
        raise RuntimeError("Need to supply an input for state X")

    state_X, state_Y = poses[0], poses[1]

    # check to see if trimming state X to the same length as state Y is needed
    if len(state_X.residues) > int(state_Y.chain_end(1)):
        pose_holder = pyrosetta.rosetta.core.pose.Pose()
        for i in range(1, int(state_Y.chain_end(1)) + 1):
            pose_holder.append_residue_by_bond(state_X.residue(i))
        state_X = pose_holder.clone()
        sw = pyrosetta.rosetta.protocols.simple_moves.SwitchChainOrderMover()
        sw.chain_order("1")
        sw.apply(state_X)
    elif len(state_X.residues) < int(state_Y.chain_end(1)):  # this is unexpected
        return
    else:
        pass

    # paste in helix from state Y into X after yeeting in a safe direction
    chB_alone = state_Y.clone()
    sw = pyrosetta.rosetta.protocols.simple_moves.SwitchChainOrderMover()
    sw.chain_order("2")
    sw.apply(chB_alone)
    chB_alone = yeet_pose_xyz(chB_alone, xyz=(-1, 0, 0))  # yeet in opposite direction
    state_X.append_pose_by_jump(chB_alone, state_X.num_jump() + 1)
    sw = pyrosetta.rosetta.protocols.simple_moves.SwitchChainOrderMover()
    sw.chain_order("12")
    sw.apply(state_X)
    for key, value in scores_X.items():
        pyrosetta.rosetta.core.pose.setPoseExtraScore(state_X, key, value)

    # make sure there isn't the same disulfide between the states for some reason
    if state_X.scores["disulfide_at"] == state_Y.scores["disulfide_at"]:
        return None
    else:
        pass

    if kwargs["ala_pen"] == None:
        ala_penalty = 1
    else:
        ala_penalty = kwargs["ala_pen"]
    if kwargs["np_pen"] == None:
        np_penalty = 3
    else:
        np_penalty = kwargs["np_pen"]
    og_np_penalty = deepcopy(np_penalty)
    og_ala_penalty = deepcopy(ala_penalty)
    scores = deepcopy(scores_Y)
    new_loop_resis = scores["new_loop_resis"]
    parent_sequence = state_X.sequence()
    # heavily penalize buried unsats, unset lk_ball since it isn't worth using
    # setup res_type_constraints for FNR, setup aa_comp, setup SAP constraint
    sfxn_obj = pyrosetta.rosetta.protocols.rosetta_scripts.XmlObjects.create_from_string(
        """
        <SCOREFXNS>
            <ScoreFunction name="sfxn" weights="beta_nov16" />
            <ScoreFunction name="sfxn_design" weights="beta_nov16" >
                <Set use_hb_env_dep="true" />
                <Reweight scoretype="approximate_buried_unsat_penalty" weight="17" />
                <Set approximate_buried_unsat_penalty_burial_atomic_depth="3.5" />
                <Set approximate_buried_unsat_penalty_hbond_energy_threshold="-1.0" />
                <Set approximate_buried_unsat_penalty_natural_corrections1="true" />
                <Set approximate_buried_unsat_penalty_hbond_bonus_cross_chain="-7" />
                <Set approximate_buried_unsat_penalty_hbond_bonus_ser_to_helix_bb="1"/>
                <Reweight scoretype="lk_ball" weight="0" />
                <Reweight scoretype="lk_ball_iso" weight="0" />
                <Reweight scoretype="lk_ball_bridge" weight="0" />
                <Reweight scoretype="lk_ball_bridge_uncpl" weight="0" />
                <Reweight scoretype="res_type_constraint" weight="2.0" />
                <Reweight scoretype="aa_composition" weight="1.0" />
                <Reweight scoretype="sap_constraint" weight="1.0" />
            </ScoreFunction>
        </SCOREFXNS>
        """
    )

    sfxn = sfxn_obj.get_score_function("sfxn_design")
    sfxn_clean = sfxn_obj.get_score_function("sfxn")
    res = scores["total_length"]
    score_per_res_X, score_per_res_Y = (
        sfxn_clean(state_X) / res,
        sfxn_clean(state_Y) / res,
    )

    def msd_fnr(
        despose,
        refpose,
        weight=0.0,
        strict_layers=False,
        neighbors=False,
        design_sel=None,
        design_helix=False,
    ):
        """
        Perform multi state design (MSD) using FavorNativeResidue (FNR)
        """
        true_sel = TrueResidueSelector()
        allres = get_residues_from_subset(true_sel.apply(despose))
        diff = pyrosetta.rosetta.utility.vector1_unsigned_long()

        sel_chA = ChainSelector("A")
        chA = get_residues_from_subset(sel_chA.apply(refpose))
        sel_chB = ChainSelector("B")
        chB = get_residues_from_subset(sel_chB.apply(refpose))
        intsel = InterGroupInterfaceByVectorSelector(sel_chA, sel_chB)
        intB = AndResidueSelector(intsel, sel_chB)
        hlx_int = get_residues_from_subset(intB.apply(despose))
        hh_int = get_residues_from_subset(intsel.apply(despose))

        #  for the case where a disulfide reuses the same residue in both states we want to break the bond
        des_dslfs = [int(i) for i in despose.scores["disulfide_at"].split(",")]
        ref_dslfs = [int(i) for i in refpose.scores["disulfide_at"].split(",")]
        # check each position for seq disagreement # TODO see if this fixes sequence length disagreement at c_term
        for i in allres:
            if despose.sequence(i, i) == "C":  # maintain disulfides in despose
                continue
            elif (
                refpose.sequence(i, i) == "C"
            ):  # safely replace despose residue with CYS (not CYD)
                mut = pyrosetta.rosetta.protocols.simple_moves.MutateResidue()
                mut.set_target(i)
                mut.set_res_name(pyrosetta.rosetta.core.chemical.AA(2))  # 2 is CYS
                mut.apply(despose)
            elif despose.sequence(i, i) != refpose.sequence(i, i):
                diff.append(i)
                despose.replace_residue(i, refpose.residue(i), 1)
            else:
                pass
        # optionally allow helixes to be designed
        if design_helix:
            for i in hlx_int:
                diff.append(i)
        # for the case where a disulfide reuses the same residue in both states we want to break the bond in refpose on despose
        # use set math to determine if there is reuse
        lone_dslfs = set(des_dslfs + ref_dslfs) - set(des_dslfs)

        for i in lone_dslfs:
            for j in des_dslfs:
                if pyrosetta.rosetta.core.conformation.is_disulfide_bond(
                    despose.conformation(), i, j
                ):
                    pyrosetta.rosetta.core.conformation.break_disulfide(
                        despose.conformation(), i, j
                    )
                else:
                    pass

        if design_sel is not None:
            designable = ResidueIndexSelector(design_sel)
        else:
            if neighbors:  # design neighbors too
                designable = NeighborhoodResidueSelector(
                    ResidueIndexSelector(diff),
                    6,
                    True,
                )
            else:  # design only diff
                designable = ResidueIndexSelector(diff)
        packable = NeighborhoodResidueSelector(designable, 6, True)
        pack_option = RestrictToRepackingRLT()
        pack = OperateOnResidueSubset(pack_option, designable, True)
        lock_option = PreventRepackingRLT()
        lock = OperateOnResidueSubset(lock_option, packable, True)
        # add standard task operations
        arochi = pyrosetta.rosetta.protocols.task_operations.LimitAromaChi2Operation()
        arochi.chi2max(110)
        arochi.chi2min(70)
        arochi.include_trp(True)
        ifcl = pyrosetta.rosetta.core.pack.task.operation.InitializeFromCommandline()
        # setup custom layer design
        ss1 = pyrosetta.rosetta.core.scoring.dssp.Dssp(state_X)
        ss2 = pyrosetta.rosetta.core.scoring.dssp.Dssp(state_Y)
        surf_sel = LayerSelector()
        surf_sel.set_layers(0, 0, 1)
        surf_sel.set_use_sc_neighbors(0)
        surf_sel.set_cutoffs(20, 50)
        surf1 = get_residues_from_subset(surf_sel.apply(state_X))
        surf2 = get_residues_from_subset(surf_sel.apply(state_Y))
        core_sel = LayerSelector()
        core_sel.set_layers(1, 0, 0)
        core_sel.set_use_sc_neighbors(0)
        core1 = get_residues_from_subset(core_sel.apply(state_X))
        core2 = get_residues_from_subset(core_sel.apply(state_Y))

        intf_layr = pyrosetta.rosetta.utility.vector1_unsigned_long()
        core_both = pyrosetta.rosetta.utility.vector1_unsigned_long()
        surf_both = pyrosetta.rosetta.utility.vector1_unsigned_long()
        bdry_core = pyrosetta.rosetta.utility.vector1_unsigned_long()
        bdry_surf = pyrosetta.rosetta.utility.vector1_unsigned_long()
        surf_core = pyrosetta.rosetta.utility.vector1_unsigned_long()
        bdry_both = pyrosetta.rosetta.utility.vector1_unsigned_long()
        # enumerate all 9 possible combinations + hh_int
        for i in allres:
            if i in hh_int:
                intf_layr.append(i)

            elif i in core1:
                if i in core2:
                    core_both.append(i)
                elif i in surf2:
                    surf_core.append(i)
                else:
                    bdry_core.append(i)
            elif i in surf1:
                if i in surf2:
                    surf_both.append(i)
                elif i in core2:
                    surf_core.append(i)
                else:
                    bdry_surf.append(i)
            else:
                if i in core2:
                    bdry_core.append(i)
                elif i in surf2:
                    bdry_surf.append(i)
                else:
                    bdry_both.append(i)

        if len(intf_layr) > 0:
            sel_intf_layr = ResidueIndexSelector(intf_layr)
        else:
            sel_intf_layr = FalseResidueSelector()
        if len(core_both) > 0:
            sel_core_both = ResidueIndexSelector(core_both)
        else:
            sel_core_both = FalseResidueSelector()
        sel_surf_both = ResidueIndexSelector(surf_both)
        if len(bdry_core) > 0:
            sel_bdry_core = ResidueIndexSelector(bdry_core)
        else:
            sel_bdry_core = FalseResidueSelector()
        if len(bdry_surf) > 0:
            sel_bdry_surf = ResidueIndexSelector(bdry_surf)
        else:
            sel_bdry_surf = FalseResidueSelector()
        if len(surf_core) > 0:
            sel_surf_core = ResidueIndexSelector(surf_core)
        else:
            sel_surf_core = FalseResidueSelector()
        sel_bdry_both = ResidueIndexSelector(bdry_both)
        if strict_layers:
            sel_c = OrResidueSelector(sel_core_both, sel_bdry_core)
            sel_b = OrResidueSelector(sel_bdry_both, sel_surf_core)
            sel_s = OrResidueSelector(sel_surf_both, sel_bdry_surf)
        else:
            sel_c = sel_core_both
            sel_s = sel_surf_both
            sel_c_or_s = OrResidueSelector(sel_core_both, sel_surf_both)
            sel_b = NotResidueSelector(sel_c_or_s)

        objs_sel = pyrosetta.rosetta.protocols.rosetta_scripts.XmlObjects.create_from_string(
            """
            <RESIDUE_SELECTORS>
                <SecondaryStructure name="sheet" overlap="0" minH="3" minE="2" include_terminal_loops="false" use_dssp="true" ss="E"/>
                <SecondaryStructure name="entire_loop" overlap="0" minH="3" minE="2" include_terminal_loops="true" use_dssp="true" ss="L"/>
                <SecondaryStructure name="entire_helix" overlap="0" minH="3" minE="2" include_terminal_loops="false" use_dssp="true" ss="H"/>
                <And name="helix_cap" selectors="entire_loop">
                    <PrimarySequenceNeighborhood lower="1" upper="0" selector="entire_helix"/>
                </And>
                <And name="helix_start" selectors="entire_helix">
                    <PrimarySequenceNeighborhood lower="0" upper="1" selector="helix_cap"/>
                </And>
                <And name="helix" selectors="entire_helix">
                    <Not selector="helix_start"/>
                </And>
                <And name="loop" selectors="entire_loop">
                    <Not selector="helix_cap"/>
                </And>
            </RESIDUE_SELECTORS>
            """
        )
        helix_sel = objs_sel.get_residue_selector("helix")
        loop_sel = objs_sel.get_residue_selector("loop")
        helix_cap_sel = objs_sel.get_residue_selector("helix_cap")

        core_hlx_sel = AndResidueSelector(sel_c, helix_sel)
        bdry_hlx_sel = AndResidueSelector(sel_b, helix_sel)
        surf_hlx_sel = AndResidueSelector(sel_s, helix_sel)
        core_loop_sel = AndResidueSelector(sel_c, loop_sel)
        bdry_loop_sel = AndResidueSelector(sel_b, loop_sel)
        surf_loop_sel = AndResidueSelector(sel_s, loop_sel)

        # layer design task ops, allows the current residue at a given position if it is not included
        intf_layr_task = RestrictAbsentCanonicalAASExceptNativeRLT()
        intf_layr_task.aas_to_keep("AEFHIKLNQRSTVWYM")
        core_hlx_task = RestrictAbsentCanonicalAASExceptNativeRLT()
        core_hlx_task.aas_to_keep("AFILVW")
        bdry_hlx_task = RestrictAbsentCanonicalAASExceptNativeRLT()
        bdry_hlx_task.aas_to_keep("AEHIKLNQRSTVWYM")
        surf_hlx_task = RestrictAbsentCanonicalAASExceptNativeRLT()
        surf_hlx_task.aas_to_keep("EHKQR")
        core_loop_task = RestrictAbsentCanonicalAASExceptNativeRLT()
        core_loop_task.aas_to_keep("AFGILPVW")
        bdry_loop_task = RestrictAbsentCanonicalAASExceptNativeRLT()
        bdry_loop_task.aas_to_keep("ADEFGHIKLNPQRSTVWY")
        surf_loop_task = RestrictAbsentCanonicalAASExceptNativeRLT()
        surf_loop_task.aas_to_keep("DEGHKNPQRST")
        hlx_cap_task = RestrictAbsentCanonicalAASExceptNativeRLT()
        hlx_cap_task.aas_to_keep("DNSTP")

        intf_layr_op = OperateOnResidueSubset(intf_layr_task, sel_intf_layr, False)
        hlx_cap_op = OperateOnResidueSubset(hlx_cap_task, helix_cap_sel, False)
        core_hlx_op = OperateOnResidueSubset(core_hlx_task, core_hlx_sel, False)
        bdry_hlx_op = OperateOnResidueSubset(bdry_hlx_task, bdry_hlx_sel, False)
        surf_hlx_op = OperateOnResidueSubset(surf_hlx_task, surf_hlx_sel, False)
        core_loop_op = OperateOnResidueSubset(core_loop_task, core_loop_sel, False)
        bdry_loop_op = OperateOnResidueSubset(bdry_loop_task, bdry_loop_sel, False)
        surf_loop_op = OperateOnResidueSubset(surf_loop_task, surf_loop_sel, False)

        # push back all task ops, assumes no sheets
        task_factory = pyrosetta.rosetta.core.pack.task.TaskFactory()
        task_factory.push_back(pack)
        task_factory.push_back(lock)
        task_factory.push_back(arochi)
        task_factory.push_back(ifcl)
        task_factory.push_back(intf_layr_op)
        task_factory.push_back(hlx_cap_op)
        task_factory.push_back(core_hlx_op)
        task_factory.push_back(bdry_hlx_op)
        task_factory.push_back(surf_hlx_op)
        task_factory.push_back(core_loop_op)
        task_factory.push_back(bdry_loop_op)
        task_factory.push_back(surf_loop_op)

        # add design movers
        objs = pyrosetta.rosetta.protocols.rosetta_scripts.XmlObjects.create_from_string(
            """
            <MOVERS>
            <FastDesign name="fastdesign" repeats="1" relaxscript="InterfaceDesign2019"
                cartesian="false" dualspace="false" ramp_down_constraints="false"
                bondangle="false" bondlength="false" min_type="lbfgs_armijo_nonmonotone">
            </FastDesign>
            <AddSapConstraintMover name="add_sap" speed="lightning" sap_goal="0" penalty_per_sap="{np_penalty}" />
            <AddCompositionConstraintMover name="ala_pen" >
                <Comp entry="PENALTY_DEFINITION;TYPE ALA;ABSOLUTE 0;PENALTIES 0 {ala_penalty};DELTA_START 0;DELTA_END 1;BEFORE_FUNCTION CONSTANT;AFTER_FUNCTION LINEAR;END_PENALTY_DEFINITION;" />
            </AddCompositionConstraintMover>
            </MOVERS>
            """.format(
                np_penalty=np_penalty, ala_penalty=ala_penalty
            )
        )
        surfpol = objs.get_mover("add_sap")
        surfpol.apply(despose)
        ala_pen = objs.get_mover("ala_pen")
        ala_pen.apply(despose)
        fast_design = objs.get_mover("fastdesign")
        fast_design.set_scorefxn(sfxn)
        fast_design.set_task_factory(task_factory)
        # skip design if sequences have already converged
        if len(diff) > 0:
            pyrosetta.rosetta.protocols.protein_interface_design.FavorNativeResidue(
                despose, weight
            )
            fast_design.apply(despose)
        # remove constraints
        clear_constraints = (
            pyrosetta.rosetta.protocols.constraint_movers.ClearConstraintsMover()
        )
        clear_constraints.apply(despose)
        return

    # recover original interfacial residues and combine those from each state, assumes middle split
    objs_sse = pyrosetta.rosetta.protocols.rosetta_scripts.XmlObjects.create_from_string(
        """
        <RESIDUE_SELECTORS>
            <SSElement name="part1" selection="n_term" to_selection="{pre},H,E" chain="A" reassign_short_terminal_loop="2" />
            <SSElement name="part2" selection="-{post},H,S" to_selection="c_term" chain="A" reassign_short_terminal_loop="2" />
        </RESIDUE_SELECTORS>
        """.format(
            pre=int(scores["pre_break_helix"]),
            post=int(scores["pre_break_helix"]),
        )
    )
    part1 = objs_sse.get_residue_selector("part1")
    part2 = objs_sse.get_residue_selector("part2")
    intsel = InterGroupInterfaceByVectorSelector(part1, part2)
    intdes = get_residues_from_subset(intsel.apply(state_Y))
    intref = get_residues_from_subset(intsel.apply(state_X))
    intall = pyrosetta.rosetta.utility.vector1_unsigned_long()
    # add all residues in either interface to be designed
    for i in intdes:
        intall.append(i)
    for i in intref:
        intall.append(i)
    # one round msd with no weight, lenient layers, no neighbors on all residues that are interface in either state
    msd_fnr(
        despose=state_Y,
        refpose=state_X,
        weight=0,
        strict_layers=False,
        neighbors=False,
        design_sel=intall,
        design_helix=False,
    )
    # one round msd with no weight, strict layers, and neighbors on all residues that are different between states
    msd_fnr(
        despose=state_X,
        refpose=state_Y,
        weight=0,
        strict_layers=True,
        neighbors=True,
        design_helix=False,
    )
    # one round msd with no weight, strict layers, no neighbors on all residues that are different between states
    msd_fnr(
        despose=state_Y,
        refpose=state_X,
        weight=0,
        strict_layers=True,
        design_helix=True,
    )
    # two rounds, ramp weight with strict layers, no neighbors on all residues that are different between states
    for wt in [0.2, 0.5, 1.0]:
        msd_fnr(
            despose=state_X,
            refpose=state_Y,
            weight=wt,
            strict_layers=True,
            design_helix=False,
        )
        msd_fnr(
            despose=state_Y,
            refpose=state_X,
            weight=wt,
            strict_layers=True,
            design_helix=True,
        )
    # two rounds, ramp weight with lenient layers, no neighbors on all residues that are different between states
    for wt in [1.5, 2.0]:
        msd_fnr(
            despose=state_X,
            refpose=state_Y,
            weight=wt,
            strict_layers=False,
            design_helix=False,
        )
        msd_fnr(
            despose=state_Y,
            refpose=state_X,
            weight=wt,
            strict_layers=False,
            design_helix=True,
        )
    # set SAP penalty to 1 and alanine penalty for 0 for the last rounds
    np_penalty = 1
    ala_penalty = 0
    wt = 10
    # four rounds, max weight with lenient layers, no neighbors on all residues that are different between states
    msd_fnr(
        despose=state_X,
        refpose=state_Y,
        weight=wt,
        strict_layers=False,
        design_helix=False,
    )
    msd_fnr(
        despose=state_Y,
        refpose=state_X,
        weight=wt,
        strict_layers=False,
        design_helix=False,
    )
    wt = 100  # force convergence
    msd_fnr(
        despose=state_X,
        refpose=state_Y,
        weight=wt,
        strict_layers=False,
        design_helix=False,
    )

    # if sequences fail to converge, report failure and do not yield combined pose
    try:
        assert state_X.sequence() == state_Y.sequence()
    except AssertionError:
        print("Convergence failure with the following sequences:")
        print("X:", state_X.sequence())
        print("Y:", state_Y.sequence())
        return
    combined_scores = {}
    combined_scores["closure_type_X"] = scores_X["closure_type"]
    combined_scores["closure_type_Y"] = scores_Y["closure_type"]
    combined_scores["disulfide_at_X"] = scores_X["disulfide_at"]
    combined_scores["disulfide_at_Y"] = scores_Y["disulfide_at"]
    combined_scores["dslf_fa13_cart_X"] = scores_X["dslf_fa13_cart"]
    combined_scores["dslf_fa13_cart_Y"] = scores_Y["dslf_fa13_cart"]
    combined_scores["rmsd_cart_X"] = scores_X["rmsd_cart"]
    combined_scores["rmsd_cart_Y"] = scores_Y["rmsd_cart"]
    combined_scores["parent_sequence"] = parent_sequence
    combined_scores["ala_penalty"] = og_ala_penalty
    combined_scores["np_penalty"] = og_np_penalty
    common_keys = [
        "new_loop_resis",
        "parent",
        "bb_clash",
        "pivot_helix",
        "pre_break_helix",
        "shift",
        "total_length",
    ]
    for common_key in common_keys:
        combined_scores[common_key] = scores[common_key]
    # make base scoring xml
    xml_base = """
    <ROSETTASCRIPTS>
        <SCOREFXNS>
            <ScoreFunction name="sfxn" weights="beta_nov16" />
            <ScoreFunction name="sfxn_design" weights="beta_nov16" >
                <Set use_hb_env_dep="true" />
                <Reweight scoretype="approximate_buried_unsat_penalty" weight="17" />
                <Set approximate_buried_unsat_penalty_burial_atomic_depth="3.5" />
                <Set approximate_buried_unsat_penalty_hbond_energy_threshold="-1.0" />
                <Set approximate_buried_unsat_penalty_natural_corrections1="true" />
                <Set approximate_buried_unsat_penalty_hbond_bonus_cross_chain="-7" />
                <Set approximate_buried_unsat_penalty_hbond_bonus_ser_to_helix_bb="1"/>                    
            </ScoreFunction>
        </SCOREFXNS>
        <RESIDUE_SELECTORS>
            <Chain name="chA" chains="A"/>
            <Chain name="chB" chains="B"/>
            <Index name="new_loop_resis" resnums="{new_loop_resis}" />
            <Neighborhood name="around_new_loop" selector="new_loop_resis" distance="8.0" />
            <SSElement name="part1" selection="n_term" to_selection="{pre},H,E" chain="A" reassign_short_terminal_loop="2" />
            <SSElement name="mid" selection="{pre},H,E" to_selection="-{post},H,S" chain="A" reassign_short_terminal_loop="2" />
            <SSElement name="part2" selection="-{post},H,S" to_selection="c_term" chain="A" reassign_short_terminal_loop="2" />
            <Or name="chB_OR_mid" selectors="chB,mid" />
        </RESIDUE_SELECTORS>
        <TASKOPERATIONS>
            <IncludeCurrent name="current" />
            <LimitAromaChi2 name="arochi" chi2max="110" chi2min="70" include_trp="True" />
            <ExtraRotamersGeneric name="ex1_ex2" ex1="1" ex2="1" />
            <InitializeFromCommandline name="ifcl"/>
            <ProteinInterfaceDesign name="pack_long" design_chain1="0" design_chain2="0" jump="1" interface_distance_cutoff="15"/>
        </TASKOPERATIONS>
        <MOVERS>
            <SavePoseMover name="save_before_relax" restore_pose="0" reference_name="before_relax"/>
            <DeleteRegionMover name="cutn" residue_selector="part1" rechain="true" />
            <DeleteRegionMover name="cutc" residue_selector="part2" rechain="true" />
            <DeleteRegionMover name="cutB" residue_selector="chB_OR_mid" rechain="true" />
            <FastRelax name="relax" scorefxn="sfxn_design" repeats="1" batch="false" ramp_down_constraints="false"
                cartesian="false" bondangle="false" bondlength="false" min_type="dfpmin_armijo_nonmonotone"
                task_operations="ifcl,current,arochi,ex1_ex2" >
            </FastRelax>
            <TaskAwareMinMover name="min" scorefxn="sfxn" bb="0" chi="1" task_operations="pack_long" />
        </MOVERS>
        <FILTERS>
            <BuriedUnsatHbonds name="vbuns" use_reporter_behavior="true" report_all_heavy_atom_unsats="true" 
                scorefxn="sfxn" ignore_surface_res="false" print_out_info_to_pdb="true" confidence="0" 
                use_ddG_style="false" dalphaball_sasa="true" probe_radius="1.1" atomic_depth_selection="5.5" 
                burial_cutoff="1000" burial_cutoff_apo="0.2" />
            <BuriedUnsatHbonds name="sbuns" use_reporter_behavior="true" report_all_heavy_atom_unsats="true"
                scorefxn="sfxn" ignore_surface_res="false" print_out_info_to_pdb="true" confidence="0"
                use_ddG_style="false" burial_cutoff="0.01" dalphaball_sasa="true" probe_radius="1.1" 
                atomic_depth_selection="5.5" atomic_depth_deeper_than="false" />
            <BuriedUnsatHbonds name="buns" use_reporter_behavior="true" report_all_heavy_atom_unsats="true" 
                scorefxn="sfxn" ignore_surface_res="false" print_out_info_to_pdb="true" confidence="0" 
                use_ddG_style="false" burial_cutoff="0.01" dalphaball_sasa="true" probe_radius="1.1"
                max_hbond_energy="1.5" burial_cutoff_apo="0.2" />
            <SSShapeComplementarity name="sc" verbose="1" loops="1" helices="1" />
            <TaskAwareScoreType name="tot_score" scorefxn="sfxn" score_type="total_score" threshold="0" mode="total"  confidence="0" />
            <MoveBeforeFilter name="score_nc" mover="cutB" filter="tot_score" confidence="0" />
            <MoveBeforeFilter name="score_nB" mover="cutc" filter="tot_score" confidence="0" />
            <MoveBeforeFilter name="score_Bc" mover="cutn" filter="tot_score" confidence="0" />
            <ExposedHydrophobics name="exposed_hydrophobics" />
            <Geometry name="geometry"
                confidence="0"
                count_bad_residues="true" />
            <Geometry name="geometry_loop" 
                residue_selector="around_new_loop" 
                confidence="0"
                count_bad_residues="true" />
            <SSPrediction name="mismatch_probability" confidence="0" 
                cmd="/software/psipred4/runpsipred_single" use_probability="1" 
                mismatch_probability="1" use_svm="1" />
            <Rmsd name="rmsd_final" reference_name="before_relax" chains="AB" superimpose="1" threshold="5" by_aln="0" confidence="0" />
            <ScoreType name="total_score_pose" scorefxn="sfxn" score_type="total_score" threshold="0" confidence="0" />
            <ResidueCount name="count" />
            <CalculatorFilter name="score_per_res" equation="total_score_full / res" threshold="-2.0" confidence="0">
                <Var name="total_score_full" filter="total_score_pose"/>
                <Var name="res" filter="count"/>
            </CalculatorFilter>        
            <worst9mer name="wnm_all" rmsd_lookup_threshold="0.4" confidence="0" />
            <worst9mer name="wnm_hlx" rmsd_lookup_threshold="0.4" confidence="0" only_helices="true" />
    """.format(
        new_loop_resis=new_loop_resis,
        pre=int(scores["pre_break_helix"]),
        post=int(scores["pre_break_helix"]),
    )
    # score state X
    xml_X = (
        xml_base
        + """
            </FILTERS>
            <SIMPLE_METRICS>
                <SapScoreMetric name="sap_score" />
            </SIMPLE_METRICS>
            <APPLY_TO_POSE>
            </APPLY_TO_POSE>
            <PROTOCOLS>
                <Add mover_name="save_before_relax" />
                <Add mover_name="relax"/>
                <Add filter_name="buns" />
                <Add filter_name="sbuns" />
                <Add filter_name="vbuns" />
                <Add filter_name="exposed_hydrophobics" />
                <Add filter_name="geometry"/>
                <Add filter_name="geometry_loop"/>
                <Add filter_name="mismatch_probability" />
                <Add filter_name="rmsd_final" />
                <Add metrics="sap_score" />
                <Add filter_name="sc" />
                <Add filter_name="score_nc" />
                <Add filter_name="score_nB" />
                <Add filter_name="score_Bc" />
                <Add filter_name="score_per_res" />
                <Add filter_name="wnm_all" />
                <Add filter_name="wnm_hlx" />
            </PROTOCOLS>
            <OUTPUT scorefxn="sfxn" />
        </ROSETTASCRIPTS>
        """
    )
    score_X = SingleoutputRosettaScriptsTask(xml_X)
    scored_X = score_X(state_X.clone())
    scores_X = scored_X.pose.scores
    scores_X = {f"{key}_X": value for key, value in scores_X.items()}
    # score state Y
    xml_Y = (
        xml_base
        + """
                <ContactMolecularSurface name="cms" target_selector="chA" binder_selector="chB" confidence="0" />
                <ContactMolecularSurface name="cms_nc" target_selector="part1" binder_selector="part2" confidence="0" />
                <ContactMolecularSurface name="cms_nB" target_selector="part1" binder_selector="chB" confidence="0" />
                <ContactMolecularSurface name="cms_Bc" target_selector="part2" binder_selector="chB" confidence="0" />
                <Ddg name="ddg" threshold="-10" jump="1" repeats="5" repack="1" relax_mover="min" confidence="0" scorefxn="sfxn" />
                <Sasa name="sasa" confidence="0" />
                <MoveBeforeFilter name="sasa_nc" mover="cutB" filter="sasa" confidence="0" />
                <MoveBeforeFilter name="sasa_nB" mover="cutc" filter="sasa" confidence="0" />
                <MoveBeforeFilter name="sasa_Bc" mover="cutn" filter="sasa" confidence="0" />
                <ShapeComplementarity name="sc_int" verbose="0" min_sc="0.55" write_int_area="1" write_median_dist="1" jump="1" confidence="0"/>
                <MoveBeforeFilter name="sc_int_nc" mover="cutB" filter="sc_int" confidence="0" />
                <MoveBeforeFilter name="sc_int_nB" mover="cutc" filter="sc_int" confidence="0" />
                <MoveBeforeFilter name="sc_int_Bc" mover="cutn" filter="sc_int" confidence="0" />

            </FILTERS>
            <SIMPLE_METRICS>
                <SapScoreMetric name="sap_score" />
            </SIMPLE_METRICS>
            <APPLY_TO_POSE>
            </APPLY_TO_POSE>
            <PROTOCOLS>
                <Add mover_name="save_before_relax" />
                <Add mover_name="relax"/>
                <Add filter_name="buns" />
                <Add filter_name="sbuns" />
                <Add filter_name="vbuns" />
                <Add filter_name="cms" />
                <Add filter_name="cms_nc" />
                <Add filter_name="cms_nB" />
                <Add filter_name="cms_Bc" />
                <Add filter_name="ddg" />
                <Add filter_name="sasa" />
                Add filter_name="sasa_nc" />
                Add filter_name="sasa_nB" />
                Add filter_name="sasa_Bc" />
                <Add filter_name="tot_score" />
                <Add filter_name="score_nc" />
                <Add filter_name="score_nB" />
                <Add filter_name="score_Bc" />
                <Add filter_name="sc" />
                <Add filter_name="sc_int" />
                Add filter_name="sc_int_nc" />
                Add filter_name="sc_int_nB" />
                Add filter_name="sc_int_Bc" />
                <Add filter_name="exposed_hydrophobics" />
                <Add filter_name="geometry"/>
                <Add filter_name="geometry_loop"/>
                <Add filter_name="mismatch_probability" />
                <Add filter_name="rmsd_final" />
                <Add metrics="sap_score" />
                <Add filter_name="score_per_res" />
                <Add filter_name="wnm_all" />
                <Add filter_name="wnm_hlx" />
            </PROTOCOLS>
            <OUTPUT scorefxn="sfxn" />
        </ROSETTASCRIPTS>
        """
    )

    score_Y = SingleoutputRosettaScriptsTask(xml_Y)
    scored_Y = score_Y(state_Y.clone())
    scores_Y = scored_Y.pose.scores
    scores_Y = {f"{key}_Y": value for key, value in scores_Y.items()}

    combined_XY = combined_pose_maker([state_X, state_Y])
    sw = pyrosetta.rosetta.protocols.simple_moves.SwitchChainOrderMover()
    sw.chain_order("1234")
    sw.apply(combined_XY)
    combined_scores.update({**scores_X, **scores_Y})
    # clear scores and update
    pyrosetta.rosetta.core.pose.clearPoseExtraScores(combined_XY)
    for key, value in combined_scores.items():
        pyrosetta.rosetta.core.pose.setPoseExtraScore(combined_XY, key, value)
    scored_ppose = io.to_pose(combined_XY)
    return scored_ppose

### Setup dask, set command line options, make tasks and submit to client again to test msd
At some point I should maybe try using `client.wait_for_workers(n_workers=1, timeout=None)`  
Using `ala_pen = 2`, `np_pen = 1`

In [3]:
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from glob import glob
import logging
import pwd
from pyrosetta.distributed.cluster.core import PyRosettaCluster


print("run the following from your local terminal:")
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)

ala_pen = 2
np_pen = 1


def create_tasks(selected, options):
    for pair in selected:
        with open(pair, "r") as f:
            for file in f:
                paths = file.rstrip().split()
                tasks = {"options": "-corrections::beta_nov16 true"}
                tasks["extra_options"] = options
                tasks["-s"] = paths[0]
                tasks["-x"] = paths[1]
                tasks["ala_pen"] = ala_pen
                tasks["np_pen"] = np_pen
                yield tasks


logging.basicConfig(level=logging.INFO)
pairs = glob(os.path.join(os.getcwd(), "04_pairs/paired/DHR*/*.pair"))
hDHRS = glob(os.path.join(os.getcwd(), "04_pairs/paired/hDHR*/*.pair"))
TH_DHRS = glob(os.path.join(os.getcwd(), "04_pairs/paired/TH_DHR*/*.pair"))
TH_DHRS = [t for t in TH_DHRS if ("_S2" not in t and "_C2" not in t and "_C9" not in t)]
hTH_DHRS = glob(os.path.join(os.getcwd(), "04_pairs/paired/hTH_DHR*/*.pair"))
KH_DHRS = glob(os.path.join(os.getcwd(), "04_pairs/paired/KH*/*.pair"))
selected = pairs + hDHRS + TH_DHRS + hTH_DHRS + KH_DHRS

print(f"Desigining {len(selected)} pairs")

options = {
    "-out:level": "300",
    "-holes:dalphaball": "/home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc",
    "-indexed_structure_store:fragment_store": "/net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5",
}

output_path = os.path.join(os.getcwd(), "05_msd_test")

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="4GB",
        queue="medium",
        walltime="23:30:00",
        death_timeout=120,
        local_directory="$TMPDIR/dask",
        log_directory="/mnt/home/pleung/logs/slurm_logs",
        extra=["--lifetime", "23h", "--lifetime-stagger", "4m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-1020 workers,
        cluster.adapt(
            minimum=1,
            maximum=1020,
            wait_count=400,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            PyRosettaCluster(
                tasks=create_tasks(selected, options),
                client=client,
                scratch_dir=output_path,
                output_path=output_path,
                nstruct=1,
            ).distribute(protocols=[msd])

run the following from your local terminal:
ssh -L 8000:localhost:8787 pleung@dig175
Desigining 535 pairs
#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -e /mnt/home/pleung/logs/slurm_logs/dask-worker-%J.err
#SBATCH -o /mnt/home/pleung/logs/slurm_logs/dask-worker-%J.out
#SBATCH -p medium
#SBATCH -n 1
#SBATCH --cpus-per-task=1
#SBATCH --mem=4G
#SBATCH -t 23:30:00

JOB_ID=${SLURM_JOB_ID%;*}

/home/pleung/.conda/envs/phil/bin/python -m distributed.cli.dask_worker tcp://172.16.131.12:39543 --nthreads 1 --memory-limit 3.73GiB --name name --nanny --death-timeout 120 --local-directory $TMPDIR/dask --lifetime 23h --lifetime-stagger 4m

<Client: 'tcp://172.16.131.12:39543' processes=0 threads=0, memory=0 B>


INFO:pyrosetta.distributed:maybe_init performing pyrosetta initialization: {'options': '-run:constant_seed 1 -multithreading:total_threads 1', 'extra_options': '-mute all', 'set_logging_handler': 'interactive', 'silent': True}
INFO:pyrosetta.rosetta:Found rosetta database at: /home/pleung/.conda/envs/phil/lib/python3.8/site-packages/pyrosetta/database; using it....
INFO:pyrosetta.rosetta:PyRosetta-4 2021 [Rosetta PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python38.Release 2021.12+release.ed6a5560506cfd327d4a6a3e2c9b0c9f6f4a6535 2021-03-26T16:09:25] retrieved from: http://www.pyrosetta.org
(C) Copyright Rosetta Commons Member Institutions. Created in JHU by Sergey Lyskov and PyRosetta Team.


### We're going to do a huge run here, more than it makes sense to use dask for
Will submit ~75k CPU hours worth of stuff to `short` and `medium` and ~150k to `backfill`
Make `SLURM` array tasks

In [3]:
import os, stat, subprocess


def create_tasks(selected):
    for pair in selected:
        with open(pair, "r") as f:
            for file in f:
                for ala_pen in 1, 2:
                    for np_pen in 1, 2:
                        tasks = {}
                        paths = file.rstrip().split()
                        tasks["-s"] = paths[0]
                        tasks["-x"] = paths[1]
                        tasks["-ala_pen"] = ala_pen
                        tasks["-np_pen"] = np_pen
                        yield tasks


pairs = glob(os.path.join(os.getcwd(), "04_pairs/paired/DHR*/*.pair"))
hDHRS = glob(os.path.join(os.getcwd(), "04_pairs/paired/hDHR*/*.pair"))
TH_DHRS = glob(os.path.join(os.getcwd(), "04_pairs/paired/TH_DHR*/*.pair"))
TH_DHRS = [t for t in TH_DHRS if ("_S2" not in t and "_C2" not in t and "_C9" not in t)]
hTH_DHRS = glob(os.path.join(os.getcwd(), "04_pairs/paired/hTH_DHR*/*.pair"))
selected = pairs + hDHRS + TH_DHRS + hTH_DHRS


msd_py = os.path.join(os.getcwd(), "msd.py")

jid = "{SLURM_JOB_ID%;*}"
sid = "{SLURM_ARRAY_TASK_ID}p"

for i, queue in [
    (0, "short"),
    (1, "short"),
    (2, "short"),
    (3, "short"),
    (4, "short"),
]:
    tasklist = f"05_msd_tasks_{i}.cmds"
    run_sh = """#!/usr/bin/env bash \n#SBATCH -J m1n10n \n#SBATCH -e /mnt/home/pleung/logs/slurm_logs/m1n10n-%J.err \n#SBATCH -o /mnt/home/pleung/logs/slurm_logs/m1n10n-%J.out \n#SBATCH -p {queue} \n#SBATCH --mem=4G \n\nJOB_ID=${jid} \nCMD=$(sed -n "${sid}" {tasklist}) \necho "${c}" | bash""".format(
        queue=queue, jid=jid, sid=sid, tasklist=tasklist, c="{CMD}"
    )
    shell = f"05_msd_run_{i}.sh"
    with open(shell, "w+") as f:
        print(run_sh, file=f)
    st = os.stat(shell)
    os.chmod(shell, st.st_mode | stat.S_IEXEC)
    with open(f"05_msd_tasks_{i}.cmds", "w+") as f:
        for nstruct in range(0, 25):
            outpath = os.path.join(os.getcwd(), f"05_msd_runs_{i}")
            full_outpath = os.path.join(os.getcwd(), outpath, f"{nstruct}")
            for tasks in create_tasks(selected):
                args_ = " ".join([" ".join([k, str(v)]) for k, v in tasks.items()])
                cmd = f"mkdir -p {full_outpath}; cd {full_outpath}; {msd_py} {args_}"
                print(cmd, file=f)

# Let's go
print("Run the following commands")
for i in "01234":
    print(f"sbatch -a 1-$(cat 05_msd_tasks_{i}.cmds | wc -l) 05_msd_run_{i}.sh")

Run the following commands
sbatch -a 1-$(cat 05_msd_tasks_0.cmds | wc -l) 05_msd_run_0.sh
sbatch -a 1-$(cat 05_msd_tasks_1.cmds | wc -l) 05_msd_run_1.sh
sbatch -a 1-$(cat 05_msd_tasks_2.cmds | wc -l) 05_msd_run_2.sh
sbatch -a 1-$(cat 05_msd_tasks_3.cmds | wc -l) 05_msd_run_3.sh
sbatch -a 1-$(cat 05_msd_tasks_4.cmds | wc -l) 05_msd_run_4.sh


### Get reference values for state X
for filtering purposes

In [None]:
from pyrosetta.distributed import requires_init
from pyrosetta.distributed.packed_pose.core import PackedPose


@requires_init
def score_ref(packed_pose_in: PackedPose, **kwargs) -> PackedPose:
    """
    Score the parents for purposes of comparison
    """
    import pyrosetta
    import pyrosetta.distributed.io as io
    from pyrosetta.distributed.tasks.rosetta_scripts import (
        SingleoutputRosettaScriptsTask,
    )

    if packed_pose_in == None:
        packed_pose_in = io.pose_from_file(kwargs["-s"])
    else:
        raise RuntimeError
    parent = kwargs["-s"].split("/")[-1].replace(".pdb", "")
    pre_break_helix = kwargs["pre_break_helix"]
    sfxn = "beta_nov16"
    xml = """
        <ROSETTASCRIPTS>
            <SCOREFXNS>
                <ScoreFunction name="sfxn" weights="{sfxn}" />
                <ScoreFunction name="sfxn_design" weights="{sfxn}" >
                    <Set use_hb_env_dep="true" />
                    <Reweight scoretype="approximate_buried_unsat_penalty" weight="17" />
                    <Set approximate_buried_unsat_penalty_burial_atomic_depth="3.5" />
                    <Set approximate_buried_unsat_penalty_hbond_energy_threshold="-1.0" />
                    <Set approximate_buried_unsat_penalty_natural_corrections1="true" />
                    <Set approximate_buried_unsat_penalty_hbond_bonus_cross_chain="-7" />
                    <Set approximate_buried_unsat_penalty_hbond_bonus_ser_to_helix_bb="1"/>                    
                </ScoreFunction>
            </SCOREFXNS>
            <RESIDUE_SELECTORS>
                <SSElement name="part1" selection="n_term" to_selection="{pre},H,E" chain="A" reassign_short_terminal_loop="2" />
                <SSElement name="part2" selection="-{post},H,S" to_selection="c_term" chain="A" reassign_short_terminal_loop="2" />
            </RESIDUE_SELECTORS>
            <TASKOPERATIONS>
                <IncludeCurrent name="current" />
                <LimitAromaChi2 name="arochi" chi2max="110" chi2min="70" include_trp="True" />
                <ExtraRotamersGeneric name="ex1_ex2" ex1="1" ex2="1" />
                <InitializeFromCommandline name="ifcl"/>
            </TASKOPERATIONS>
            <MOVERS>
                <SavePoseMover name="save_before_relax" restore_pose="0" reference_name="before_relax"/>
            </MOVERS>
            <FILTERS>
                <BuriedUnsatHbonds name="vbuns" use_reporter_behavior="true" report_all_heavy_atom_unsats="true" 
                    scorefxn="sfxn" ignore_surface_res="false" print_out_info_to_pdb="true" confidence="0" 
                    use_ddG_style="false" dalphaball_sasa="true" probe_radius="1.1" atomic_depth_selection="5.5" 
                    burial_cutoff="1000" burial_cutoff_apo="0.2" />
                <BuriedUnsatHbonds name="sbuns" use_reporter_behavior="true" report_all_heavy_atom_unsats="true"
                    scorefxn="sfxn" ignore_surface_res="false" print_out_info_to_pdb="true" confidence="0"
                    use_ddG_style="false" burial_cutoff="0.01" dalphaball_sasa="true" probe_radius="1.1" 
                    atomic_depth_selection="5.5" atomic_depth_deeper_than="false" />
                <BuriedUnsatHbonds name="buns" use_reporter_behavior="true" report_all_heavy_atom_unsats="true" 
                    scorefxn="sfxn" ignore_surface_res="false" print_out_info_to_pdb="true" confidence="0" 
                    use_ddG_style="false" burial_cutoff="0.01" dalphaball_sasa="true" probe_radius="1.1"
                    max_hbond_energy="1.5" burial_cutoff_apo="0.2" />
                <ContactMolecularSurface name="cms" verbose="true" target_selector="part1" binder_selector="part2"/>
                <ExposedHydrophobics name="exposed_hydrophobics" />
                <Geometry name="geometry"
                    confidence="0"
                    count_bad_residues="true" />
                <SSPrediction name="mismatch_probability" confidence="0" 
                    cmd="/software/psipred4/runpsipred_single" use_probability="1" 
                    mismatch_probability="1" use_svm="1" />
                <Rmsd name="rmsd_final" reference_name="before_relax" chains="A" superimpose="1" threshold="5" by_aln="0" confidence="0" />
                <ScoreType name="total_score_pose" scorefxn="sfxn" score_type="total_score" threshold="0" confidence="0" />
                <ResidueCount name="count" />
                <CalculatorFilter name="score_per_res" equation="total_score_full / res" threshold="-2.0" confidence="0">
                    <Var name="total_score_full" filter="total_score_pose"/>
                    <Var name="res" filter="count"/>
                </CalculatorFilter>        
                <worst9mer name="wnm_all" rmsd_lookup_threshold="0.4" confidence="0" />
                <worst9mer name="wnm_hlx" rmsd_lookup_threshold="0.4" confidence="0" only_helices="true" />

            </FILTERS>
            <MOVERS>
                <FastRelax name="relax" scorefxn="sfxn_design" repeats="1" batch="false" ramp_down_constraints="false"
                    cartesian="false" bondangle="false" bondlength="false" min_type="dfpmin_armijo_nonmonotone"
                    task_operations="ifcl,current,arochi,ex1_ex2" >
                </FastRelax>
            </MOVERS>
            <SIMPLE_METRICS>
                <SapScoreMetric name="sap_score" />
            </SIMPLE_METRICS>
            <APPLY_TO_POSE>
            </APPLY_TO_POSE>
            <PROTOCOLS>
                <Add mover_name="save_before_relax" />
                <Add mover_name="relax"/>
                <Add filter_name="buns" />
                <Add filter_name="sbuns" />
                <Add filter_name="vbuns" />
                <Add filter_name="cms" />
                <Add filter_name="exposed_hydrophobics" />
                <Add filter_name="geometry"/>
                <Add filter_name="mismatch_probability" />
                <Add filter_name="rmsd_final" />
                <Add metrics="sap_score" />
                <Add filter_name="score_per_res" />
                <Add filter_name="wnm_all" />
                <Add filter_name="wnm_hlx" />
            </PROTOCOLS>
            <OUTPUT scorefxn="sfxn" />
        </ROSETTASCRIPTS>
    """.format(
        sfxn=sfxn,
        pre=pre_break_helix,
        post=pre_break_helix,
    )
    scored = SingleoutputRosettaScriptsTask(xml)
    scored_ppose = scored(packed_pose_in.pose.clone())
    pose = io.to_pose(scored_ppose)
    pyrosetta.rosetta.core.pose.setPoseExtraScore(pose, "parent", parent)
    scored_ppose = io.to_packed(pose)
    return scored_ppose

### Setup dask, set command line options, make tasks and submit to client for scoring ref
Run `nstruct` of 5 to cover bases for stochasticity

In [None]:
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from glob import glob
import logging
import pwd
from pyrosetta.distributed.cluster.core import PyRosettaCluster


print("run the following from your local terminal:")
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)


def create_tasks(selected, options):
    for file in selected:
        tasks = {"options": "-corrections::beta_nov16 true"}
        tasks["extra_options"] = options
        tasks["-s"] = file
        if "THR" in file:
            tasks["pre_break_helix"] = 6
        else:
            tasks["pre_break_helix"] = 4
        yield tasks


logging.basicConfig(level=logging.INFO)
pdbs = glob(os.path.join(os.getcwd(), "00_inputs/*/*.pdb"))

options = {
    "-out:level": "300",
    "-holes:dalphaball": "/home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc",
    "-indexed_structure_store:fragment_store": "/net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5",
}

output_path = os.path.join(os.getcwd(), "05_score_ref")

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="4GB",
        queue="long",
        walltime="23:30:00",
        death_timeout=120,
        local_directory="$TMPDIR/dask",
        log_directory="/mnt/home/pleung/logs/slurm_logs",
        extra=["--lifetime", "23h", "--lifetime-stagger", "4m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-1020 workers,
        cluster.adapt(
            minimum=1,
            maximum=510,
            wait_count=400,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            PyRosettaCluster(
                tasks=create_tasks(pdbs, options),
                client=client,
                scratch_dir=output_path,
                output_path=output_path,
                nstruct=5,
            ).distribute(protocols=[score_ref])

### Look at scores
Hacky function to load JSON-like data

In [None]:
def read_scorefile(scores):
    import pandas as pd

    scores = pd.read_json(scores, orient="records", typ="frame", lines=True)
    scores = scores.T
    mat = scores.values
    n = mat.shape[0]
    dicts = list(mat[range(n), range(n)])
    index = scores.index
    tabulated_scores = pd.DataFrame(dicts, index=index)
    return tabulated_scores


output_path = os.path.join(os.getcwd(), "05_score_ref")
scores = os.path.join(output_path, "scores.json")

### Get reference scores

In [None]:
ref_df = read_scorefile(scores)
ref_df

### Now need to retrieve the JSONs containing scores for the 4 msd runs

In [None]:
%%time
import json

output_paths = glob(os.path.join(os.getcwd(), "05_msd_runs_*/*/*.json"))

scores = {}

for test in tqdm(output_paths):
    pdb = test.replace("json", "pdb")
    key = f"{pdb}"
    with open(test, "r") as f:
        values = json.load(f)
    scores[key] = values

### Write them back to disk as a single scorefile

In [None]:
%%time
scores_df = pd.DataFrame(scores)
scores_df = scores_df.T
output_path = os.path.join(os.getcwd(), "05_filter")
output_file = os.path.join(output_path, "scores.json")
os.makedirs(output_path, exist_ok=True)
scores_df.to_json(output_file)

In [None]:
output_path = os.path.join(os.getcwd(), "05_filter")
scores = os.path.join(output_path, "scores.json")
scores_df = pd.read_json(scores)
scores_df.head()

### Determine which scores covary

In [None]:
sns.set(
    context="talk",
    font_scale=1,  # make the font larger; default is pretty small
    style="ticks",  # make the background white with black lines
    palette="colorblind",  # a color palette that is colorblind friendly!
)

sap_subset = scores_df[
    [
        "sap_score_X",
        "sap_score_Y",
        "parent",
    ]
]

ax = sns.pairplot(data=sap_subset.sample(frac=0.1), hue="parent", corner=True, height=8)
plt.suptitle("Correlation of SAP, split by parent")
sns.despine()
plt.savefig("figs/05_correlations_sap_split_by_parent.png")

In [None]:
def rho(x, y, ax=None, **kwargs):
    """Plot the correlation coefficient in the top left hand corner of a plot.
    https://stackoverflow.com/questions/50832204/show-correlation-values-in-pairplot-using-seaborn-in-python/50835066
    """
    import scipy

    r, _ = scipy.stats.pearsonr(x, y)
    ax = ax or plt.gca()
    # Unicode for lowercase rho (ρ)
    rho = "\u03C1"
    ax.annotate(f"{rho} = {r:.2f}", xy=(0.1, 0.9), xycoords=ax.transAxes)


ax = sns.pairplot(data=sap_subset.sample(frac=0.1), corner=True, height=8)
ax.map_lower(rho)
plt.suptitle("Correlation of SAP, with pearson R")
sns.despine()
plt.savefig("figs/05_correlations_SAP_pearson.png")

In [None]:
sap_subset = scores_df[
    [
        "exposed_hydrophobics_X",
        "exposed_hydrophobics_Y",
        "sap_score_X",
        "sap_score_Y",
        "parent",
    ]
]

ax = sns.pairplot(data=sap_subset.sample(frac=0.1), hue="parent", corner=True, height=8)
plt.suptitle("Correlation of hydrophobicity, split by parent")
sns.despine()
plt.savefig("figs/05_correlations_hydrophobicity_split_by_parent.png")

plt.close()

ax = sns.pairplot(data=sap_subset.sample(frac=0.1), corner=True, height=8)
ax.map_lower(rho)
plt.suptitle("Correlation of hydrophobicity, with pearson R")
sns.despine()
plt.savefig("figs/05_correlations_hydrophobicity_pearson.png")

In [None]:
frag_subset = scores_df[
    [
        "geometry_X",
        "geometry_Y",
        "mismatch_probability_X",
        "mismatch_probability_Y",
        "wnm_all_X",
        "wnm_all_Y",
        "wnm_hlx_X",
        "wnm_hlx_Y",
        "parent",
    ]
]

ax = sns.pairplot(
    data=frag_subset.sample(frac=0.1), hue="parent", corner=True, height=8
)
plt.suptitle("Correlation of fragment metrics, split by parent")
sns.despine()
plt.savefig("figs/05_correlations_frag_split_by_parent.png")

plt.close()

ax = sns.pairplot(data=frag_subset.sample(frac=0.1), corner=True, height=8)
ax.map_lower(rho)
plt.suptitle("Correlation of fragment metrics, with pearson R")
sns.despine()
plt.savefig("figs/05_correlations_frag_pearson.png")

In [None]:
def plot_unity(xdata, ydata, **kwargs):
    """https://stackoverflow.com/questions/48122019/how-can-i-plot-identity-lines-on-a-seaborn-pairplot"""
    xmin, ymin = (xdata.min(), ydata.min())
    xmax, ymax = (xdata.max(), ydata.max())
    xpoints = np.linspace(xmin, xmax, 100)
    ypoints = np.linspace(ymin, ymax, 100)
    plt.gca().plot(
        xpoints, ypoints, color="k", marker=None, linestyle="--", linewidth=1.0
    )


pack_subset = scores_df[
    [
        "cms_X",
        "cms_Y",
        "score_per_res_X",
        "score_per_res_Y",
        "dslf_fa13_X",
        "dslf_fa13_Y",
        "sap_score_X",
        "sap_score_Y",
        "parent",
    ]
]

ax = sns.pairplot(
    data=pack_subset.sample(frac=0.1), hue="parent", corner=True, height=8
)
plt.suptitle("Correlation of packing and score metrics, split by parent")
sns.despine()
plt.savefig("figs/05_correlations_pack_split_by_parent.png")

plt.close()

ax = sns.pairplot(data=pack_subset.sample(frac=0.1), corner=True, height=8)
ax.map_lower(rho)
ax.map_offdiag(plot_unity)
plt.suptitle("Correlation of packing and score metrics, with pearson R")
sns.despine()
plt.savefig("figs/05_correlations_pack_pearson.png")

### Unfortunately forgot to record `np_penalty` and `ala_penalty`

In [None]:
for_plotting = scores_df[
    [
        "buns_X",
        "buns_Y",
        "cms_X",
        "cms_Y",
        "exposed_hydrophobics_X",
        "exposed_hydrophobics_Y",
        #         "np_penalty",
        "sap_score_X",
        "sap_score_Y",
        "sbuns_X",
        "sbuns_Y",
        "score_per_res_X",
        "score_per_res_Y",
        "vbuns_X",
        "vbuns_Y",
        "parent",
    ]
]

# for_plotting["np_penalty"] = for_plotting["np_penalty"].astype(int).astype(str)
for term in [
    "buns",
    "sbuns",
    "vbuns",
    "cms",
    "exposed_hydrophobics",
    "sap_score",
    "score_per_res",
]:
    fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(30, 10), tight_layout=True)
    sns.histplot(
        ax=ax1, data=scores_df, x=f"{term}_X", y=f"{term}_Y", palette="colorblind"
    )
    sns.histplot(ax=ax2, data=scores_df, x=f"{term}_X", kde=True, palette="colorblind")
    sns.histplot(ax=ax3, data=scores_df, x=f"{term}_Y", kde=True, palette="colorblind")
    plt.suptitle(f"Pairwise analysis of {term}")
    sns.despine()
    plt.savefig(f"figs/05_pairwise_{term}.png")
    plt.close

### Get the good decoys for each state after removing really bad stuff
`sap_score < 25`  
`score_per_res < -3`  
`vbuns < 3`  
`wnm_all < 0.8`  
`wnm_hlx < 0.15`  
Won't filter on TRP or AAA just yet, if I did it would be:  
`"AAA" < 3`  
`"W" == 1`

In [None]:
def row2state(row):
    state = (
        row["parent"]
        + "_p_"
        + str(int(row["pivot_helix"]))
        + "_s_"
        + str(int(row["shift"]))
    )
    return state


strict_df = scores_df[scores_df["sap_score_X"] < 25]
strict_df = strict_df[strict_df["sap_score_Y"] < 25]
strict_df = strict_df[strict_df["score_per_res_X"] < -3]
strict_df = strict_df[strict_df["score_per_res_Y"] < -3]
strict_df = strict_df[strict_df["vbuns_X"] < 1]
strict_df = strict_df[strict_df["vbuns_Y"] < 1]
strict_df = strict_df[strict_df["wnm_all_X"] < 0.8]
strict_df = strict_df[strict_df["wnm_all_Y"] < 0.8]
strict_df = strict_df[strict_df["wnm_hlx_X"] < 0.15]
strict_df = strict_df[strict_df["wnm_hlx_Y"] < 0.15]


strict_df["state"] = strict_df.apply(row2state, axis=1)

print(len(set(strict_df.state.values)))
print(len(strict_df))
print(len(set(strict_df.parent.values)))
set(scores_df.parent.values) - set(strict_df.parent.values)

### Dump a list of pairs that pass strict cutoffs

In [None]:
with open(os.path.join(os.getcwd(), "05_filter", "good.list"), "w") as f:
    for i in strict_df.index:
        print(i, file=f)

### Cleanup stuff to save space

In [None]:
to_delete = glob(os.path.join(os.getcwd(), "05_msd_runs_*/*/*.json"))

for file in tqdm(to_delete):
    os.remove(file)

### Add average parent values to scores_df

In [None]:
terms = [
    "buns",
    "cms",
    "mismatch_probability",
    "sap_score",
    "sbuns",
    "score_per_res",
    "wnm_all",
    "wnm_hlx",
    "vbuns",
]

ref_means = {
    "buns": dict(ref_df.groupby("parent")["buns"].mean()),
    "cms": dict(ref_df.groupby("parent")["cms"].mean()),
    "sbuns": dict(ref_df.groupby("parent")["sbuns"].mean()),
    "vbuns": dict(ref_df.groupby("parent")["vbuns"].mean()),
    "mismatch_probability": dict(
        ref_df.groupby("parent")["mismatch_probability"].mean()
    ),
    "sap_score": dict(ref_df.groupby("parent")["sap_score"].mean()),
    "score_per_res": dict(ref_df.groupby("parent")["score_per_res"].mean()),
    "wnm_all": dict(ref_df.groupby("parent")["wnm_all"].mean()),
    "wnm_hlx": dict(ref_df.groupby("parent")["wnm_hlx"].mean()),
}


def map_mean_parent_scores_to_row(row, term):
    parent = row.parent
    score = ref_means[term][parent]
    return score


for term in tqdm(terms):
    scores_df[f"{term}_parent"] = scores_df.apply(
        lambda row: map_mean_parent_scores_to_row(row, term),
        axis=1,
    )
scores_df.to_json("05_filter/scores_combined_with_parents.json")

In [None]:
sns.set(
    context="talk",
    font_scale=2,  # make the font larger; default is pretty small
    style="ticks",  # make the background white with black lines
    palette="colorblind",  # a color palette that is colorblind friendly!
)
terms = [
    "buns",
    "cms",
    "mismatch_probability",
    "sap_score",
    "sbuns",
    "score_per_res",
    "wnm_all",
    "wnm_hlx",
    "vbuns",
]
order = sorted(list(set(scores_df.parent.values)))
for term in terms:
    for state in ["X", "Y"]:
        fig = plt.figure(figsize=(30, 20), tight_layout=True)
        plt.xticks(rotation=90)
        sns.boxplot(
            x="parent",
            y=term + "_" + state,
            data=scores_df.sample(frac=0.2),
            showfliers=False,
            order=order,
        )
        sns.stripplot(
            x="parent",
            y=term + "_" + state,
            data=scores_df.sample(frac=0.2),
            order=order,
        )
        sns.stripplot(
            x="parent",
            y=term + "_parent",
            data=scores_df.sample(frac=0.2),
            color="black",
            order=order,
        )
        sns.despine()
        plt.title(term + "_" + state)
        plt.show()
        fig.savefig(
            "figs/05_before_selection_all_parents_vs_{term}_{state}.png".format(
                term=term, state=state
            )
        )
        plt.close()

### Blocks I didn't use

In [None]:
from pyrosetta.distributed import cluster

pyrosetta.distributed.init(
    "-out:level 500 -corrections::beta_nov16 true -holes:dalphaball /home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc -indexed_structure_store:fragment_store /net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5 -dunbrack_prob_buried 0.5 -dunbrack_prob_nonburied 0.5 -dunbrack_prob_buried_semi 0.5 -dunbrack_prob_nonburied_semi 0.5"
)
kwargs_ = {
    "-s": "/mnt/home/pleung/projects/bistable_bundle/r4/helix_binders/04_pairs/decoys/0000/2021.06.16.09.05.57.348631_6b13ba8b8f04450a885e127d5e66bcdf.pdb.bz2",
    "-x": "/mnt/home/pleung/projects/bistable_bundle/r4/helix_binders/04_staple_ref/decoys/0000/2021.06.16.00.40.13.771912_aec12a83734b4f66a196eace6dcd0910.pdb.bz2",
    "ala_pen": 0,
    "np_pen": 0.01,
}

tpose = msd(None, **kwargs_)

### Get sequences with hacky `.sh` script
Append all sequences to a FASTA:  
`while read p; do ./pdb2fasta.sh "$p" ; done < 05_filter/good.list >> temp.fasta`  
Get only the chain A sequences (chain A and chain B have the same sequence) and flatten the FASTA:  
`cat temp.fasta | sed ':a;N;$!ba;s/chain A\n/chainA /g' | grep chainA | sed 's/ chainA / /' > temp2.fasta; rm temp.fasta`

In [None]:
!while read p; do ./pdb2fasta.sh "$p" ; done < 05_filter/good.list >> temp.fasta

In [None]:
!cat temp.fasta | sed ':a;N;$!ba;s/chain A\n/chainA /g' | grep chainA | sed 's/ chainA / /' > temp2.fasta; rm temp.fasta

In [None]:
# Make mapping of fasta headers to df keys
hash2key = {}

for key in strict_df.index:
    hash_ = key.split("/")[-1].replace(".pdb", "")
    hash2key[hash_] = key

# Add in sequence info to df
strict_df["sequence"] = ""

with open("temp2.fasta", "r") as f:
    for line in tqdm(f):
        hash_, seq = line.split()
        key = hash2key[hash_.replace(">", "")]
        strict_df.loc[key, "sequence"] = seq

### Remove sequences containing != 1 TRP

In [None]:
idces = []
for i, row in strict_df.iterrows():
    if "W" in row["sequence"]:
        if row["sequence"].count("AAA") < 2:
            if row["sequence"].count("W") == 1:
                idces.append(i)

one_trp = strict_df.loc[idces]
print(len(set(one_trp.state.values)))
print(len(one_trp))
set(scores_df.parent.values) - set(one_trp.parent.values)

### Dump a list of pairs that pass strict cutoffs and have 1 TRP

In [None]:
with open(os.path.join(os.getcwd(), "05_filter", "1trp_good.list"), "w") as f:
    for i in one_trp.index:
        print(i, file=f)

In [None]:
one_trp.to_json(os.path.join(os.getcwd(), "05_filter", "1trp_good.json"))
!rm temp2.fasta

In [None]:
one_trp = pd.read_json(os.path.join(os.getcwd(), "05_filter", "1trp_good.json"))
for term in tqdm(terms):
    one_trp[f"{term}_parent"] = one_trp.apply(
        lambda row: map_mean_parent_scores_to_row(row, term),
        axis=1,
    )
order = sorted(list(set(one_trp.parent.values)))
for term in terms:
    for state in ["X", "Y"]:
        fig = plt.figure(figsize=(30, 20), tight_layout=True)
        plt.xticks(rotation=90)
        sns.boxplot(
            x="parent",
            y=term + "_" + state,
            data=one_trp,
            showfliers=False,
            order=order,
        )
        sns.stripplot(
            x="parent",
            y=term + "_" + state,
            data=one_trp,
            order=order,
        )
        sns.stripplot(
            x="parent",
            y=term + "_parent",
            data=one_trp,
            color="black",
            order=order,
        )
        sns.despine()
        plt.title(term + "_" + state)
        plt.show()
        fig.savefig(
            "figs/05_after_selection_all_parents_vs_{term}_{state}.png".format(
                term=term, state=state
            )
        )
        plt.close()

### Featurize the designs

In [None]:
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from glob import glob
import logging
import pwd


def featurize(abspath):
    import os, subprocess

    def cmd(command, wait=True):
        """@nrbennet @bcov"""
        the_command = subprocess.Popen(
            command,
            shell=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            universal_newlines=True,
        )
        if not wait:
            return
        the_stuff = the_command.communicate()
        return str(the_stuff[0]) + str(the_stuff[1])

    working_dir = os.path.dirname(abspath)
    pythonpath = "/software/conda/envs/tensorflow/bin/python"
    script = os.path.join(os.getcwd(), "predict_.py")
    pdb2fasta = os.path.join(os.getcwd(), "pdb2fasta.sh")
    pdb = os.path.basename(abspath)
    handle = pdb.replace("pdb", "fasta")
    to_send = f"""cd {working_dir}; {pdb2fasta} {pdb} > {handle} ;"""
    to_send += f""" {pythonpath} {script} -i {handle} -o "/net/scratch/pleung/" """
    print("sending: ", to_send)
    cmd(to_send)
    return


print("run the following from your local terminal:")
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)

futures = []

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="4GB",
        queue="long",
        walltime="23:30:00",
        death_timeout=120,
        local_directory="$TMPDIR/dask",
        log_directory="/mnt/home/pleung/logs/slurm_logs",
        extra=["--lifetime", "23h", "--lifetime-stagger", "4m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-510 workers,
        cluster.adapt(
            minimum=1,
            maximum=510,
            wait_count=400,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            with open(
                os.path.join(os.getcwd(), "05_filter", "1trp_good.list"), "r"
            ) as f:
                for pdb in f:
                    future = client.submit(featurize, pdb.rstrip())
                    futures.append(future)
            results = [pending.result() for pending in futures]

In [None]:
to_delete = (
    glob(os.path.join(os.getcwd(), "05_msd_runs_*/*/*.json"))
    + glob(os.path.join(os.getcwd(), "05_msd_runs_*/*/*.scwrl4.pdb"))
    + glob(os.path.join(os.getcwd(), "05_msd_runs_*/*/*_0.pdb"))
    + glob(os.path.join(os.getcwd(), "05_msd_runs_*/*/*_1.pdb"))
    + glob(os.path.join(os.getcwd(), "05_msd_runs_*/*/*.fasta"))
)

for file in tqdm(to_delete):
    os.remove(file)