# Find unbound states to pair with the bound states

### Imports

In [1]:
%load_ext lab_black
# Python standard library
from glob import glob
import os
import socket
import sys

# 3rd party library imports
import dask
import matplotlib.pyplot as plt
import pandas as pd
import pyrosetta
import numpy as np
import scipy
import seaborn as sns
from tqdm.auto import tqdm  # jupyter compatible progress bar

tqdm.pandas()  # link tqdm to pandas
# Notebook magic
# save plots in the notebook
%matplotlib inline
# reloads modules automatically before executing cells
%load_ext autoreload
%autoreload 2
print(f"running in directory: {os.getcwd()}")  # where are we?
print(f"running on node: {socket.gethostname()}")  # what node are we on?

running in directory: /mnt/projects/crispy_shifty/projects/crispy_shifty_junctions
running on node: dig116


### Set working directory to the root of the crispy_shifty repo

In [2]:
os.chdir("/projects/crispy_shifty")

### Fix the paths
Necessary because we copied these designs from perlmutter

In [3]:
from crispy_shifty.utils.io import fix_path_prefixes

folded = os.path.join(
    os.getcwd(),
    "projects/crispy_shifty_junctions/03_fold_bound_states/folded_states.list",
)
new_folded = fix_path_prefixes(
    find="/pscratch/sd/p/pleung",
    replace="/mnt/projects/crispy_shifty/projects/crispy_shifty_junctions",
    file=folded,
    overwrite=True,
)

  from distributed.utils import tmpfile


### Pair the designed bound states

In [None]:
from crispy_shifty.utils.io import gen_array_tasks

simulation_name = "04_pair_bound_states"
design_list_file = os.path.join(
    os.getcwd(),
    "projects/crispy_shifty_junctions/03_fold_bound_states/folded_states.list",
)
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifty_junctions/{simulation_name}")

options = " ".join(
    [
        "out:level 200",
        "corrections::beta_nov16 true",
        "indexed_structure_store:fragment_store /net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5",
    ]
)

extra_kwargs = {
    "reference_csv": "/mnt/projects/crispy_shifty/scaffolds/02_make_free_states/free_state_0s.csv",
}

gen_array_tasks(
    distribute_func="crispy_shifty.protocols.states.pair_bound_state",
    design_list_file=design_list_file,
    output_path=output_path,
    queue="short",
    memory="4G",
    nstruct=1,
    nstruct_per_task=10,
    options=options,
    extra_kwargs=extra_kwargs,
    simulation_name=simulation_name,
)

### Collect scorefiles of the paired bound states and concatenate

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import collect_score_file

simulation_name = "04_pair_bound_states"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

if not os.path.exists(os.path.join(output_path, "scores.json")):
    collect_score_file(output_path, "scores")

### Load resulting concatenated scorefile

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import parse_scorefile_linear

output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

if not os.path.exists(os.path.join(output_path, "scores.csv")):
    scores_df = parse_scorefile_linear(os.path.join(output_path, "scores.json"))

### Dump scores_df as a CSV and then reload, for performance reasons

In [None]:
if not os.path.exists(os.path.join(output_path, "scores.csv")):
    scores_df.to_csv(os.path.join(output_path, "scores.csv"))

scores_df = pd.read_csv(os.path.join(output_path, "scores.csv"), index_col="Unnamed: 0")

### Setup for plotting

In [None]:
sns.set(
    context="talk",
    font_scale=1,  # make the font larger; default is pretty small
    style="ticks",  # make the background white with black lines
    palette="colorblind",  # a color palette that is colorblind friendly!
)

### Data exploration
Gonna remove the Rosetta sfxn scoreterms for now

In [None]:
from crispy_shifty.protocols.design import beta_nov16_terms

scores_df = scores_df[
    [term for term in scores_df.columns if term not in beta_nov16_terms]
]
print(len(scores_df))

In [None]:
print(list(scores_df.columns))

In [None]:
from crispy_shifty.utils.plotting import histplot_df, pairplot_df

to_plot = [
    "bb_clash_delta_x",
    "score_per_res_x",
    "wnm_all_x",
]
print(to_plot)

### Filter extreme outliers and change some dtypes

In [None]:
scores_df["bb_clash_delta_x"] = scores_df["bb_clash_delta_x"].astype(float)
scores_df["score_per_res_x"] = scores_df["score_per_res_x"].astype(float)
scores_df["wnm_all_x"] = scores_df["wnm_all_x"].astype(float)


query = "bb_clash_delta_x < 500 and score_per_res_x < 0 and wnm_all_x < 5"
sample_df = scores_df.query(query)

