## Summary

----

## Imports

In [None]:
import concurrent.futures
import itertools
import os
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import proteinsolver
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from kmbio import PDB
from scipy import stats
from tqdm.notebook import tqdm

In [None]:
DEBUG = "CI" not in os.environ

In [None]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

In [None]:
%matplotlib inline

try:
    inline_rc
except NameError:
    inline_rc = mpl.rcParams.copy()
    
mpl.rcParams.update({"font.size": 12})

## Parameters

In [None]:
UNIQUE_ID = "191f05de"  # No attention
# UNIQUE_ID = "0007604c"  # 5-layer graph-conv with attention, batch_size=1
# UNIQUE_ID = "91fc9ab9"  # 4-layer graph-conv with attention, batch_size=4

In [None]:
BEST_STATE_FILES = {
    #
    "191f05de": "protein_train/191f05de/e53-s1952148-d93703104.state"
}

In [None]:
NOTEBOOK_NAME = "06_global_analysis_of_protein_folding"
NOTEBOOK_PATH = Path(NOTEBOOK_NAME).resolve()
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

In [None]:
INPUT_PATH = Path(os.getenv("DATAPKG_INPUT_DIR"))
INPUT_PATH

In [None]:
DATAPKG_DATA_DIR = Path(f"~/datapkg_data_dir").expanduser().resolve()
DATAPKG_DATA_DIR

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
proteinsolver.settings.data_url = DATAPKG_DATA_DIR.as_posix()
proteinsolver.settings.data_url

## Load data

In [None]:
!ls {INPUT_PATH}/global_analysis_of_protein_folding

In [None]:
!ls {INPUT_PATH}/global_analysis_of_protein_folding/aan0693_designed-PDB-files

In [None]:
!ls {INPUT_PATH}/global_analysis_of_protein_folding/aan0693_SI_datasets

### aan0693_SI_datasets

In [None]:
!ls {INPUT_PATH}/global_analysis_of_protein_folding/aan0693_SI_datasets/stability_scores

In [None]:
def remove_controls(df):
    df = df[
        (~df["name"].str.endswith("_hp")) & (~df["name"].str.endswith("_random")) & (~df["name"].str.endswith("_buryD"))
    ]
    return df

In [None]:
def load_stability_scores(key):
    stability_scores = pd.read_csv(
        INPUT_PATH
        / "global_analysis_of_protein_folding"
        / "aan0693_SI_datasets"
        / "stability_scores"
        / f"{key}_stability_scores",
        sep="\t",
    )
    stability_scores = remove_controls(stability_scores)

    for energy_function in ["talaris2013", "betanov15"]:
        rosetta_energies_file = (
            INPUT_PATH
            / "global_analysis_of_protein_folding"
            / "aan0693_SI_datasets"
            / "design_structural_metrics"
            / f"{key}_relax_scored_{'filtered_' if energy_function == 'betanov15' else ''}{energy_function}.sc"
        )
        if not rosetta_energies_file.is_file():
            print(f"Not loading Rosetta energies for {energy_function}!")
            continue

        before_ = len(stability_scores)
        relax_scored_filtered = pd.read_csv(
            rosetta_energies_file, sep="\t" if energy_function == "betanov15" else " +", engine="python"
        ).rename(columns={"description": "name", "total_score": f"{energy_function}_score"})
        stability_scores = stability_scores.merge(
            relax_scored_filtered[["name", f"{energy_function}_score"]], on="name", how="outer"
        )
#         assert len(stability_scores) == before_, (len(stability_scores), before_)

    stability_scores["library_name"] = key
    return stability_scores

### stability_scores

In [None]:
# stability_scores = {}

In [None]:
# for key in ["rd1", "rd2", "rd3", "rd4", "ssm2"]:
#     stability_scores[key] = load_stability_scores(key)

In [None]:
# stability_scores["fig1"] = pd.read_csv(
#     INPUT_PATH / "global_analysis_of_protein_folding" / "aan0693_SI_datasets" / "fig1_thermodynamic_data.csv"
# ).assign(library_name="fig1")

In [None]:
stability_scores = torch.load(NOTEBOOK_PATH.joinpath("stability_scores.torch"))

## Load model

In [None]:
%run protein_train/{UNIQUE_ID}/model.py

In [None]:
batch_size = 1
num_features = 20
adj_input_size = 2
hidden_size = 128
frac_present = 0.5
frac_present_valid = frac_present
info_size= 1024

In [None]:
state_file = BEST_STATE_FILES[UNIQUE_ID]
state_file

In [None]:
net = Net(
    x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
)
net.load_state_dict(torch.load(state_file, map_location=device))
net.eval()
net = net.to(device)

## Mutation probabilities

### Test network

In [None]:
NOTEBOOK_PATH.parents[2]

In [None]:
dataset = []
for structure_id in ["5vli02", "1n5uA03", "4z8jA00", "4unuA00", "4beuA02"]:
    structure_file = Path(
        os.getenv(
            "STRUCTURE_FILE",
            NOTEBOOK_PATH.parent.parent
            / "proteinsolver"
            / "data"
            / "inputs"
            / f"{structure_id}.pdb",
        )
    ).resolve()
    structure = PDB.load(structure_file)
    pdata = proteinsolver.utils.extract_seq_and_adj(structure, list(structure[0])[0].id)
    data = proteinsolver.datasets.protein.row_to_data(pdata)
    data = proteinsolver.datasets.protein.transform_edge_attr(data)
    dataset.append(data)

In [None]:
start_time = time.perf_counter()
for data in tqdm(dataset):
    data = data.to(device)
    out = net(data.x, data.edge_index, data.edge_attr)
    data.x[0] = 0
    out = net(data.x, data.edge_index, data.edge_attr)
print(f"Elapsed time: {time.perf_counter() - start_time}.")

In [None]:
start_time = time.perf_counter()
for data in tqdm(dataset):
    data = data.to(device)
    proteinsolver.utils.scan_with_mask(net, data.x, data.edge_index, data.edge_attr, 20)
print(f"Elapsed time: {time.perf_counter() - start_time}.")