# Run MPNN multistate design on the paired states

### Imports

In [1]:
%load_ext lab_black
# Python standard library
from glob import glob
import os
import socket
import sys

# 3rd party library imports
import dask
import matplotlib.pyplot as plt
import pandas as pd
import pyrosetta
import numpy as np
import scipy
import seaborn as sns
from tqdm.auto import tqdm  # jupyter compatible progress bar

tqdm.pandas()  # link tqdm to pandas
# Notebook magic
# save plots in the notebook
%matplotlib inline
# reloads modules automatically before executing cells
%load_ext autoreload
%autoreload 2
print(f"running in directory: {os.getcwd()}")  # where are we?
print(f"running on node: {socket.gethostname()}")  # what node are we on?

running in directory: /mnt/projects/crispy_shifty/projects/crispy_shifties
running on node: dig68


### Set working directory to the root of the crispy_shifty repo

In [2]:
os.chdir("/projects/crispy_shifty")

### Run MPNN on the paired states

In [None]:
from crispy_shifty.utils.io import gen_array_tasks

simulation_name = "06_mpnn_paired_states"
design_list_file = os.path.join(
    os.getcwd(),
    "projects/crispy_shifties/05_design_paired_states/designed_paired_states.list",
)
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

options = " ".join(
    [
        "out:level 200",
    ]
)

extra_kwargs = {
    "num_sequences": "100",
    "mpnn_temperature": "0.2",
    "mpnn_design_area": "scan",
}

gen_array_tasks(
    distribute_func="crispy_shifty.protocols.mpnn.mpnn_paired_state",
    design_list_file=design_list_file,
    output_path=output_path,
    queue="short",
    memory="5G",
    nstruct=1,
    nstruct_per_task=1,
    options=options,
    extra_kwargs=extra_kwargs,
    simulation_name=simulation_name,
)

### Collect scorefiles of the MPNN designed paired states and concatenate

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import collect_score_file

simulation_name = "06_mpnn_paired_states"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

if not os.path.exists(os.path.join(output_path, "scores.json")):
    collect_score_file(output_path, "scores")

### Load resulting concatenated scorefile

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import parse_scorefile_linear

output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

if not os.path.exists(os.path.join(output_path, "scores.csv")):
    scores_df = parse_scorefile_linear(os.path.join(output_path, "scores.json"))

### Dump scores_df as a CSV and then reload, for performance reasons

In [None]:
if not os.path.exists(os.path.join(output_path, "scores.csv")):
    scores_df.to_csv(os.path.join(output_path, "scores.csv"))

scores_df = pd.read_csv(os.path.join(output_path, "scores.csv"), index_col="Unnamed: 0")

### Data exploration

In [None]:
print(len(scores_df))
print(list(scores_df.columns))

In [None]:
set(scores_df.mpnn_msd_design_area.values)

### Rebalance the df

In [None]:
rebalanced = []
goal_representation = 32

for state in tqdm(set(scores_df.state.values)):
    subset_df = scores_df.query(f"state == @state")
    # if len is less than or equal to goal_representation take all,
    if len(subset_df) <= goal_representation:
        rebalanced.extend(list(subset_df.index))
    else:
        # get n randomly selected from this state
        sample = subset_df.sample(goal_representation, random_state=0)
        rebalanced.extend(list(sample.index))
rebalanced_df = scores_df.loc[rebalanced]
len(rebalanced_df)

### Save individual fastas

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import df_to_fastas

output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

rebalanced_df = df_to_fastas(rebalanced_df, prefix="mpnn_seq")

### Save a list of outputs
Sort by length

In [None]:
simulation_name = "06_mpnn_paired_states"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

scores_df = rebalanced_df.sort_values("looped_length")

with open(os.path.join(output_path, "mpnn_paired_states.list"), "w") as f:
    for path in tqdm(scores_df.index):
        print(path, file=f)

### Concat the pdb.bz2 and fasta paths into a single list, for Superfold reasons

In [None]:
simulation_name = "06_mpnn_paired_states"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

with open(os.path.join(output_path, "mpnn_paired_states.pair"), "w") as f:
    for path in tqdm(scores_df.index):
        line = path + "____" + path.replace("decoys", "fastas").replace("pdb.bz2", "fa")
        print(line, file=f)

