## Summary

----

## Imports

In [1]:
import concurrent.futures
import itertools
import os
from pathlib import Path

import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt
import proteinsolver
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from kmbio import PDB
from scipy import stats

In [2]:
DEBUG = "CI" not in os.environ

In [3]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

In [4]:
%matplotlib inline

try:
    inline_rc
except NameError:
    inline_rc = mpl.rcParams.copy()
    
mpl.rcParams.update({"font.size": 12})

## Parameters

In [5]:
UNIQUE_ID = "191f05de"  # No attention
# UNIQUE_ID = "0007604c"  # 5-layer graph-conv with attention, batch_size=1
# UNIQUE_ID = "91fc9ab9"  # 4-layer graph-conv with attention, batch_size=4

In [6]:
BEST_STATE_FILES = {
    #
    "191f05de": "protein_train/191f05de/e53-s1952148-d93703104.state"
}

In [7]:
NOTEBOOK_NAME = "06_global_analysis_of_protein_folding"
NOTEBOOK_PATH = Path(NOTEBOOK_NAME)
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

PosixPath('06_global_analysis_of_protein_folding')

In [8]:
INPUT_PATH = Path(os.getenv("DATAPKG_INPUT_DIR"))
INPUT_PATH

PosixPath('/home/kimlab1/database_data/datapkg_input_dir')

In [9]:
DATAPKG_DATA_DIR = Path(f"~/datapkg_data_dir").expanduser().resolve()
DATAPKG_DATA_DIR

PosixPath('/home/kimlab1/database_data/datapkg_output_dir')

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [11]:
proteinsolver.settings.data_url = DATAPKG_DATA_DIR.as_posix()
proteinsolver.settings.data_url

'/home/kimlab1/database_data/datapkg_output_dir'

## Load data

In [12]:
!ls {INPUT_PATH}/global_analysis_of_protein_folding

aan0693_designed-PDB-files	   hhpred
aan0693_designed-PDB-files.gz.zip  hhpred2
aan0693_SI_datasets		   mutation_structures_for_rosetta
aan0693_SI_datasets.tar.gz.zip	   swissmodel


In [13]:
!ls {INPUT_PATH}/global_analysis_of_protein_folding/aan0693_designed-PDB-files

aan0693_designed-PDB-files.gz  nmr  other  other2  rd1	rd2  rd3  rd4


In [14]:
!ls {INPUT_PATH}/global_analysis_of_protein_folding/aan0693_SI_datasets

counts_and_ec50s.tar.gz		  protein_and_dna_sequences.tar.gz
design_scripts.tgz		  stability_scores
design_structural_metrics	  stability_scores.tar.gz
design_structural_metrics.tar.gz  unfolded_state_model_params
fig1_thermodynamic_data.csv


### aan0693_SI_datasets

In [15]:
!ls {INPUT_PATH}/global_analysis_of_protein_folding/aan0693_SI_datasets/stability_scores

rd1_stability_scores  rd3_stability_scores  ssm2_stability_scores
rd2_stability_scores  rd4_stability_scores


In [16]:
def remove_controls(df):
    df = df[
        (~df["name"].str.endswith("_hp")) & (~df["name"].str.endswith("_random")) & (~df["name"].str.endswith("_buryD"))
    ]
    return df

In [17]:
def load_stability_scores(key):
    stability_scores = pd.read_csv(
        INPUT_PATH
        / "global_analysis_of_protein_folding"
        / "aan0693_SI_datasets"
        / "stability_scores"
        / f"{key}_stability_scores",
        sep="\t",
    )
    stability_scores = remove_controls(stability_scores)

    for energy_function in ["talaris2013", "betanov15"]:
        rosetta_energies_file = (
            INPUT_PATH
            / "global_analysis_of_protein_folding"
            / "aan0693_SI_datasets"
            / "design_structural_metrics"
            / f"{key}_relax_scored_{'filtered_' if energy_function == 'betanov15' else ''}{energy_function}.sc"
        )
        if not rosetta_energies_file.is_file():
            print(f"Not loading Rosetta energies for {energy_function}!")
            continue

        before_ = len(stability_scores)
        relax_scored_filtered = pd.read_csv(
            rosetta_energies_file, sep="\t" if energy_function == "betanov15" else " +", engine="python"
        ).rename(columns={"description": "name", "total_score": f"{energy_function}_score"})
        stability_scores = stability_scores.merge(
            relax_scored_filtered[["name", f"{energy_function}_score"]], on="name", how="outer"
        )