### Plot loop scores

In [None]:
cols = [
    "bb_clash_delta_x",
    "score_per_res_x",
    "wnm_all_x",
]
the_fig = histplot_df(
    df=sample_df,
    cols=cols,
    bins=10,
    hue="scaffold_type",
)
plt.savefig(os.path.join(output_path, "loop_scores.png"))

In [None]:
the_fig = pairplot_df(
    df=sample_df,
    cols=cols,
    hue="scaffold_type",
)
plt.savefig(os.path.join(output_path, "loop_scores_paired.png"))

### Filter out obviously bad decoys

In [None]:
query = "bb_clash_delta_x < 200 and score_per_res_x < -1.5 and wnm_all_x < 1.2"
filtered = scores_df.query(query)
len(filtered)

In [None]:
len(set([x for x in filtered.parent.values if "DHR" in x]))

### Save a list of outputs

In [None]:
simulation_name = "04_pair_bound_states"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

with open(os.path.join(output_path, "paired_states.list"), "w") as f:
    for path in tqdm(filtered.index):
        print(path, file=f)

### Pair the best designed bound states

In [None]:
from crispy_shifty.utils.io import gen_array_tasks

simulation_name = "04_pair_bound_states_best"
design_list_file = os.path.join(
    os.getcwd(),
    "projects/crispy_shifties/03_fold_bound_states/best_folded_states.list",
)
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

options = " ".join(
    [
        "out:level 200",
        "corrections::beta_nov16 true",
        "indexed_structure_store:fragment_store /net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5",
    ]
)

extra_kwargs = {
    "reference_csv": "/mnt/projects/crispy_shifty/scaffolds/02_make_free_states/free_state_0s.csv",
}

gen_array_tasks(
    distribute_func="crispy_shifty.protocols.states.pair_bound_state",
    design_list_file=design_list_file,
    output_path=output_path,
    queue="short",
    memory="3G",
    nstruct=1,
    nstruct_per_task=1,
    options=options,
    extra_kwargs=extra_kwargs,
    simulation_name=simulation_name,
)

### Collect scorefiles of the paired bound states and concatenate

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import collect_score_file

simulation_name = "04_pair_bound_states_best"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

if not os.path.exists(os.path.join(output_path, "scores.json")):
    collect_score_file(output_path, "scores")

### Load resulting concatenated scorefile

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import parse_scorefile_linear

output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

if not os.path.exists(os.path.join(output_path, "scores.csv")):
    scores_df = parse_scorefile_linear(os.path.join(output_path, "scores.json"))

### Dump scores_df as a CSV and then reload, for performance reasons

In [None]:
if not os.path.exists(os.path.join(output_path, "scores.csv")):
    scores_df.to_csv(os.path.join(output_path, "scores.csv"))

scores_df = pd.read_csv(os.path.join(output_path, "scores.csv"), index_col="Unnamed: 0")

### Filter out obviously bad decoys

In [None]:
query = "bb_clash_delta_x < 200 and score_per_res_x < -1.5 and wnm_all_x < 1.2"
filtered = scores_df.query(query)
len(filtered)

In [None]:
len(set([x for x in filtered.parent.values if "DHR" in x]))

### Save a list of outputs

In [None]:
simulation_name = "04_pair_bound_states_best"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

with open(os.path.join(output_path, "best_paired_states.list"), "w") as f:
    for path in tqdm(filtered.index):
        print(path, file=f)

### Prototyping blocks

test `pair_bound_state`

In [None]:
%%time 
import pyrosetta

pyrosetta.init("-corrections::beta_nov16 -indexed_structure_store:fragment_store /net/databases/VALL_clustered/connect_chains/ss_grouped_vall_helix_shortLoop.h5")

sys.path.insert(0, "/projects/crispy_shifty/")
from crispy_shifty.protocols.states import pair_bound_state

t = pair_bound_state(
        None,
        **{
            'pdb_path': '/mnt/projects/crispy_shifty/projects/crispy_shifties/03_fold_bound_states/decoys/0046/03_fold_bound_states_eb3d67f280a349e28cdbb86f4a1c4678.pdb.bz2',
            'reference_csv': '/mnt/projects/crispy_shifty/scaffolds/02_make_free_states/free_state_0s.csv',
        }
)
for i, tppose in enumerate(t):
    tppose.pose.dump_pdb(f"{i}.pdb")

In [None]:
d = dict(tppose.pose.scores)
from crispy_shifty.protocols.design import beta_nov16_terms

d = {k: v for k, v in d.items() if k not in beta_nov16_terms}

In [None]:
d