# Design the paired states with sequence symmetry, sampling backbones for MPNN while maintaining sequence realism

### Imports

In [1]:
%load_ext lab_black
# Python standard library
from glob import glob
import os
import socket
import sys

# 3rd party library imports
import dask
import matplotlib.pyplot as plt
import pandas as pd
import pyrosetta
import numpy as np
import scipy
import seaborn as sns
from tqdm.auto import tqdm  # jupyter compatible progress bar

tqdm.pandas()  # link tqdm to pandas
# Notebook magic
# save plots in the notebook
%matplotlib inline
# reloads modules automatically before executing cells
%load_ext autoreload
%autoreload 2
print(f"running in directory: {os.getcwd()}")  # where are we?
print(f"running on node: {socket.gethostname()}")  # what node are we on?

running in directory: /mnt/home/pleung/projects/crispy_shifty/projects/crispy_shifties
running on node: dig135


### Set working directory to the root of the crispy_shifty repo
TODO set to projects dir

In [2]:
os.chdir("/home/pleung/projects/crispy_shifty")
# os.chdir("/projects/crispy_shifty")

### Design the paired states
TODO

In [None]:
from crispy_shifty.utils.io import gen_array_tasks

simulation_name = "05_design_paired_states"
design_list_file = os.path.join(
    os.getcwd(),
    "projects/crispy_shifties/04_pair_bound_states/paired_states.list",
)
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

options = " ".join(
    [
        "out:level 200",
        "corrections::beta_nov16 true",
        #         "indexed_structure_store:fragment_store /net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5",
    ]
)

gen_array_tasks(
    distribute_func="crispy_shifty.protocols.msd.two_state_design_paired_state",
    design_list_file=design_list_file,
    output_path=output_path,
    queue="short",
    memory="6G",
    nstruct=1,
    nstruct_per_task=1,
    options=options,
    simulation_name=simulation_name,
)

In [4]:
# !sbatch -a 1-$(cat /mnt/home/pleung/projects/crispy_shifty/projects/crispy_shifties/02_mpnn_bound_states/tasks.cmds | wc -l) /mnt/home/pleung/projects/crispy_shifty/projects/crispy_shifties/02_mpnn_bound_states/run.sh

### Collect scorefiles of designed paired states and concatenate
TODO change to projects dir

In [5]:
# sys.path.insert(0, "~/projects/crispy_shifty")  # TODO
# from crispy_shifty.utils.io import collect_score_file

# simulation_name = "05_design_paired_states"
# output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

# if not os.path.exists(os.path.join(output_path, "scores.json")):
#     collect_score_file(output_path, "scores")

### Load resulting concatenated scorefile
TODO change to projects dir

In [6]:
# sys.path.insert(0, "~/projects/crispy_shifty")  # TODO
# from crispy_shifty.utils.io import parse_scorefile_linear

# output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

# scores_df = parse_scorefile_linear(os.path.join(output_path, "scores.json"))
# scores_df = scores_df.convert_dtypes()

### Setup for plotting

In [7]:
# sns.set(
#     context="talk",
#     font_scale=1,  # make the font larger; default is pretty small
#     style="ticks",  # make the background white with black lines
#     palette="colorblind",  # a color palette that is colorblind friendly!
# )

### Data exploration
Gonna remove the Rosetta sfxn scoreterms for now

In [8]:
# from crispy_shifty.protocols.design import beta_nov16_terms

# scores_df = scores_df[
#     [term for term in scores_df.columns if term not in beta_nov16_terms]
# ]
# print(len(scores_df))

In [9]:
# print(list(scores_df.columns))

In [10]:
# from crispy_shifty.plotting.utils import histplot_df, pairplot_df

# to_plot = [
#     "bb_clash_delta_x",
#     "score_per_res_x",
#     "wnm_all_x",
# ]
# print(to_plot)

### Filter extreme outliers and changes some dtypes

In [11]:
# scores_df["score_per_res_x"] = scores_df["score_per_res_x"].astype(float)

# query = "bb_clash_delta_x < 500 and score_per_res_x < 0 and wnm_all_x < 5"
# sample_df = scores_df.query(query)

### Plot loop scores

In [12]:
# cols = [
#     "bb_clash_delta_x",
#     "score_per_res_x",
#     "wnm_all_x",
# ]
# the_fig = histplot_df(
#     df=sample_df,
#     cols=cols,
#     bins=10,
#     hue="scaffold_type",
# )
# plt.savefig(os.path.join(output_path, "loop_scores.png"))

### Filter out obviously bad decoys

In [13]:
# query = "bb_clash_delta_x < 200 and score_per_res_x < -2 and wnm_all_x < 1"
# filtered = sample_df.query(query)
# len(filtered)

### Save a list of outputs

In [14]:
# simulation_name = "05_design_paired_states"
# output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

# with open(os.path.join(output_path, "designed_paired_states.list"), "w") as f:
#     for path in tqdm(filtered.index):
#         print(path, file=f)

### Prototyping blocks

test `design_paired_state`

In [15]:
# %%time
# import pyrosetta

# pyrosetta.init(
#     "-corrections::beta_nov16 \
#     -indexed_structure_store:fragment_store /net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5"
# )

# sys.path.insert(0, "~/projects/crispy_shifty/") # TODO projects
# from crispy_shifty.protocols.msd import two_state_design_paired_state

# t = two_state_design_paired_state(
#         None,
#         **{
#             'pdb_path': '/mnt/home/pleung/projects/crispy_shifty/projects/crispy_shifties/04_pair_bound_states/decoys/0001/04_pair_bound_states_331c8b82841d4de280ff3db199bb973f.pdb.bz2'
#         }
# )
# for i, tppose in enumerate(t):
#     tppose.pose.dump_pdb(f"{i}.pdb")

In [16]:
# d = dict(tppose.pose.scores)
# from crispy_shifty.protocols.design import beta_nov16_terms

# d = {k: v for k, v in d.items() if k not in beta_nov16_terms}

In [17]:
# d

test `almost_linkres`

In [None]:
from pyrosetta.rosetta.core.select.residue_selector import (
    OrResidueSelector,
    ResidueIndexSelector,
)

from crispy_shifty.protocols.design import (
    gen_movemap,
    gen_std_layer_design,
    gen_task_factory,
)
from crispy_shifty.protocols.msd import almost_linkres

pyrosetta.init(
    "-corrections::beta_nov16 true \
    -packing:precompute_ig true \
    "
)

mm = gen_movemap()
sfxn = pyrosetta.create_score_function("beta_nov16.wts")
ld = gen_std_layer_design()

# make some silly selectors

pose = pyrosetta.io.pose_from_file("foo.pdb")
print(pose.chain_end(1))
print(pose.chain_end(2))
print(pose.chain_end(3))

sel1 = ResidueIndexSelector("108,109,110,111,112")
sel2 = ResidueIndexSelector("362,363,364,365,366")
sel3 = ResidueIndexSelector("221,222,223,224,225")
pre_design_sel = OrResidueSelector(sel1, sel2)
design_sel = OrResidueSelector(pre_design_sel, sel3)

tf = gen_task_factory(
    design_sel=design_sel,
    layer_design=ld,
)

tm = almost_linkres(
    pose=pose,
    movemap=mm,
    residue_selectors=[sel1, sel2],
    scorefxn=sfxn,
    task_factory=tf,
    repeats=1,
)

In [19]:
pose.dump_pdb("baz.pdb")

True