#         assert len(stability_scores) == before_, (len(stability_scores), before_)

    stability_scores["library_name"] = key
    return stability_scores

### stability_scores

In [18]:
# stability_scores = {}

In [19]:
# for key in ["rd1", "rd2", "rd3", "rd4", "ssm2"]:
#     stability_scores[key] = load_stability_scores(key)

In [20]:
# stability_scores["fig1"] = pd.read_csv(
#     INPUT_PATH / "global_analysis_of_protein_folding" / "aan0693_SI_datasets" / "fig1_thermodynamic_data.csv"
# ).assign(library_name="fig1")

In [21]:
stability_scores = torch.load(NOTEBOOK_PATH.joinpath("stability_scores.torch"))

## Load model

In [22]:
%run protein_train/{UNIQUE_ID}/model.py

In [23]:
batch_size = 1
num_features = 20
adj_input_size = 2
hidden_size = 128
frac_present = 0.5
frac_present_valid = frac_present
info_size= 1024

In [24]:
state_file = BEST_STATE_FILES[UNIQUE_ID]
state_file

'protein_train/191f05de/e53-s1952148-d93703104.state'

In [25]:
net = Net(
    x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
)
net.load_state_dict(torch.load(state_file, map_location=device))
net.eval()
net = net.to(device)

## Mutation probabilities

### Test network

In [26]:
structure_file = INPUT_PATH.joinpath(
    "global_analysis_of_protein_folding", "aan0693_designed-PDB-files", "rd1", "EEHEE_rd1_0004.pdb"
)

In [27]:
structure = PDB.load(structure_file)

In [28]:
pdata = proteinsolver.utils.extract_seq_and_adj(structure, 'A')

In [29]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
data = proteinsolver.datasets.protein.transform_edge_attr(data)
data.to(device)

proteinsolver.utils.scan_with_mask(
    net, data.x, data.edge_index, data.edge_attr, 20
)

tensor([ 4.2337e+00,  1.4978e+00,  7.5491e+00,  9.7959e-02, -1.7546e+00,
         2.8181e+00,  5.4245e+00,  1.3950e+00,  2.0784e+00,  3.1730e+00,
         1.3765e+00,  6.6040e+00, -9.9698e-04,  4.2144e+00,  1.3128e+00,
         6.9296e-02,  5.1866e+00,  6.3043e+00,  1.9462e+00,  1.1357e+00,
         4.4304e+00,  6.7671e+00,  2.5967e+00,  1.5917e+00,  7.2876e-01,
         5.6794e+00,  4.6740e+00,  2.3055e-01,  2.4577e-01,  7.3126e+00,
         2.6852e+00,  5.3051e-01,  1.5237e+00,  4.0610e+00,  4.3447e-02,
         3.3018e+00,  4.3885e+00,  3.5008e+00,  4.2183e+00,  2.1397e+00,
        -2.2222e-01])

### mutation_proba_ssm2