### Run MPNN on the best paired states

In [None]:
from crispy_shifty.utils.io import gen_array_tasks

simulation_name = "06_mpnn_paired_states_best"
design_list_file = os.path.join(
    os.getcwd(),
    "projects/crispy_shifties/05_design_paired_states_best/best_designed_paired_states.list",
)
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

options = " ".join(
    [
        "out:level 200",
    ]
)

extra_kwargs = {
    "num_sequences": "100",
    "mpnn_temperature": "0.2",
    "mpnn_design_area": "scan",
}

gen_array_tasks(
    distribute_func="crispy_shifty.protocols.mpnn.mpnn_paired_state",
    design_list_file=design_list_file,
    output_path=output_path,
    queue="short",
    memory="5G",
    nstruct=1,
    nstruct_per_task=1,
    options=options,
    extra_kwargs=extra_kwargs,
    simulation_name=simulation_name,
)

### Collect scorefiles of the MPNN designed paired states and concatenate

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import collect_score_file

simulation_name = "06_mpnn_paired_states_best"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

if not os.path.exists(os.path.join(output_path, "scores.json")):
    collect_score_file(output_path, "scores")

### Load resulting concatenated scorefile

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import parse_scorefile_linear

output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

if not os.path.exists(os.path.join(output_path, "scores.csv")):
    scores_df = parse_scorefile_linear(os.path.join(output_path, "scores.json"))

### Dump scores_df as a CSV and then reload, for performance reasons

In [None]:
if not os.path.exists(os.path.join(output_path, "scores.csv")):
    scores_df.to_csv(os.path.join(output_path, "scores.csv"))

scores_df = pd.read_csv(os.path.join(output_path, "scores.csv"), index_col="Unnamed: 0")

### Rebalance the df

In [None]:
rebalanced = []
goal_representation = 32

for state in tqdm(set(scores_df.state.values)):
    subset_df = scores_df.query(f"state == @state")
    # if len is less than or equal to goal_representation take all,
    if len(subset_df) <= goal_representation:
        rebalanced.extend(list(subset_df.index))
    else:
        # get n randomly selected from this state
        sample = subset_df.sample(goal_representation, random_state=0)
        rebalanced.extend(list(sample.index))
rebalanced_df = scores_df.loc[rebalanced]
len(rebalanced_df)

### Save individual fastas

In [None]:
sys.path.insert(0, "/projects/crispy_shifty")
from crispy_shifty.utils.io import df_to_fastas

output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

rebalanced_df = df_to_fastas(rebalanced_df, prefix="mpnn_seq")

### Save a list of outputs
Sort by length

In [None]:
simulation_name = "06_mpnn_paired_states_best"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

scores_df = rebalanced_df.sort_values("looped_length")

with open(os.path.join(output_path, "best_mpnn_paired_states.list"), "w") as f:
    for path in tqdm(scores_df.index):
        print(path, file=f)

### Concat the pdb.bz2 and fasta paths into a single list, for Superfold reasons

In [None]:
simulation_name = "06_mpnn_paired_states_best"
output_path = os.path.join(os.getcwd(), f"projects/crispy_shifties/{simulation_name}")

with open(os.path.join(output_path, "best_mpnn_paired_states.pair"), "w") as f:
    for path in tqdm(scores_df.index):
        line = path + "____" + path.replace("decoys", "fastas").replace("pdb.bz2", "fa")
        print(line, file=f)

### Prototyping blocks

test `mpnn_paired_state`

In [None]:
%%time 
import pyrosetta

pyrosetta.init()


sys.path.insert(0, "/projects/crispy_shifty/")
from crispy_shifty.protocols.mpnn import mpnn_paired_state

t = mpnn_paired_state(
        None,
        **{
            'pdb_path': '/mnt/home/pleung/projects/crispy_shifty/projects/crispy_shifties/05_design_paired_states/decoys/0000/05_design_paired_states_c6be6ebc8a3146e2960cb45360a8a202.pdb.bz2',
            'num_sequences': 100,
        }
)
for i, tppose in enumerate(t):
    tppose.pose.dump_pdb(f"{i}.pdb")

In [None]:
d = dict(tppose.pose.scores)
#
d