# Linkres cleanup and filtering

### Boilerplate

In [1]:
%load_ext lab_black
# python internal
import collections
import copy
import gc
from glob import glob
import h5py
import itertools
import os
import random
import re
import socket
import shutil
import subprocess
import sys

# conda/pip
import dask
import graphviz
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import scipy
import seaborn as sns
from tqdm import tqdm

# special packages on the DIGS
import py3Dmol
import pymol
import pyrosetta

# notebook magic
%matplotlib inline
%load_ext autoreload
%autoreload 2

print(os.getcwd())
print(socket.gethostname())

/mnt/home/pleung/projects/bistable_bundle/r4/helix_binders
dig99


### Generate state X and state Y with and without disulfides by doing `KeepSequenceSymmetry` and fixbb design
I will use the serialization build of PyRosetta to enable recording user defined info about the designs.  
This enables downstream inline filtering and data analysis, as well as clustering by lineage.

Record delta sap, delta total score, calculate sap and `wnm_all` correctly this time

In [2]:
from pyrosetta.distributed.packed_pose.core import PackedPose


def almost_linkres(packed_pose_in=None, **kwargs) -> PackedPose:
    """
    Load pose. Design with linkres x3: state_X_nocys, state_Y_nocys,
    state_XY_nocys. Dump X and Y poses with and without disulfides (4 total).
    Return XY without the disulfides.
    """
    import bz2, os
    import pyrosetta
    import pyrosetta.distributed.io as io
    from pyrosetta.distributed import cluster
    from pyrosetta.distributed.tasks.rosetta_scripts import (
        SingleoutputRosettaScriptsTask,
    )
    from pyrosetta.rosetta.core.pack.task.operation import (
        IncludeCurrent,
        InitializeFromCommandline,
        ExtraRotamersGeneric,
    )
    from pyrosetta.rosetta.core.pose import Pose
    from pyrosetta.rosetta.protocols.denovo_design.movers import FastDesign
    from pyrosetta.rosetta.protocols.rosetta_scripts import XmlObjects
    from pyrosetta.rosetta.protocols.task_operations import LimitAromaChi2Operation

    def sele2str(sele: str, xml_string: str, resis_to_filter: list, pose: Pose) -> str:
        """
        Turn residue selector into list of resis seperated by comma
        @cdemakis @pleung
        """
        import itertools
        from pyrosetta.rosetta.protocols.rosetta_scripts import XmlObjects

        xml = XmlObjects.create_from_string(xml_string)
        selector = xml.get_residue_selector(sele)
        string = str(
            ",".join(
                map(
                    str,
                    list(
                        itertools.compress(list(selector.apply(pose)), resis_to_filter)
                    ),
                )
            )
        )
        return string

    if packed_pose_in == None:
        file = kwargs["-s"]
        with open(file, "rb") as f:
            ppose = io.pose_from_pdbstring(bz2.decompress(f.read()).decode())
        scores = pyrosetta.distributed.cluster.get_scores_dict(file)["scores"]
    else:
        raise RuntimeError("Need to supply an input")

    pose = io.to_pose(ppose)

    if kwargs["out_path"] == None:
        out_path = os.getcwd()
    else:
        out_path = kwargs["out_path"]
        os.makedirs(os.path.join(os.getcwd(), out_path), exist_ok=True)
    cys_X = scores["disulfide_at_X"].split(",")
    cys_Y = scores["disulfide_at_Y"].split(",")
    unique_cys = set(cys_X + cys_Y)
    # check whether disulfides are orthogonal
    design_resis = [int(x) for x in unique_cys]
    try:
        assert len(unique_cys) in [
            3,
            4,
        ]  # if there are less than 3 CYS or more than 4 we really f'd up at some point
    except AssertionError:
        print("There should be at least 3 unique CYS, but no more than 4")
        return None
    cys1_X, cys2_X = int(cys_X[0]), int(cys_X[1])
    cys1_Y, cys2_Y = int(cys_Y[0]), int(cys_Y[1])
    length = scores["total_length"]
    resis_matrix = [[x, int(x + length)] for x in design_resis]
    target_resis = ",".join(
        [str(x) for row in resis_matrix for x in row]
    )  # flatten matrix
    xml_string = """
    <ROSETTASCRIPTS>
        <SCOREFXNS>
            <ScoreFunction name="sfxn" weights="beta_nov16" />
            <ScoreFunction name="sfxn_design" weights="beta_nov16" >
                <Set use_hb_env_dep="true" />
                <Reweight scoretype="approximate_buried_unsat_penalty" weight="17" />
                <Set approximate_buried_unsat_penalty_burial_atomic_depth="3.5" />
                <Set approximate_buried_unsat_penalty_hbond_energy_threshold="-1.0" />
                <Set approximate_buried_unsat_penalty_natural_corrections1="true" />
                <Set approximate_buried_unsat_penalty_hbond_bonus_cross_chain="-7" />
                <Set approximate_buried_unsat_penalty_hbond_bonus_ser_to_helix_bb="1"/>                    
            </ScoreFunction>
        </SCOREFXNS>
        <RESIDUE_SELECTORS>
            <Chain name="chA" chains="A" />
            <Chain name="chB" chains="B" />
            <Chain name="chC" chains="C" />
            <Chain name="chD" chains="D" />
            <Chain name="AB" chains="A,B" />
            <Chain name="CD" chains="C,D" />
            <Neighborhood name="twosidedY" selector="chD" distance="8.0"/>
            <And name="onesidedY" selectors="chC,twosidedY"/>
            <Index name="designable" resnums="{target_resis}" />
            <Neighborhood name="packable" selector="designable" distance="6.0" include_focus_in_subset="true" />
            <Not name="not_designable" selector="designable" />
            <Not name="not_packable" selector="packable" />
        </RESIDUE_SELECTORS>
        <TASKOPERATIONS>
            <OperateOnResidueSubset name="design" selector="designable"> # no CYS; GLY; PRO;
                <RestrictAbsentCanonicalAASRLT aas="ADEFHIKLMNQRSTVWY" />
            </OperateOnResidueSubset>
            <OperateOnResidueSubset name="pack" selector="not_designable">
                <RestrictToRepackingRLT/>
            </OperateOnResidueSubset>
            <OperateOnResidueSubset name="lock" selector="not_packable">
                <PreventRepackingRLT/>
            </OperateOnResidueSubset>
            <RestrictToRepacking name="rtrp" />
            <KeepSequenceSymmetry name="2state" setting="true"/>
        </TASKOPERATIONS>
        <SIMPLE_METRICS>
            <SapScoreMetric name="sap_X" score_selector="AB" />
            <SapScoreMetric name="sap_Y" score_selector="CD" />
            <SapScoreMetric name="sap_A" score_selector="chA" />
            <SapScoreMetric name="sap_B" score_selector="chB" />
            <SapScoreMetric name="sap_C" score_selector="chC" />
            <SapScoreMetric name="sap_D" score_selector="chD" />
        </SIMPLE_METRICS>
        <MOVERS>
            <SetupForSequenceSymmetryMover name="almost_linkres" sequence_symmetry_behaviour="2state" >
                <SequenceSymmetry residue_selectors="AB,CD" />
            </SetupForSequenceSymmetryMover>
            <SwitchChainOrder name="delete_Y" chain_order="12"/>
            <SwitchChainOrder name="delete_X" chain_order="34"/>
            <SwitchChainOrder name="A_only" chain_order="1"/>
            <SwitchChainOrder name="B_only" chain_order="2"/>
            <SwitchChainOrder name="C_only" chain_order="3"/>
            <SwitchChainOrder name="D_only" chain_order="4"/>
            <MutateResidue name="cys1_X" target="{cys1_X}" new_res="CYS" />
            <MutateResidue name="cys2_X" target="{cys2_X}" new_res="CYS" />
            <MutateResidue name="cys1_Y" target="{cys1_Y}" new_res="CYS" />
            <MutateResidue name="cys2_Y" target="{cys2_Y}" new_res="CYS" />
            <ForceDisulfides name="restore_X" 
                scorefxn="sfxn_design"
                disulfides="{cys1_X}:{cys2_X}"
                remove_existing="false"
                repack="true" />
            <ForceDisulfides name="restore_Y" 
                scorefxn="sfxn_design"
                disulfides="{cys1_Y}:{cys2_Y}"
                remove_existing="false"
                repack="true" />
            <FastRelax name="relax" 
                scorefxn="sfxn_design"
                repeats="1" 
                relaxscript="MonomerRelax2019"
                />
            <PackRotamersMover name="repack" 
                scorefxn="sfxn_design"
                task_operations="rtrp"
                />
            <RunSimpleMetrics name="run_metrics" metrics="sap_X,sap_Y,sap_A,sap_B,sap_C,sap_D" override="true" />
        </MOVERS>
        <FILTERS>
            <worst9mer name="pre_wnm_all" rmsd_lookup_threshold="0.4" confidence="0" />
            <MoveBeforeFilter name="wnm_all_X" mover="A_only" filter="pre_wnm_all" confidence="0" />
            <MoveBeforeFilter name="wnm_all_Y" mover="C_only" filter="pre_wnm_all" confidence="0" />
            <ScoreType name="total_score_pose" scorefxn="sfxn" score_type="total_score" threshold="0" confidence="0" />
            <MoveBeforeFilter name="tot_score_X" mover="delete_Y" filter="total_score_pose" confidence="0" />
            <MoveBeforeFilter name="tot_score_Y" mover="delete_X" filter="total_score_pose" confidence="0" />
            <MoveBeforeFilter name="tot_score_A" mover="A_only" filter="total_score_pose" confidence="0" />
            <MoveBeforeFilter name="tot_score_B" mover="B_only" filter="total_score_pose" confidence="0" />
            <MoveBeforeFilter name="tot_score_C" mover="C_only" filter="total_score_pose" confidence="0" />
            <MoveBeforeFilter name="tot_score_D" mover="D_only" filter="total_score_pose" confidence="0" />
        </FILTERS>
        <APPLY_TO_POSE>
        </APPLY_TO_POSE>
        <PROTOCOLS>
            <Add filter="wnm_all_X" />
            <Add filter="wnm_all_Y" />
            <Add filter="tot_score_X" />
            <Add filter="tot_score_Y" />
            <Add filter="tot_score_A" />
            <Add filter="tot_score_B" />
            <Add filter="tot_score_C" />
            <Add filter="tot_score_D" />
            <Add metrics="sap_X,sap_Y,sap_A,sap_B,sap_C,sap_D" labels="sap_X,sap_Y,sap_A,sap_B,sap_C,sap_D"/>
        </PROTOCOLS>
    </ROSETTASCRIPTS>
    """.format(
        target_resis=target_resis,
        cys1_X=cys1_X,
        cys2_X=cys2_X,
        cys1_Y=cys1_Y,
        cys2_Y=cys2_Y,
    )
    # set taskops by extracting them from the xml_string
    xml_obj = XmlObjects.create_from_string(xml_string)
    filters = SingleoutputRosettaScriptsTask(xml_string)
    design_task = xml_obj.get_task_operation("design")
    pack_task = xml_obj.get_task_operation("pack")
    lock_task = xml_obj.get_task_operation("lock")
    linkres_task = xml_obj.get_task_operation("2state")
    # set taskops
    task_factory = pyrosetta.rosetta.core.pack.task.TaskFactory()
    task_factory.push_back(IncludeCurrent())
    task_factory.push_back(InitializeFromCommandline())
    arochi = LimitAromaChi2Operation()
    arochi.include_trp(True)
    task_factory.push_back(arochi)
    ex1_ex2 = ExtraRotamersGeneric()
    ex1_ex2.ex1(True)
    ex1_ex2.ex2(True)
    task_factory.push_back(ex1_ex2)
    task_factory.push_back(design_task)
    task_factory.push_back(pack_task)
    task_factory.push_back(lock_task)
    task_factory.push_back(linkres_task)
    # set movemap
    mm = pyrosetta.rosetta.core.kinematics.MoveMap()
    mm.set_bb(False)
    mm.set_chi(True)
    mm.set_jump(False)
    sfxn_design = xml_obj.get_score_function("sfxn_design")
    fast_design = FastDesign(scorefxn_in=sfxn_design, standard_repeats=1)
    fast_design.cartesian(False)
    fast_design.set_task_factory(task_factory)
    fast_design.set_movemap(mm)
    fast_design.minimize_bond_angles(False)
    fast_design.minimize_bond_lengths(False)
    fast_design.min_type("lbfgs_armijo_nonmonotone")
    fast_design.ramp_down_constraints(False)

    name = file
    basename_no_ext = name.split("/")[-1].replace(".pdb.bz2", "", 1)
    if out_path[-1] != "/":
        out_path += "/"
    else:
        pass
    pose = pose.clone()
    almost_linkres = xml_obj.get_mover("almost_linkres")
    almost_linkres.apply(pose)
    fast_design.apply(pose)

    ref_pose = pose.clone()
    # get state X
    delete_Y = xml_obj.get_mover("delete_Y")
    delete_Y.apply(pose)
    ref_X = pose.clone()
    cys1 = xml_obj.get_mover("cys1_X")
    cys2 = xml_obj.get_mover("cys2_X")
    restore = xml_obj.get_mover("restore_X")
    relax = xml_obj.get_mover("relax")
    repack = xml_obj.get_mover("repack")
    cys1.apply(pose)
    cys2.apply(pose)
    restore.apply(pose)
    repack.apply(pose)
    # dump X
    pose.dump_pdb(os.path.join(os.getcwd(), out_path + f"{basename_no_ext}_X.pdb"))
    scores["X_seq"] = pose.sequence()
    # get rmsd
    relax_X = ref_X.clone()
    relax.apply(relax_X)
    # dump X_nocys
    relax_X.dump_pdb(
        os.path.join(os.getcwd(), out_path + f"{basename_no_ext}_X_nocys.pdb")
    )
    rmsd_calc = pyrosetta.rosetta.core.simple_metrics.metrics.RMSDMetric()
    rmsd_calc.set_rmsd_type(pyrosetta.rosetta.core.scoring.rmsd_atoms(3))
    rmsd_calc.set_run_superimpose(True)
    rmsd_calc.set_comparison_pose(ref_X)
    rmsd = rmsd_calc.calculate(relax_X)
    scores["rmsd_final_X"] = rmsd
    # get state Y
    pose = ref_pose.clone()
    delete_X = xml_obj.get_mover("delete_X")
    delete_X.apply(pose)
    ref_Y = pose.clone()
    cys1 = xml_obj.get_mover("cys1_Y")
    cys2 = xml_obj.get_mover("cys2_Y")
    restore = xml_obj.get_mover("restore_Y")
    cys1.apply(pose)
    cys2.apply(pose)
    restore.apply(pose)
    repack.apply(pose)
    # dump Y
    pose.dump_pdb(os.path.join(os.getcwd(), out_path + f"{basename_no_ext}_Y.pdb"))
    scores["Y_seq"] = pose.sequence()
    # get rmsd
    relax_Y = ref_Y.clone()
    relax.apply(relax_Y)
    # dump Y_nocys
    relax_Y.dump_pdb(
        os.path.join(os.getcwd(), out_path + f"{basename_no_ext}_Y_nocys.pdb")
    )
    rmsd_calc = pyrosetta.rosetta.core.simple_metrics.metrics.RMSDMetric()
    rmsd_calc.set_rmsd_type(pyrosetta.rosetta.core.scoring.rmsd_atoms(3))
    rmsd_calc.set_run_superimpose(True)
    rmsd_calc.set_comparison_pose(ref_Y)
    rmsd = rmsd_calc.calculate(relax_Y)
    scores["rmsd_final_Y"] = rmsd
    assert relax_X.sequence() == relax_Y.sequence()
    scores["final_sequence"] = relax_X.sequence()
    # put in paths to X, Y, X_nocys, Y_nocys
    for suffix in ["X", "Y", "X_nocys", "Y_nocys"]:
        scores[f"{suffix}_path"] = os.path.join(
            os.getcwd(), out_path + f"{basename_no_ext}_{suffix}.pdb"
        )
    Y_resis_sel = xml_obj.get_residue_selector("CD")
    Y_resis = list(Y_resis_sel.apply(ref_pose))
    selector_strings = {
        "chC_resis": sele2str("chC", xml_string, Y_resis, ref_pose),
        "chD_resis": sele2str("chD", xml_string, Y_resis, ref_pose),
        "twosided_Y_resis": sele2str("twosidedY", xml_string, Y_resis, ref_pose),
        "onesided_Y_resis": sele2str("onesidedY", xml_string, Y_resis, ref_pose),
    }
    final_scored = filters(ref_pose.clone())
    final_keys = [
        "wnm_all_X",
        "wnm_all_Y",
        "tot_score_X",
        "tot_score_Y",
        "tot_score_A",
        "tot_score_B",
        "tot_score_C",
        "tot_score_D",
        "sap_X",
        "sap_Y",
        "sap_A",
        "sap_B",
        "sap_C",
        "sap_D",
    ]
    final_update = dict((k, final_scored.pose.scores[k]) for k in final_keys)
    scores.update({**final_update, **selector_strings})
    for key, value in scores.items():
        pyrosetta.rosetta.core.pose.setPoseExtraScore(ref_pose, key, value)
    ppose = io.to_packed(ref_pose)
    return ppose