In [30]:
def get_mutation_proba_ssm2(net, filename, sequence):
    if not filename.endswith(".pdb"):
        filename = filename + ".pdb"

    try:
        library_name = filename.split("_")[1]
        if library_name not in ["rd1", "rd2", "rd3", "rd4"]:
            library_name = "other"
    except IndexError:
        library_name = "other"

    filepath = (
        INPUT_PATH / "global_analysis_of_protein_folding" / "aan0693_designed-PDB-files" / library_name / filename
    )
    if not filepath.is_file():
        return np.nan

    structure = PDB.load(filepath)
    chain = "X" if library_name == "other" else "A"
    pdata = proteinsolver.utils.extract_seq_and_adj(structure, chain)

    if filename in ["EHEE_rd2_0005.pdb", "EHEE_rd3_0015.pdb"]:
        # Skip because the number of amino acids in the PDB does not match the number of amino acids in the mutated sequence
        return np.nan

    if sum([(a == b) for a, b in zip(sequence, pdata.sequence)]) < (len(sequence) * 0.8):
        print(f"Warning, sequence does not match for protien {filename}")
        return np.nan

    pdata = pdata._replace(sequence=sequence)

    data = proteinsolver.datasets.protein.row_to_data(pdata)
    data = proteinsolver.datasets.protein.transform_edge_attr(data)
    data = data.to(device)

    log_prob = proteinsolver.utils.scan_with_mask(
        net, data.x, data.edge_index, data.edge_attr, 20
    )

    return log_prob.sum().item()

In [31]:
stability_scores["ssm2"]["network_score"] = [
    get_mutation_proba_ssm2(net, filename, sequence)
    for filename, sequence in stability_scores["ssm2"][["my_wt", "sequence"]].values
]

### mutation_proba_fig1

In [32]:
def get_mutation_proba_fig1(net, filename, sequence):
    if not filename.endswith(".pdb"):
        filename = filename + ".pdb"

    try:
        library_name = filename.split("_")[1]
        if library_name not in ["rd1", "rd2", "rd3", "rd4"]:
            library_name = "other2"
    except IndexError:
        library_name = "other2"

    filepath = (
        INPUT_PATH / "global_analysis_of_protein_folding" / "aan0693_designed-PDB-files" / library_name / filename
    )
    if not filepath.is_file():
        return np.nan

    structure = PDB.load(filepath)
    chain = "X" if library_name == "other2" else "A"
    pdata = proteinsolver.utils.extract_seq_and_adj(structure, chain)
    if filename == "villin.pdb":
        pdata = pdata._replace(sequence=pdata.sequence + "SSSGSSGS")
    elif filename == "BBL.pdb":
        pdata = pdata._replace(sequence=pdata.sequence + "N")

    if len(pdata.sequence) != len(sequence):
        if filename.startswith("villin") or filename.startswith("BBL"):
            # Checked and those two are OK to truncate
            pass
        else:
            print(
                f"Warning, sequence lengths do not match for protien {filename} ({len(pdata.sequence)}, {len(sequence)})"
            )
            return np.nan
    if sum([a == b for a, b in zip(sequence, pdata.sequence)]) < (len(sequence) * 0.6):
        print(f"Warning, sequences do not match for protien {filename}")
        return np.nan
    pdata = pdata._replace(sequence=sequence)

    data = proteinsolver.datasets.protein.row_to_data(pdata)
    data = proteinsolver.datasets.protein.transform_edge_attr(data)
    data = data.to(device)

    node_outputs = proteinsolver.utils.scan_with_mask(
        net, data.x, data.edge_index, data.edge_attr, 20
    )

    return node_outputs.mean().item()

In [33]:
stability_scores["fig1"]["my_wt"] = stability_scores["fig1"]["name"].apply(
    lambda name: {"BBL": "BBL_H142W.pdb"}.get(name.split("_")[0], name.split("_")[0] + "_wt.pdb")
)

In [34]:
stability_scores["fig1"]["network_score"] = [
    get_mutation_proba_fig1(net, filename, sequence)
    for filename, sequence in stability_scores["fig1"][["my_wt", "sequence"]].values
]

In [35]:
torch.save(stability_scores, NOTEBOOK_PATH.joinpath("stability_scores_for_mutations.raw.scan.torch"))

### Add deltas

In [36]:
def add_diff(df, colname):
    if f"{colname}_wt" not in df:
        before_ = len(df)
        ext = ".pdb" if not df.iloc[0]["name"].endswith(".pdb") else ""
        row_is_wt = (df["name"] + ext) == df["my_wt"]
        assert set(df[row_is_wt]["my_wt"]) == set(df["my_wt"])
        df2 = df[row_is_wt][["my_wt", colname]].rename(columns={colname: f"{colname}_wt"})
        df = df.merge(df2, on="my_wt")
        assert before_ == len(df), (before_, len(df))
    df[f"{colname}_change"] = df[f"{colname}"] - df[f"{colname}_wt"]
    return df

