In [31]:
import jupyter_black

jupyter_black.load()

In [32]:
import pickle
import torch
import h5py
from pathlib import Path
import numpy as np
import pandas as pd

csv_path = "../data/online_version/internal_data.csv"
fasta_full_path = "../data/queen_multimer/full.fasta"
fasta_mature_path = "../data/queen_multimer/mature.fasta"
path_queen_model = "../data/queen_multimer/QUEEN_MLPmodel_final.pkl"
hdf_mature_path = "../data/queen_multimer/mature_esm2_650m.h5"
hdf_full_path = "../data/queen_multimer/full_esm2_650m.h5"

csv_out = "../data/queen_multimer/queen_pred.csv"

In [34]:
df_online = pd.read_csv(csv_path)


def save_fasta(fasta_path: str, id_col: str, seq_col: str) -> None:
    with open(fasta_path, "w") as handle:
        for _, row in df_online.loc[df_online[seq_col].notna(), :].iterrows():
            handle.write(f">{row[id_col]}\n")
            handle.write(f"{row[seq_col]}\n")


save_fasta(fasta_path=fasta_full_path, id_col="identifier", seq_col="full_seq")
save_fasta(fasta_path=fasta_mature_path, id_col="identifier", seq_col="mature_seq")

In [35]:
# load model
model_location = "../data/queen_multimer/QUEEN_MLPmodel_final.pkl"
with open(model_location, "rb") as f:
    queen_model = pickle.load(f)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [36]:
def queen_prediction(hdf_path: str, queen_model, col_name: str) -> pd.DataFrame():
    # extract embeddings from h5
    embs, uids = [], []
    with h5py.File(hdf_path, "r") as hdf:
        for uid, emb in hdf.items():
            emb = np.array(emb)
            embs.append(emb)
            uids.append(uid)
    embs = np.array(embs)

    # make predictions
    y_test = queen_model.predict(embs)
    inv_map = {
        0: 1,
        1: 2,
        2: 3,
        3: 4,
        4: 5,
        5: 6,
        6: 7,
        7: 8,
        8: 10,
        9: 12,
        10: 14,
        11: 24,
    }
    y_test_transformed = np.array([inv_map[x] for x in y_test])
    df = pd.DataFrame(zip(uids, y_test_transformed), columns=["identifier", col_name])
    return df


df_full = queen_prediction(
    hdf_path=hdf_full_path, queen_model=queen_model, col_name="full_multimere_pred"
)
df_mature = queen_prediction(
    hdf_path=hdf_mature_path, queen_model=queen_model, col_name="mature_multimere_pred"
)
df_pred = pd.merge(left=df_full, right=df_mature, on="identifier", how="outer")
df_pred.loc[
    df_pred["full_multimere_pred"].notna()
    & (df_pred["full_multimere_pred"] != df_pred["mature_multimere_pred"])
]

Unnamed: 0,identifier,full_multimere_pred,mature_multimere_pred
0,NCBI|BOITR_GGUG01000006.1_1_318|Boiga_trigonata,2.0,1
3,NCBI|Boiga_dendrophila_dendrophila_GGUB0100000...,2.0,1
4,NCBI|Boiga_dendrophila_dendrophila_GGUB0100000...,2.0,1
5,NCBI|Boiga_dendrophila_dendrophila_GGUB0100002...,2.0,1
6,NCBI|Boiga_dendrophila_dendrophila_GGUB0100002...,1.0,2
...,...,...,...
756,TR|R4G7H5|Furina_ornata,2.0,1
759,TR|R4G7K5|Pseudonaja_modesta,2.0,1
760,TR|R4G7K6|Pseudonaja_modesta,2.0,1
762,TR|S6CB76|Pseudechis_rossignolii,1.0,2


In [40]:
df = pd.merge(left=df_online, right=df_pred, on="identifier", how="outer")
cols2keep = [
    "identifier",
    "group",
    "major_group",
    "sub_group",
    "membran_prediction",
    "seq_start",
    "number_cysteines",
    "Dimeric",
    "full_multimere_pred",
    "mature_multimere_pred",
]
df[cols2keep].to_csv(csv_out, index=False)