# Prep chip order

### Boilerplate

In [1]:
%load_ext lab_black
# python internal
import collections
import copy
import gc
from glob import glob
import h5py
import itertools
import os
import random
import re
import socket
import shutil
import subprocess
import sys

# conda/pip
import dask
import graphviz
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import scipy
import seaborn as sns
from tqdm import tqdm

# special packages on the DIGS
import py3Dmol
import pymol
import pyrosetta

# notebook magic
%matplotlib inline
%load_ext autoreload
%autoreload 2

print(os.getcwd())
print(socket.gethostname())

/mnt/home/pleung/projects/peptide_binders/r0/peptide_binders
dig66


### Load selected designs

In [2]:
input_json = os.path.join(os.getcwd(), "04_run_af2_short", "to_order.json")
to_order_df = pd.read_json(input_json)

In [3]:
subpools = {}
i = 0
for length in set(to_order_df.chA_length.values):
    subset = to_order_df[to_order_df["chA_length"] == length]
    if len(subset) < 1000:
        subpools[i] = subset
        i += 1
    else:
        for split in np.array_split(subset, int(len(subset) / 1000) + 1):
            subpools[i] = split
            i += 1

In [6]:
for key, value in subpools.items():
    print(key, len(value))

0 494
1 907
2 907
3 907
4 907
5 907
6 907
7 907
8 907
9 907
10 782
11 781
12 781
13 946
14 946
15 945
16 945
17 945
18 945
19 945


### Make functions to domesticate sequences uses Ryan's `domesticator`

In [4]:
def row2dna(row) -> str:
    
    def capture_1shot_domesticator(cmd: str) -> str:
        """
        run domesticator cmd. 
        split stdout into lines.
        loop once through discarding lines up to ones including >.
        return joined output
        """

        def cmd_no_stderr(command, wait=True):
            """@nrbennet @bcov @pleung"""
            the_command = subprocess.Popen(
                command,
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                universal_newlines=True,
            )
            if not wait:
                return
            the_stuff = the_command.communicate()
            return str(the_stuff[0])

        stdout = cmd_no_stderr(cmd)

        sequence = []
        append = False
        for line in stdout.splitlines():
            if append:
                sequence.append(line)
            else:
                pass
            if ">unknown_seq1" in line:
                append = True
            else:
                pass
        to_return = "".join(sequence)
        return to_return
    
    protein = row["chA_seq"]
    dna = capture_1shot_domesticator(f"./domesticator.py {protein} --avoid_restriction_sites BsaI XhoI NdeI --avoid_patterns AGGAGG GCTGGTGG ATCTGTT GGRGGT GGATCC GCTAGC AAAAAAAA GGGGG TTTTTTTT CCCCCCCC CACCTGC --avoid_kmers 8 --avoid_kmers_boost 10 --species s_cerevisiae")
    return dna

### Reverse translate designs to order
TODO

In [None]:
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from glob import glob
import logging
import pwd

print("run the following from your local terminal:")
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)