In [37]:
stability_scores["ssm2"].columns

Index(['name', 'sequence', 'my_wt', 'pos', 'mut', 'wt_aa', 'ec50_t',
       'delta_ec50_t', 'ec50_95ci_lbound_t', 'ec50_95ci_ubound_t',
       'ec50_95ci_t', 'ec50_pred_t', 'delta_pred_vs_wt_t', 'ec50_rise_t',
       'stabilityscore_t', 'ec50_c', 'delta_ec50_c', 'ec50_95ci_lbound_c',
       'ec50_95ci_ubound_c', 'ec50_95ci_c', 'ec50_pred_c',
       'delta_pred_vs_wt_c', 'ec50_rise_c', 'stabilityscore_c',
       'ec50_rise_c_adj', 'stabilityscore_c_adj', 'consensus_ec50_rise',
       'consensus_stability_score', 'library_name', 'network_score'],
      dtype='object')

In [38]:
stability_scores["ssm2"]["consensus_stability_score2"] = (
    stability_scores["ssm2"]["stabilityscore_t"] + stability_scores["ssm2"]["stabilityscore_c"]
) / 2

for column in [
    "network_score",
    "stabilityscore_t",
    "stabilityscore_c",
    "consensus_stability_score",
    "consensus_stability_score2",
]:
    stability_scores["ssm2"] = add_diff(stability_scores["ssm2"], column)

In [39]:
stability_scores_wpreds = stability_scores

In [40]:
# torch.save(stability_scores, NOTEBOOK_PATH.joinpath("stability_scores_form_mutations.logprob.incremental.torch"))

## For Carles

In [41]:
mutations_for_rosetta = stability_scores_wpreds["ssm2"][
    stability_scores_wpreds["ssm2"]["network_score_change"].notnull()
][["my_wt", "pos", "wt_aa", "mut", "sequence"]]

In [42]:
len(mutations_for_rosetta["my_wt"].unique())

13

In [43]:
mutations_for_rosetta.to_csv(
    "/home/kimlab1/database_data/datapkg_input_dir/global_analysis_of_protein_folding/mutation_structures_for_rosetta/gapf_mutations_for_rosetta.csv",
    index=False,
)

## Figure

In [44]:
stability_scores["fig1"] = stability_scores["fig1"].rename(
    columns={"deltaGunf thermal": "dg_thermal", "deltaGunf chemical": "dg_chemical"}
)

stability_scores["fig1"]["consensus_stability_score"] = (
    stability_scores["fig1"]["stabilityscore_t"] + stability_scores["fig1"]["stabilityscore_c"]
) / 2

stability_scores["fig1"]["consensus_dg"] = (
    stability_scores["fig1"]["dg_thermal"] + stability_scores["fig1"]["dg_chemical"]
) / 2

for column in [
    "network_score",
    "stabilityscore_t",
    "stabilityscore_c",
    "consensus_stability_score",
    "dg_thermal",
    "dg_chemical",
    "consensus_dg",
    "Tm",
]:
    stability_scores["fig1"] = add_diff(stability_scores["fig1"], column)

### Generate correlations

In [45]:
def get_conf_interval(r, num):
    import math

    stderr = 1.0 / math.sqrt(num - 3)
    delta = 1.96 * stderr
    lower = math.tanh(math.atanh(r) - delta)
    upper = math.tanh(math.atanh(r) + delta)
    return r - lower, upper - r 

In [46]:
x_col = "network_score_change"
y_col = "consensus_stability_score_change"
library = "ssm2"

correlations = []

for my_wt, gp in stability_scores[library].groupby("my_wt"):
    x = gp[[x_col, y_col]].dropna()
    if x.empty:
        print(f"Skipping {my_wt}!")
        continue
    corr = stats.spearmanr(x[x_col], x[y_col])
    corr_conf = get_conf_interval(corr[0], len(x))
    correlations.append(
        {"my_wt": my_wt, "corr": corr[0], "pvalue": corr[1], "corr_lower_bound": corr_conf[0], "corr_upper_bound": corr_conf[1]}
    )