### Setup dask, set command line options, make tasks and submit to client for cleanup

In [None]:
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
import logging
import pwd
from pyrosetta.distributed.cluster.core import PyRosettaCluster


logging.basicConfig(level=logging.INFO)
selected = os.path.join(os.getcwd(), "05_filter/good.list")
options = {
    "-out:level": "300",
    "-detect_disulf": "false",
    "-precompute_ig": "true",
    "-holes:dalphaball": "/home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc",
    "-indexed_structure_store:fragment_store": "/net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5",
}


def create_tasks(selected, options):
    with open(selected, "r") as f:
        for i, pdb_path in enumerate(f):
            tasks = {"options": "-corrections::beta_nov16 true"}
            tasks["extra_options"] = options
            tasks["-s"] = pdb_path.rstrip()
            tasks["out_path"] = f"06_states/{int(i / 250)}"
            yield tasks


print("run the following from your local terminal:")
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)

output_path = os.path.join(os.getcwd(), "06_cleanup")

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="4GB",
        queue="medium",
        walltime="23:30:00",
        death_timeout=120,
        local_directory="$TMPDIR/dask",
        log_directory="/mnt/home/pleung/logs/slurm_logs",
        extra=["--lifetime", "23h", "--lifetime-stagger", "4m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-1020 workers,
        cluster.adapt(
            minimum=1,
            maximum=1020,
            wait_count=400,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            PyRosettaCluster(
                tasks=create_tasks(selected, options),
                client=client,
                scratch_dir=output_path,
                output_path=output_path,
            ).distribute(protocols=[almost_linkres])

In [3]:
%%time
import pyrosetta
from pyrosetta.distributed import cluster
import pyrosetta.distributed.io as io

flags = """
-out:level 300
-precompute_ig 1
-detect_disulf false
-corrections::beta_nov16 true
-holes:dalphaball /home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc
-indexed_structure_store:fragment_store /net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5
"""
pyrosetta.distributed.init(" ".join(flags.replace("\n\t", " ").split()))
t = almost_linkres(
    None,
    **{
        "-s": "/mnt/home/pleung/projects/bistable_bundle/r4/helix_binders/05_msd_runs_2/4/2f65437ef963b1198d7c01eff28aa08518500dd5cc9bf2cf.pdb.bz2",
        "out_path": "test",
    }
)

CPU times: user 9min 41s, sys: 6.21 s, total: 9min 47s
Wall time: 10min 34s