subpool_futures = {}
translated_subpools = {}

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="1GB",
        queue="long",
        walltime="23:30:00",
        death_timeout=120,
        local_directory="$TMPDIR/dask",
        log_directory="/mnt/home/pleung/logs/slurm_logs",
        extra=["--lifetime", "23h", "--lifetime-stagger", "4m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-50 workers,
        cluster.adapt(
            minimum=1,
            maximum=20,
            wait_count=400,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            for subpool, df in subpools.items():
                future_df = client.submit(df.apply, row2dna, axis=1)
                subpool_futures[subpool] = future_df
            for subpool, pending in subpool_futures.items():
                translated_subpools[subpool] = pending.result()
            output_dir = os.path.join(os.getcwd(), "05_to_order")
            os.makedirs(output_dir, exist_ok=True)
            with open(os.path.join(output_dir, "subpools.pkl"), "wb") as handle:
                pickle.dump(translated_subpools, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [5]:
# prefix = "atactacggtctcaagga"  # for scarless cloning
# irbs = "AGCGGCGGCAGCTAGTAAAGAAGGAGATATCATATGAGCGGCGGCAGC"  # for bicstronic expression and normalization
# suffix = "GGttcccgagaccgtaatgc"  # for scarless cloning
# with open(os.path.join(os.getcwd(), "07_to_order", "cs_dna.list"), "w+") as bicis:
#     for index, row in tqdm(to_order_df.iterrows()):
#         chA = row["chA_seq"]
#         chB = row["chB_seq"]
#         len_chB_dna = 3 * len(chB)
#         fully_domesticated = capture_1shot_domesticator(
#             cmd_no_stderr(
#                 f"./domesticator.py {chB+chA} --avoid_restriction_sites BsaI XhoI NdeI --avoid_patterns AGGAGG GCTGGTGG ATCTGTT GGRGGT --avoid_kmers 8 --avoid_kmers_boost 10 --species e_coli"
#             )
#         )
#         chB_dna = fully_domesticated[:len_chB_dna]
#         chA_dna = fully_domesticated[len_chB_dna:]
#         bicis.write(index + "\t" + prefix + chB_dna + irbs + chA_dna + suffix + "\n")
# with open(os.path.join(os.getcwd(), "07_to_order", "jhb_dna.list"), "w+") as bicis:
#     for index, row in tqdm(to_order_2.iterrows()):
#         chA = row["chA_seq"]
#         chB = row["chB_seq"]
#         len_chB_dna = 3 * len(chB)
#         fully_domesticated = capture_1shot_domesticator(
#             cmd_no_stderr(
#                 f"./domesticator.py {chB+chA} --avoid_restriction_sites BsaI XhoI NdeI --avoid_patterns AGGAGG GCTGGTGG ATCTGTT GGRGGT --avoid_kmers 8 --avoid_kmers_boost 10 --species e_coli"
#             )
#         )
#         chB_dna = fully_domesticated[:len_chB_dna]
#         chA_dna = fully_domesticated[len_chB_dna:]
#         bicis.write(index + "\t" + prefix + chB_dna + irbs + chA_dna + suffix + "\n")

for pool, df in subpools.items():
    

108it [08:44,  4.85s/it]
12it [01:00,  5.07s/it]


In [35]:
for i, subpool in subpools.items():
    with open(f"test_{i}.fasta", "w") as f:
        for j, row in subpool.iterrows():
            print(f">{j}", file=f)
            print(f"{row['chA_seq']}", file=f)

In [3]:
from typing import *


def run_af2(
    prefix="",  # prefix for saving pdbs, can include path components
    query="",  # relative or abspath to pdb, pdb.gz, or pdb.bz2
    num_recycle=3,  # set this to 10 if plddts are low - might help models converge
    random_seed=0,  # try changing seed if you need to sample more
    num_models=5,  # it will run [4, 3, 5, 2, 1][:num_models] these models, 4 is used for compiling jax params
    index_gap=200,  # decrease under 32 if you have a prior that chains need to interact
    save_pdbs=True,  # if false will save pdbstring in output dict instead
) -> Dict:
    import bz2
    import os

    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
    from string import ascii_uppercase
    import sys

    sys.path.insert(0, "/projects/ml/alphafold/alphafold_git/")
    from typing import Dict
    import jax
    from jax.lib import xla_bridge
    import matplotlib.pyplot as plt

    import numpy as np
    from alphafold.common import protein
    from alphafold.data import pipeline
    from alphafold.data import templates
    from alphafold.model import data
    from alphafold.model import config
    from alphafold.model import model
    from alphafold.relax import relax
    from alphafold.relax import utils
    import pyrosetta
    import pyrosetta.distributed.io as io
    from pyrosetta.distributed.tasks.rosetta_scripts import (
        SingleoutputRosettaScriptsTask,
    )
    from pyrosetta.rosetta.core.pose import Pose

    def set_bfactor(pose: Pose, lddt_array: list) -> None:
        for resid, residue in enumerate(pose.residues, start=1):
            for i, atom in enumerate(residue.atoms(), start=1):
                pose.pdb_info().bfactor(resid, i, lddt_array[resid - 1])
        return

    def mk_mock_template(query_sequence: str) -> Dict:
        """
        Make a mock template dict from a query sequence.
        Since alphafold's model requires a template input,
        we create a blank example w/ zero input, confidence -1
        @minkbaek @aivan
        """
        ln = len(query_sequence)
        output_templates_sequence = "-" * ln
        output_confidence_scores = np.full(ln, -1)
        templates_all_atom_positions = np.zeros(
            (ln, templates.residue_constants.atom_type_num, 3)
        )
        templates_all_atom_masks = np.zeros(
            (ln, templates.residue_constants.atom_type_num)
        )
        templates_aatype = templates.residue_constants.sequence_to_onehot(
            output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID
        )
        template_features = {
            "template_all_atom_positions": templates_all_atom_positions[None],
            "template_all_atom_masks": templates_all_atom_masks[None],
            "template_sequence": [f"none".encode()],
            "template_aatype": np.array(templates_aatype)[None],
            "template_confidence_scores": output_confidence_scores[None],
            "template_domain_names": [f"none".encode()],
            "template_release_date": [f"none".encode()],
        }
        return template_features

    def get_rmsd(design: Pose, prediction: Pose) -> float:
        """Calculate Ca-RMSD of prediction to design"""
        rmsd_calc = pyrosetta.rosetta.core.simple_metrics.metrics.RMSDMetric()
        rmsd_calc.set_rmsd_type(pyrosetta.rosetta.core.scoring.rmsd_atoms(3))
        rmsd_calc.set_run_superimpose(True)
        rmsd_calc.set_comparison_pose(design)
        rmsd = float(rmsd_calc.calculate(prediction))
        return rmsd

    def DAN(pdb: str) -> np.array:
        import os, subprocess

        def cmd(command, wait=True):
            """@nrbennet @bcov"""
            the_command = subprocess.Popen(
                command,
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                universal_newlines=True,
            )
            if not wait:
                return
            the_stuff = the_command.communicate()
            return str(the_stuff[0]) + str(the_stuff[1])

        pythonpath = "/software/conda/envs/tensorflow/bin/python"
        script = "/net/software/DeepAccNet/DeepAccNet.py"
        npz = pdb.replace(".pdb", ".npz")
        to_send = f"""{pythonpath} {script} -r -v --pdb {pdb} {npz} """
        print(cmd(to_send))
        x = np.load(npz)
        os.remove(npz)
        lddt = x["lddt"]
        return lddt

    def predict_structure(
        prefix="",
        feature_dict={},
        Ls=[],
        model_params={},
        use_model={},
        random_seed=0,
        index_gap=200,
        save_pdbs=True,
    ) -> Dict:
        """Predicts structure using AlphaFold for the given pdb/pdb.bz2."""
        # Minkyung"s code adds big enough number to residue index to indicate chain breaks
        idx_res = feature_dict["residue_index"]
        L_prev = 0
        # Ls: number of residues in each chain
        for L_i in Ls[:-1]:
            idx_res[L_prev + L_i :] += index_gap
            L_prev += L_i
        feature_dict["residue_index"] = idx_res
        # Run the models.
        plddts, paes, ptms, rmsds = [], [], [], []
        poses = []
        for model_name, params in model_params.items():
            if model_name in use_model:
                model_runner = model_runner_4  # global, only compile once
                model_runner.params = params
                processed_feature_dict = model_runner.process_features(
                    feature_dict, random_seed=random_seed
                )
                prediction_result = model_runner.predict(processed_feature_dict)
                unrelaxed_protein = protein.from_prediction(
                    processed_feature_dict, prediction_result
                )
                plddts.append(prediction_result["plddt"])
                paes.append(prediction_result["predicted_aligned_error"])
                ptms.append(prediction_result["ptm"])
                # add termini after each chain
                unsafe_pose = io.to_pose(
                    io.pose_from_pdbstring(protein.to_pdb(unrelaxed_protein))
                )
                cleaned_pose = Pose()
                total = 0
                chunks = []
                mylist = list(unsafe_pose.residues)
                for j in range(len(Ls)):
                    chunk_mylist = mylist[total : total + Ls[j]]
                    chunks.append(chunk_mylist)
                    total += Ls[j]
                    temp_pose = Pose()
                    for k in chunk_mylist:
                        temp_pose.append_residue_by_bond(k)
                    pyrosetta.rosetta.core.pose.append_pose_to_pose(
                        cleaned_pose, temp_pose, True
                    )
                sc = pyrosetta.rosetta.protocols.simple_moves.SwitchChainOrderMover()
                sc.chain_order("".join([str(i) for i in range(1, len(Ls) + 1)]))
                sc.apply(cleaned_pose)
                rmsds.append(get_rmsd(pose, cleaned_pose))
                # relax sidechains to prevent distracting clashes in output
                xml = """
                <ROSETTASCRIPTS>
                    <SCOREFXNS>
                        <ScoreFunction name="sfxn" weights="beta_nov16" />
                    </SCOREFXNS>
                    <RESIDUE_SELECTORS>
                    </RESIDUE_SELECTORS>
                    <TASKOPERATIONS>
                    </TASKOPERATIONS>
                    <TASKOPERATIONS>
                        <IncludeCurrent name="current" />
                    </TASKOPERATIONS>
                    <MOVERS>
                        <FastRelax name="relax" scorefxn="sfxn" repeats="1" bondangle="false" bondlength="false" task_operations="current" >
                            <MoveMap name="MM" bb="false" chi="true" jump="false" />
                        </FastRelax>
                    </MOVERS>
                    <FILTERS>
                    </FILTERS>
                    <APPLY_TO_POSE>
                    </APPLY_TO_POSE>
                    <PROTOCOLS>
                        <Add mover="relax" />
                    </PROTOCOLS>
                    <OUTPUT />
                </ROSETTASCRIPTS>
                """
                relaxer = SingleoutputRosettaScriptsTask(xml)
                relaxed_ppose = relaxer(cleaned_pose.clone())
                poses.append(io.to_pose(relaxed_ppose))
                # cleanup some memory
                del processed_feature_dict, prediction_result

        model_idx = [4, 3, 5, 1, 2]
        model_idx = model_idx[:num_models]
        out = {}
        # save output pdbs and metadata
        for n, r in enumerate(model_idx):
            os.makedirs(
                os.path.join(os.getcwd(), "/".join(prefix.split("/")[:-1])),
                exist_ok=True,
            )
            relaxed_pdb_path = f"{prefix}_relaxed_model_{r}.pdb"
            set_bfactor(poses[n], list(plddts[n]))
            poses[n].dump_pdb(relaxed_pdb_path)
            average_plddts = float(plddts[n].mean())

            out[f"model_{r}"] = {
                "average_plddts": average_plddts,
                "plddt": plddts[n].tolist(),
                "pae": paes[n].tolist(),
                "ptm": ptms[n].tolist(),
                "rmsd_to_input": rmsds[n],
                "pdb_path": os.path.abspath(relaxed_pdb_path),
            }
            print(f"model_{r}: average plddt {average_plddts}")
        return out

    # begin main method
    pyrosetta.init("-run:constant_seed 1 -mute all -corrections::beta_nov16 true")
    # read in pdbs, do bz2 check, if query does not contain .pdb throw exception
    if ".pdb" in query and ".bz2" not in query:
        pose = pyrosetta.io.pose_from_file(query)
    elif ".pdb.bz2" in query:
        with open(query, "rb") as f:
            ppose = io.pose_from_pdbstring(bz2.decompress(f.read()).decode())
        pose = io.to_pose(ppose)
    else:
        raise RuntimeError("query must be a pdb, pdb.gz, or pdb.bz2")
    n_chains = pose.num_chains()
    seqs = [chain.sequence() for chain in pose.split_by_chain()]
    full_sequence = "".join(seqs)
    # prepare models
    use_model = {}
    if "model_params" not in dir():
        model_params = {}
    for model_name in ["model_4", "model_3", "model_5", "model_1", "model_2"][
        :num_models
    ]:
        use_model[model_name] = True
        if model_name not in model_params:
            model_params[model_name] = data.get_model_haiku_params(
                model_name=model_name + "_ptm",
                data_dir="/projects/ml/alphafold/alphafold_git/",
            )
            if (
                model_name == "model_4"
            ):  # compile only model 4 and later load weights for other models
                model_config = config.model_config(model_name + "_ptm")
                model_config.data.common.max_extra_msa = 1
                model_config.data.eval.max_msa_clusters = n_chains
                model_config.data.eval.num_ensemble = 1
                model_config.data.common.num_recycle = num_recycle
                model_runner_4 = model.RunModel(model_config, model_params[model_name])
    # prepare input data
    template_features = mk_mock_template(full_sequence)  # make mock template
    deletion_matrix = [[0] * len(full_sequence)]  # make mock deletion matrix
    msas = []
    deletion_matrices = []
    for i in range(n_chains):
        # make a sequence of length full_sequence where everything but the i-th chain is "-"
        msa = [
            "".join(["-" * len(seq) if i != j else seq for j, seq in enumerate(seqs)])
        ]
        msas.append(msa)
        deletion_matrices.append(deletion_matrix)
    feature_dict = {
        **pipeline.make_sequence_features(
            sequence=full_sequence,
            description="none",
            num_res=len(full_sequence),
        ),
        **pipeline.make_msa_features(msas=msas, deletion_matrices=deletion_matrices),
        **template_features,
    }
    # predict structure
    if prefix == "":
        prefix = query
    else:
        pass
    out = predict_structure(
        prefix=prefix,
        feature_dict=feature_dict,
        Ls=[len(l) for l in seqs],
        model_params=model_params,
        use_model=use_model,
        random_seed=random_seed,
        index_gap=index_gap,
        save_pdbs=save_pdbs,
    )
    # deallocate backend memory to make room for DAN
    # TODO delete runners/config?
    del model_params
    device = xla_bridge.get_backend().platform
    backend = xla_bridge.get_backend(device)
    for buffer in backend.live_buffers():
        buffer.delete()
    # run DAN
    for model, result in out.items():
        pdb_path = result["pdb_path"]
        # DAN_plddt = DAN(pdb_path)
        # result["average_DAN_plddts"] = float(DAN_plddt.mean())
        # result["DAN_plddt"] = DAN_plddt.tolist()
        # if not save, write pdbstrings to output dict
        if not save_pdbs:
            result["pdb_string"] = io.to_pdbstring(io.pose_from_file(pdb_path))
            os.remove(pdb_path)
    return out

### Setup dask, set command line options, make tasks and submit to client again to cleanup disulfides

In [None]:
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from glob import glob
import logging
import pwd
from pyrosetta.distributed.cluster.core import PyRosettaCluster


print("run the following from your local terminal:")
print(
    f"ssh -L 8000:localhost:8787 {pwd.getpwuid(os.getuid()).pw_name}@{socket.gethostname()}"
)


def create_tasks(to_clean):
    with open(to_clean, "r") as f:
        for file in f:
            tasks = {"options": "-corrections::beta_nov16 true"}
            tasks["extra_options"] = options
            tasks["-s"] = file.rstrip()
            yield tasks


logging.basicConfig(level=logging.INFO)


options = {
    "-out:level": "300",
    "-precompute_ig": "true",
    "-detect_disulf": "false",
    "-holes:dalphaball": "/home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc",
    "-indexed_structure_store:fragment_store": "/net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5",
}

### Cleanup disulfides
Do short designs first

In [None]:
to_clean = os.path.join(os.getcwd(), "03_filter/cleanup_short.list")
output_path = os.path.join(os.getcwd(), "04_cleanup_short")

if __name__ == "__main__":
    # configure SLURM cluster as a context manager
    with SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="6GB",
        queue="medium",
        walltime="23:30:00",
        death_timeout=120,
        local_directory="$TMPDIR/dask",
        log_directory="/mnt/home/pleung/logs/slurm_logs",
        extra=["--lifetime", "23h", "--lifetime-stagger", "4m"],
    ) as cluster:
        print(cluster.job_script())
        # scale between 1-510 workers,
        cluster.adapt(
            minimum=1,
            maximum=510,
            wait_count=400,  # Number of consecutive times that a worker should be suggested for removal it is removed
            interval="5s",  # Time between checks
        )
        # setup a client to interact with the cluster as a context manager
        with Client(cluster) as client:
            print(client)
            PyRosettaCluster(
                tasks=create_tasks(to_clean),
                client=client,
                scratch_dir=output_path,
                output_path=output_path,
            ).distribute(protocols=[finalize_design])

### Look at scores
Hacky function to load JSON-like data

In [None]:
def read_scorefile(scores):
    import pandas as pd
    from tqdm import tqdm

    dfs = []
    with open(scores, "r") as f:
        for line in tqdm(f):
            dfs.append(pd.read_json(line).T)
    tabulated_scores = pd.concat(dfs)
    return tabulated_scores

### Get short first

In [None]:
%%time
output_path = os.path.join(os.getcwd(), "04_cleanup_short")
scores = os.path.join(output_path, "scores.json")
scores_df = read_scorefile(scores)
scores_df.to_json(os.path.join(output_path, "scores.json"))

### Reload as JSON

In [None]:
scores_df = pd.read_json(os.path.join(os.getcwd(), "04_cleanup_short", "scores.json"))

### Plotting functions and config

In [None]:
sns.set(
    context="talk",
    font_scale=1.5,  # make the font larger; default is pretty small
    style="ticks",  # make the background white with black lines
    palette="colorblind",  # a color palette that is colorblind friendly!
)


def rho(x, y, ax=None, **kwargs):
    """Plot the correlation coefficient in the top left hand corner of a plot.
    https://stackoverflow.com/questions/50832204/show-correlation-values-in-pairplot-using-seaborn-in-python/50835066
    """
    import scipy

    r, _ = scipy.stats.pearsonr(x, y)
    ax = ax or plt.gca()
    # Unicode for lowercase rho (ρ)
    rho = "\u03C1"
    ax.annotate(f"{rho} = {r:.2f}", xy=(0.1, 0.9), xycoords=ax.transAxes)


def plot_unity(xdata, ydata, **kwargs):
    """https://stackoverflow.com/questions/48122019/how-can-i-plot-identity-lines-on-a-seaborn-pairplot"""
    xmin, ymin = (xdata.min(), ydata.min())
    xmax, ymax = (xdata.max(), ydata.max())
    xpoints = np.linspace(xmin, xmax, 100)
    ypoints = np.linspace(ymin, ymax, 100)
    plt.gca().plot(
        xpoints, ypoints, color="k", marker=None, linestyle="--", linewidth=4.0
    )

### Analyze mlfold results

In [None]:
results_df = pd.read_csv(
    os.path.join(os.getcwd(), "04_af2_short_sample", "results", "scores.csv")
)


def row2design(row) -> str:
    prediction = (
        "/home/pleung/projects/peptide_binders/r0/peptide_binders/04_af2_short_sample/"
        + row["Name"]
        + ".pdb"
    )
    return prediction


def row2prediction(row) -> str:
    prediction = (
        "/home/pleung/projects/peptide_binders/r0/peptide_binders/04_af2_short_sample/results/"
        + row["ID"]
        + "_model_4.pdb"
    )
    return prediction


results_df["design"] = results_df.apply(row2design, axis=1)
results_df["prediction"] = results_df.apply(row2prediction, axis=1)
results_df = results_df.drop(["ID", "Name"], axis=1)
results_df.iloc[0]["prediction"]

realrows = []
for i, row in results_df.iterrows():
    if row["Sequence"] == "Sequence":
        pass
    else:
        realrows.append(row)

results_df = pd.DataFrame(
    realrows, columns=["Sequence", "Model/Tag", "lDDT", "Time", "design", "prediction"]
)
results_df.head()

### Check RMSD to input

In [None]:
import pyrosetta


def get_rmsd(row) -> float:

    design = row["design"]
    prediction = row["prediction"]
    rmsd_calc = pyrosetta.rosetta.core.simple_metrics.metrics.RMSDMetric()
    rmsd_calc.set_rmsd_type(pyrosetta.rosetta.core.scoring.rmsd_atoms(3))
    rmsd_calc.set_run_superimpose(True)
    design_pose = pyrosetta.io.pose_from_file(design)
    chA = design_pose.split_by_chain(1)
    rmsd_calc.set_comparison_pose(chA)
    prediction_pose = pyrosetta.io.pose_from_file(prediction)
    rmsd = float(rmsd_calc.calculate(prediction_pose))
    return rmsd


pyrosetta.init()
results_df["rmsd"] = results_df.apply(get_rmsd, axis=1)

### Unused blocks

In [None]:
%%time
import pyrosetta
from pyrosetta.distributed import cluster
import pyrosetta.distributed.io as io

flags = """
-out:level 300
-precompute_ig true
-detect_disulf false
-corrections::beta_nov16 true
-holes:dalphaball /home/bcov/ppi/tutorial_build/main/source/external/DAlpahBall/DAlphaBall.gcc
-indexed_structure_store:fragment_store /net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5
"""
# pyrosetta.distributed.init(" ".join(flags.replace("\n\t", " ").split()))
pyrosetta.init(" ".join(flags.replace("\n\t", " ").split()))

t = finalize_design(
    None,
    **{
        "-s": "/mnt/home/pleung/projects/peptide_binders/r0/peptide_binders/03_detail_0/0081/47f6c6fcb6d1e9c788e002cfeb798d0bfbc3f514e73931c8.pdb.bz2",
#         "-s": "/mnt/home/pleung/projects/peptide_binders/r0/peptide_binders/03_detail_0/0018/8aeb40fd33d90cbae3429aac01b14b7f45a054268e56fea7.pdb.bz2",
    }
)