correlations_df = pd.DataFrame(correlations, columns=correlations[0].keys())
#     plt.title(f"Spearman R: {corr[0]:.4f} ({corr[1]:.3f})")
#     plt.plot(x[x_col], x[y_col], "r.")
#     plt.suptitle(my_wt)
#     plt.show()

Skipping EHEE_0882.pdb!
Skipping EHEE_rd2_0005.pdb!
Skipping EHEE_rd3_0015.pdb!
Skipping HHH_0142.pdb!


In [47]:
correlations_df

Unnamed: 0,my_wt,corr,pvalue,corr_lower_bound,corr_upper_bound
0,EEHEE_rd3_0037.pdb,0.394147,3.3060639999999996e-30,0.061183,0.057878
1,EEHEE_rd3_1498.pdb,0.393333,4.440973e-30,0.061226,0.057925
2,EEHEE_rd3_1702.pdb,0.149497,2.932858e-05,0.069584,0.068134
3,EEHEE_rd3_1716.pdb,0.434199,5.645419e-37,0.058951,0.055452
4,HEEH_rd2_0779.pdb,0.480017,6.493175e-46,0.056094,0.052426
5,HEEH_rd3_0223.pdb,0.393417,4.307661e-30,0.061221,0.05792
6,HEEH_rd3_0726.pdb,0.256082,4.534633e-13,0.067015,0.064641
7,HEEH_rd3_0872.pdb,0.237012,2.34207e-11,0.067597,0.065378
8,HHH_rd2_0134.pdb,0.295338,4.5963010000000005e-17,0.065648,0.062973
9,HHH_rd3_0138.pdb,0.429652,3.688339e-36,0.059216,0.055738


In [48]:
category_colors = {
    **{c: cmap(0) for c in category_names if c.startswith("mae") or c.startswith("rmsd")},
    **{c: cmap(1) for c in category_names if not c.startswith("mae") and not c.startswith("rmsd")},
}

NameError: name 'category_names' is not defined

In [None]:
def make_spider_plot(title, row, color, ax):
    other_color = "k"

    yticks = np.array([0.2, 0.4, 0.6, 0.8, 1.0])

    # Get a list of categories
    categories = row.index.values.tolist()

    # Set the angles of each category
    angles = np.linspace(0, 2 * np.pi, len(categories) + 1)

    # Put the first category on top and going clockwise
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)

    labels = []
    for c in categories:
        cs = c.rsplit("_", 1)
        if len(cs) == 1:
            labels.append(c)
        else:
            labels.append(f"{cs[0]}\n_{cs[1]}")

    # Configure X axis
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels, size=10)
    for c, xtick in zip(labels, ax.get_xticklabels()):
        xtick.set_color(other_color)

    # Configure Y axis
    ax.set_rlabel_position(0)
    ax.set_yticks(yticks)
    ax.set_yticklabels([f"{v:.1f}" for v in yticks], color=other_color)
    ax.set_ylim(0, 0.7)

    # Configure grid
    ax.grid(color="k", linestyle=":", linewidth=0.5)

    # Plot values
    ax.plot(angles, np.r_[df["corr"], df["corr"][0]], color=color, linewidth=2, linestyle="solid")
    ax.errorbar(
        angles,
        np.r_[df["corr"], df["corr"][0]],
        color=color,
#         ecolor=other_color,
#         linewidth=2,
#         linestyle="solid",
        yerr=np.c_[df[["corr_upper_bound", "corr_lower_bound"]].values.T, df[["corr_upper_bound", "corr_lower_bound"]].values.T[:, :1]],
    )
    ax.fill(angles, np.r_[df["corr"], df["corr"][0]], color=color, alpha=0.4)

    # Add a title
    if title:
        ax.set_title(title, size=12, pad=20)


#     add_yticks(ax, angles[len(angles) // 2], yticks, yticks, color=category_colors[categories[3]])

In [None]:
df = correlations_df.assign(my_wt=correlations_df["my_wt"].str.replace(".pdb", "")).set_index("my_wt")
df = pd.concat([df.iloc[-1:], df.iloc[:-1]])

cmap = plt.cm.get_cmap("Set1", 8)

fg, ax = plt.subplots(figsize=(4.5, 4.5), subplot_kw=dict(polar=True))
make_spider_plot("", df, cmap(0), ax=ax)
fg.tight_layout()
fg.savefig(NOTEBOOK_PATH.joinpath("correlations_with_mutations.svg"))
fg.savefig(NOTEBOOK_PATH.joinpath("correlations_with_mutations.pdf"))
fg.savefig(NOTEBOOK_PATH.joinpath("correlations_with_mutations.svg"), dpi=300)

In [None]:
correlations_df

## Playground

In [None]:
correlations_df.sort_values("my_wt", ascending=False)["my_wt"].str.replace(".pdb", "")).set_index("my_wt")["corr"].in

In [None]:
x_col = "consensus_stability_score_change"
y_col = "consensus_dg_change"

x = stability_scores["fig1"][[x_col, y_col]].dropna()

corr = stats.spearmanr(x[x_col], x[y_col])

plt.title(f"Spearman R: {corr[0]:.4f} ({corr[1]:.3f})")
plt.plot(x[x_col], x[y_col], "r.")

In [None]:
if "consensus_stability_score_wt" not in stability_scores["ssm2"]:
    before_ = len(stability_scores["ssm2"])
    row_is_wt = stability_scores["ssm2"]["name"] == stability_scores["ssm2"]["my_wt"]
    df = stability_scores["ssm2"][row_is_wt][["my_wt", "consensus_stability_score"]].rename(
        columns={"consensus_stability_score": "consensus_stability_score_wt"}
    )
    stability_scores["ssm2"] = stability_scores["ssm2"].merge(df, on="my_wt")
    assert before_ == len(stability_scores["ssm2"])

In [None]:
stability_scores["ssm2"]["consensus_stability_score_change"] = (
    stability_scores["ssm2"]["consensus_stability_score"] - stability_scores["ssm2"]["consensus_stability_score_wt"]
)

In [None]:
x = stability_scores["ssm2"][["sum_log_prob_change", "consensus_stability_score_change"]].dropna()

corr = stats.spearmanr(
    x["sum_log_prob_change"],
    x["consensus_stability_score_change"],
)

fg, ax = plt.subplots()

ax.plot(
    x["sum_log_prob_change"],
    x["consensus_stability_score_change"],
    'r.',
    alpha=0.5
)
ax.set_title(f"Spearman R: {corr[0]:.4f} ({corr[1]:.3f})")
fg.tight_layout()

In [None]:
def myplot(x, y, color, ax):
    x, y = pd.DataFrame({'x': x, 'y': y}).copy().dropna().values.T
    
    corr = stats.spearmanr(x, y)
    ax.plot(x, y, color=color)
    ax.set_title(f"N: {len(x):,};   Spearman R: {corr[0]:.3f} ({corr[1]:.2f})")

In [None]:
for library_name in list(stability_scores)[:3]:
    df = stability_scores[library_name][
        (stability_scores[library_name]["sum_log_prob"].isnull())
        & ((stability_scores[library_name]["talaris2013_score"].notnull()))
    ]
    print(library_name, len(df))

In [None]:
for library_name in list(stability_scores)[:3]:
    df = stability_scores[library_name][
        (stability_scores[library_name]["sum_log_prob"].notnull())
        & ((stability_scores[library_name]["talaris2013_score"].isnull()))
    ]
    print(library_name, len(df))

In [None]:
for library_name in list(stability_scores)[:3]:
    df = stability_scores[library_name][
        ((stability_scores[library_name]["talaris2013_score"].isnull()))
    ]
    print(library_name, len(df))

In [None]:
stability_scores["rd1"].head()

In [None]:
results = []

for library in ["rd1", "rd2", "rd3", "rd4"]:
    df = stability_scores[library].dropna(subset=["sum_log_prob", "talaris2013_score", "betanov15_score"])
    df["domain"] = df["name"].str.split("_").str[0]
    for domain, gp in df.groupby("domain"):
        for score in ["sum_log_prob", "talaris2013_score", "betanov15_score"]:
            corr_t = stats.spearmanr(gp[score], gp["stabilityscore_t"])
            corr_c = stats.spearmanr(gp[score], gp["stabilityscore_c"])
            results.append((library, domain, score, corr_t[0], corr_t[1], corr_c[0], corr_c[1]))

In [None]:
x_col = "sum_log_prob_normed"
df = stability_scores[library_name].dropna(subset=["sum_log_prob", "talaris2013_score", "betanov15_score"])

fg, axs = plt.subplots(1, 2, figsize=(12, 4))

myplot(df[x_col], df["stabilityscore_t"], ax=axs[0])
myplot(df[x_col], df["stabilityscore_c"], ax=axs[1])

In [None]:
x_col = "talaris2013_score_normed"
df = stability_scores[library_name].dropna(subset=["sum_log_prob", "talaris2013_score", "betanov15_score"])

fg, axs = plt.subplots(1, 2, figsize=(12, 4))

myplot(df[x_col], df["stabilityscore_t"], ax=axs[0])
myplot(df[x_col], df["stabilityscore_c"], ax=axs[1])

In [None]:
x_col = "betanov15_score_normed"
df = stability_scores[library_name].dropna(subset=["sum_log_prob", "talaris2013_score", "betanov15_score"])

fg, axs = plt.subplots(1, 2, figsize=(12, 4))

myplot(df[x_col], df["stabilityscore_t"], ax=axs[0])
myplot(df[x_col], df["stabilityscore_c"], ax=axs[1])

In [None]:
x_col = "talaris2013_score_normed"
df = stability_scores["rd1"]

fg, axs = plt.subplots(1, 2, figsize=(12, 4))

myplot(stability_scores["rd1"][x_col], stability_scores["rd1"]["stabilityscore_t"], ax=axs[0])
myplot(stability_scores["rd1"][x_col], stability_scores["rd1"]["stabilityscore_c"], ax=axs[1])

In [None]:
df = stability_scores["rd1"]

fg, axs = plt.subplots(1, 2, figsize=(12, 4))

axs[0].plot(-df["talaris2013_score"] / df["sequence_lenght"], df["stabilityscore_t"], "r.")
axs[1].plot(-df["talaris2013_score"] / df["sequence_lenght"], df["stabilityscore_c"], "r.")

In [None]:
df = stability_scores["rd1"]

fg, axs = plt.subplots(1, 2, figsize=(12, 4))

axs[0].plot(-df["betanov15_score"] / df["sequence_lenght"], df["stabilityscore_t"], "r.")
axs[1].plot(-df["betanov15_score"] / df["sequence_lenght"], df["stabilityscore_c"], "r.")

In [None]:
get_structure_proba(net, "rd1", "EEHEE_rd1_0001.pdb")

In [None]:
rd3_stability_scores["sequence_lenght"] = [
    get_sequence_length(net, rd3_structure_path / name) for name in rd3_stability_scores["name"]
]

In [None]:
print(len(rd3_stability_scores))
df = rd3_stability_scores.dropna(how='any').copy()
print(len(df))

stats.spearmanr(df["sum_log_proba"] / df['sequence_lenght'], df["stabilityscore_t"])

In [None]:
plt.plot(df["sum_log_proba"] / df['sequence_lenght'], df["stabilityscore_t"], 'r.')

In [None]:
df['type'] = df['name'].str.split("_").str[0]

In [None]:
for type, gp in df.groupby("type"):
    print(type)
    c = "ec50_rise_t"
    print(stats.spearmanr(gp["sum_log_proba"] / gp['sequence_lenght'], gp[c]))
    plt.plot(gp["sum_log_proba"] / gp['sequence_lenght'], gp[c], 'r.')
    plt.show()