<a href="https://colab.research.google.com/github/suzheng/ThermoMPNN-D/blob/main/ThermoMPNN-D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# <center>**This is the Colab implementation of ThermoMPNN-D**</center>


<center><img src='https://drive.google.com/uc?export=view&id=1qXMpih7MLeZfRDZF9-iYSlL6SXEY3FdS'></center>

---

ThermoMPNN-D is an updated version of ThermoMPNN for predicting double point mutations. It was trained on an augmented version of the Megascale double mutant dataset. It is state-of-the-art at predicting stabilizing double mutations.

For convenience, we also provide a single-mutant ThermoMPNN model and an "additive" model that finds mutation pairs in a naive fashion by ignoring epistatic interactions. For details, see the [ThermoMPNN-D paper](https://doi.org/10.1002/pro.70003).

### **COLAB TIPS:**
- The cells of this notebook are meant to be executed *in order*, so users should start from the top and work their way down.
- Executable cells can be run by clicking the PLAY button (>) that appears when you hover over each cell, or by using **Shift+Enter**.
- Make sure GPU is enabled by checking `Runtime` -> `Change Runtime Type`
  - Make sure that `Runtime type` is set to `Python 3`
  - Make sure that `Hardware accelerator` is set to `GPU`
  - Click `Save` to confirm

- If the notebook freezes up or otherwise crashes, go to `Runtime` -> `Restart Runtime` and try again.


In [1]:
%%capture

#@title # 1. Set up **ThermoMPNN environment**
#@markdown Import ThermoMPNN and its dependencies to this session. This may take a minute or two.

#@markdown You only need to do this once *per session*. To re-run ThermoMPNN on a new protein, you may start on Step 3.

# cleaning out any remaining data
!cd /content
!rm -rf /content/ThermoMPNN-D
!rm -rf /content/sample_data
!rm /content/*.pdb
!rm /content/*.csv

# import ThermoMPNN-D github repo
import os
if not os.path.exists("/content/ThermoMPNN-D"):
  !git clone https://github.com/suzheng/ThermoMPNN-D.git
  %cd /content/ThermoMPNN-D

# downloading various dependencies - add more if needed later
! pip install omegaconf wandb pytorch-lightning biopython


In [2]:
%%capture
#@title # **2. Set up ThermoMPNN imports and functions**

import os
import sys
from urllib import request
from urllib.error import HTTPError

from google.colab._message import MessageError
from google.colab import files


tMPNN_path = '/content/ThermoMPNN-D'
if tMPNN_path not in sys.path:
  sys.path.append(tMPNN_path)


def download_pdb(pdbcode, datadir, downloadurl="https://files.rcsb.org/download/"):
    """
    Downloads a PDB file from the Internet and saves it in a data directory.
    :param pdbcode: The standard PDB ID e.g. '3ICB' or '3icb'
    :param datadir: The directory where the downloaded file will be saved
    :param downloadurl: The base PDB download URL, cf.
        `https://www.rcsb.org/pages/download/http#structures` for details
    :return: the full path to the downloaded PDB file or None if something went wrong
    """

    pdbfn = pdbcode + ".pdb"
    url = downloadurl + pdbfn
    outfnm = os.path.join(datadir, pdbfn)
    try:
        request.urlretrieve(url, outfnm)
        return outfnm
    except Exception as err:
        print(str(err), file=sys.stderr)
        return None

def drop_cysteines(df, mode):
  """Drop any mutations to Cys"""

  if mode.lower() == 'single':
    aatype_to = df['Mutation'].str[-1].values
    is_cys = aatype_to == "C"
    df = df.loc[~is_cys].reset_index(drop=True)

  elif mode.lower() == 'additive' or mode.lower() == 'epistatic':
    muts = df['Mutation'].str.split(':', n=2, expand=True).values # [N, 2]
    is_cys = []
    for m in muts:
      mut1, mut2 = m
      is_cys.append(mut1.endswith("C") or mut2.endswith("C"))

    is_cys = np.array(is_cys)
    df = df.loc[~is_cys].reset_index(drop=True)
  else:
    raise ValueError(f"Invalid mode {mode} selected!")
  return df


In [3]:
# %%capture
#@title # **3. Upload or Fetch Input Data**

#@markdown ## You may either specify a PDB code to fetch or upload a custom PDB file.<br><br>

# -------- Collecting Settings for ThermoMPNN run --------- #

!rm /content/*.pdb &> /dev/null

#@markdown PDB code (example: 1PGA):
PDB = "6vxx" #@param {type: "string"}

#@markdown -------

#@markdown Upload Custom PDB?
Custom = False #@param {type: "boolean"}

#@markdown NOTE: If enabled, a `Choose files` button will appear at the bottom of this cell once this cell is run.

#@markdown -----

#@markdown Chain(s) of Interest (example: A,B,C):
Chains = "A,B,C" #@param {type:"string"}

#@markdown If left empty, all chains will be used.

# try to upload the PDB file to Colab servers
if Custom:
  try:
    uploaded_pdb = files.upload()
    for fn in uploaded_pdb.keys():
      PDB = os.path.basename(fn)
      if not PDB.endswith('.pdb'):
        raise ValueError(f"Uploaded file {PDB} does not end in '.pdb'. Please check and rename file as needed.")
      os.rename(fn, os.path.join("/content/", PDB))
      pdb_file = os.path.join("/content/", PDB)
  except (MessageError, FileNotFoundError):
    print('\n', '*' * 100, '\n')
    print('Sorry, your input file failed to upload. Please try the backup upload procedure (next cell).')

else:
  try:
    fn = download_pdb(PDB, "/content/")
    if fn is None:
      raise ValueError("Failed to fetch PDB from RSCB. Please double-check PDB code and try again.")
    else:
      pdb_file = fn
  except HTTPError:
    raise HTTPError(f"No protein with code {PDB} exists in RSCB PDB. Please double-check PDB code and try again.")


In [None]:
#@title # **3. Backup Data Upload (ONLY needed if initial upload failed)**

#@markdown ## Colab automatic file uploads are not very reliable. If your file failed to upload automatically, you can do so manually by following these steps.<br><br>

#@markdown #### 1. Click the "Files" icon on the left toolbar. This will open the Colab server file folder.

#@markdown #### 2. The only thing in this folder should be "ThermoMPNN" directory. If any other files are in here, delete them.

#@markdown #### 3. Click the "Upload to session storage" button under the "Files" header. Choose your file for upload.

#@markdown #### 4. Run this cell. ThermoMPNN will find your file in session storage and use it.

#@markdown ------

#@markdown Chain(s) of Interest (example: A,B,C):
Chains = "" #@param {type:"string"}
#@markdown If left empty, all chains will be used.

PDB = ""

files = sorted(os.listdir('/content/'))
files = [f for f in files if f.endswith('.pdb')]

if len(files) < 1:
  raise ValueError('No PDB file found. Please upload your file before running this cell. Make sure it has a .pdb suffix.')
elif len(files) > 1:
  raise ValueError('Too many PDB files found. Please clear out any other PDBs before running this cell.')
else:
  pdb_file = os.path.join("/content/", files[0])
  PDB = files[0].removesuffix('.pdb')
  print('Successfully uploaded PDB file %s' % (files[0]))

Successfully uploaded PDB file 1bvc.pdb


In [4]:
#@markdown # **4. Run Model**

#@markdown Stability model to use:
Model = "Single" #@param ["Epistatic", "Additive", "Single"]

#@markdown ##### Model descriptions:
#@markdown * Single: Single mutation SSM sweep. Very fast and accurate.
#@markdown * Additive: Naive double mutation SSM sweep. Ignores non-additive coupling. Very fast but less accurate than Epistatic model for picking stabilizing mutations.
#@markdown * Epistatic: Full double mutation SSM sweep. Slower than Additive model, but more accurate for picking stabilizing mutations.

#@markdown ---------------

#@markdown Allow mutations to cysteine? (Not recommended)
Include = False #@param {type: "boolean"}
#@markdown Due to assay artifacts surrounding disulfide formation, model predictions for cysteine mutations may be overly favorable.

#@markdown ---------------

#@markdown Explicitly penalize disulfide breakage? (Recommended)
Penalize = True #@param {type: "boolean"}

#@markdown ThermoMPNN can usually detect disulfide breakage and penalize accordingly, but you may wish to explicitly forbid disulfide breakage to be safe. This option applies a flat penalty to make sure that breaking disulfides is always disfavored.

#@markdown --------------

#@markdown Batch size for model inference. (Recommended: 256 for Single/Additive models, 2048 for epistatic models)
BatchSize = 256 #@param {type: "integer"}
#@markdown If you hit a memory error, try lowering the BatchSize by factors of 2 to reduce memory usage.

#@markdown --------------

#@markdown Threshold for detecting stabilizing mutations. (Recommended: -1.0)
Threshold = -1.0 #@param {type: "number"}
#@markdown Only mutations with predicted ddG below this value will be kept for analysis. Higher thresholds will result in retaining more mutations.

#@markdown --------------

#@markdown Pairwise distance constraint for double mutants. (Recommended: 5.0)
Distance = 5.0 #@param {type: "number"}
#@markdown Only mutation pairs within this distance (in Angstrom) will be kept for analysis. Higher cutoffs will result in slower runtime and retaining more mutations.


# use input_chain_list to grab correct protein chain
chain_list = [c.strip() for c in Chains.strip().split(',')]
if len(chain_list) == 1 and chain_list[0] == '':
  chain_list = []

In [34]:
import argparse
import os
import time
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
from thermompnn.datasets.dataset_utils import Mutation
from thermompnn.datasets.v2_datasets import tied_featurize_mut
from thermompnn.model.v2_model import _dist, batched_index_select
from thermompnn.ssm_utils import (
    distance_filter,
    disulfide_penalty,
    get_config,
    get_dmat,
    get_model,
    load_pdb,
    renumber_pdb,
)
from torch.utils.data import DataLoader
from tqdm import tqdm


def get_ssm_mutations_double(pdb, dthresh):
    # make mutation list for SSM run
    ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
    MUT_POS, MUT_WT = [], []
    for seq_pos in range(len(pdb["seq"])):
        wtAA = pdb["seq"][seq_pos]
        # check for missing residues
        if wtAA != "-":
            MUT_POS.append(seq_pos)
            MUT_WT.append(wtAA)
        else:
            MUT_WT.append("-")

    # Use distance filter BEFORE data setup / inference for speedup
    from thermompnn.ssm_utils import get_dmat

    dmat = np.triu(get_dmat(pdb))  # [L, L]
    mask = (dmat < dthresh) & (dmat > 0.0)
    pos1, pos2 = np.where(mask)
    pos_combos = [(p1, p2) for p1, p2 in zip(pos1, pos2)]
    pos_combos = np.array(pos_combos)  # [combos, 2]
    wtAA = np.zeros_like(pos_combos)
    # fill in wtAA for each pos combo
    for p_idx in range(pos_combos.shape[0]):
        wtAA[p_idx, 0] = ALPHABET.index(MUT_WT[pos_combos[p_idx, 0]])
        wtAA[p_idx, 1] = ALPHABET.index(MUT_WT[pos_combos[p_idx, 1]])

    # make default mutAA bundle for broadcasting
    one = np.arange(20).repeat(20)
    two = np.tile(np.arange(20), 20)
    mutAA = np.stack([one, two]).T  # [400, 2]
    n_comb = pos_combos.shape[0]
    mutAA = np.tile(mutAA, (n_comb, 1))

    # the problem is 2nd wtAA/pos_combos and 2nd mutAA are correlated so they always show up together
    # repeat these 20x20 times
    wtAA = np.repeat(wtAA, 400, axis=0)
    pos_combos = np.repeat(pos_combos, 400, axis=0)

    # filter out self-mutations and single-mutations
    mask = np.sum(mutAA == wtAA, -1).astype(bool)
    pos_combos = pos_combos[~mask, :]
    mutAA = mutAA[~mask, :]
    wtAA = wtAA[~mask, :]

    # filter out upper-triangle portions - if mutAA or pos is larger, it's already been checked
    mask = pos_combos[:, 0] > pos_combos[:, 1]
    pos_combos = pos_combos[~mask, :]
    mutAA = mutAA[~mask, :]
    wtAA = wtAA[~mask, :]

    return torch.tensor(pos_combos), torch.tensor(wtAA), torch.tensor(mutAA)


def run_double(
    all_mpnn_hid, mpnn_embed, cfg, loader, batch_size, model, X, mask, mpnn_edges_raw
):
    """Batched mutation processing using shared protein embeddings and only stability prediction module head"""
    device = "cuda"
    all_mpnn_hid = torch.cat(all_mpnn_hid[: cfg.model.num_final_layers], -1)
    all_mpnn_hid = all_mpnn_hid.repeat(batch_size, 1, 1)
    mpnn_embed = mpnn_embed.repeat(batch_size, 1, 1)
    mpnn_edges_raw = mpnn_edges_raw.repeat(batch_size, 1, 1, 1)
    # get edges between the two mutated residues
    D_n, E_idx = _dist(X[:, :, 1, :], mask)
    E_idx = E_idx.repeat(batch_size, 1, 1)

    preds = []
    for b in tqdm(loader):
        pos, wtAA, mutAA = b
        pos = pos.to(device)
        wtAA = wtAA.to(device)
        mutAA = mutAA.to(device)
        mut_mutant_AAs = mutAA
        mut_positions = pos
        REAL_batch_size = mutAA.shape[0]

        # get sequence embedding for mutant aa
        mut_embed_list = []
        for m in range(mut_mutant_AAs.shape[-1]):
            mut_embed_list.append(model.prot_mpnn.W_s(mut_mutant_AAs[:, m]))
        mut_embed = torch.cat(
            [m.unsqueeze(-1) for m in mut_embed_list], -1
        )  # shape: (Batch, Embed, N_muts)

        n_mutations = [0, 1]
        edges = []
        for n_current in n_mutations:  # iterate over N-order mutations
            # select the edges at the current mutated positions
            if (
                REAL_batch_size != mpnn_edges_raw.shape[0]
            ):  # last batch will throw error if not corrected
                mpnn_edges_raw = mpnn_edges_raw[:REAL_batch_size, ...]
                E_idx = E_idx[:REAL_batch_size, ...]
                all_mpnn_hid = all_mpnn_hid[:REAL_batch_size, ...]
                mpnn_embed = mpnn_embed[:REAL_batch_size, ...]

            mpnn_edges_tmp = torch.squeeze(
                batched_index_select(
                    mpnn_edges_raw, 1, mut_positions[:, n_current : n_current + 1]
                ),
                1,
            )
            E_idx_tmp = torch.squeeze(
                batched_index_select(
                    E_idx, 1, mut_positions[:, n_current : n_current + 1]
                ),
                1,
            )

            n_other = [a for a in n_mutations if a != n_current]
            mp_other = mut_positions[:, n_other]  # [B, 1]
            # E_idx_tmp [B, K]
            mp_other = mp_other[..., None].repeat(
                1, 1, E_idx_tmp.shape[-1]
            )  # [B, 1, 48]
            idx = torch.where(
                E_idx_tmp[:, None, :] == mp_other
            )  # get indices where the neighbor list matches the mutations we want
            a, b, c = idx
            # start w/empty edges and fill in as you go, then set remaining edges to 0
            edge = torch.full(
                [REAL_batch_size, mpnn_edges_tmp.shape[-1]],
                torch.nan,
                device=E_idx.device,
            )  # [B, 128]
            # idx is (a, b, c) tuple of tensors
            # a has indices of batch members; b is all 0s; c has indices of actual neighbors for edge grabbing
            edge[a, :] = mpnn_edges_tmp[a, c, :]
            edge = torch.nan_to_num(edge, nan=0)
            edges.append(edge)

        mpnn_edges = torch.stack(
            edges, dim=-1
        )  # this should get two edges per set of doubles (one for each)

        # gather final representation from seq and structure embeddings
        final_embed = []
        for i in range(mut_mutant_AAs.shape[-1]):
            # gather embedding for a specific position
            current_positions = mut_positions[:, i : i + 1]  # [B, 1]
            g_struct_embed = torch.gather(
                all_mpnn_hid,
                1,
                current_positions.unsqueeze(-1).expand(
                    current_positions.size(0),
                    current_positions.size(1),
                    all_mpnn_hid.size(2),
                ),
            )
            g_struct_embed = torch.squeeze(g_struct_embed, 1)  # [B, E * nfl]
            # add specific mutant embedding to gathered embed based on which mutation is being gathered
            g_seq_embed = torch.gather(
                mpnn_embed,
                1,
                current_positions.unsqueeze(-1).expand(
                    current_positions.size(0),
                    current_positions.size(1),
                    mpnn_embed.size(2),
                ),
            )
            g_seq_embed = torch.squeeze(g_seq_embed, 1)  # [B, E]
            # if mut embed enabled, subtract it from the wt embed directly to keep dims low
            if cfg.model.mutant_embedding:
                if REAL_batch_size != mut_embed.shape[0]:
                    mut_embed = mut_embed[:REAL_batch_size, ...]
                g_seq_embed = g_seq_embed - mut_embed[:, :, i]  # [B, E]
            g_embed = torch.cat([g_struct_embed, g_seq_embed], -1)  # [B, E * (nfl + 1)]

            # if edges enabled, concatenate them onto the end of the embedding
            if cfg.model.edges:
                g_edge_embed = mpnn_edges[:, :, i]
                g_embed = torch.cat([g_embed, g_edge_embed], -1)  # [B, E * (nfl + 2)]
            final_embed.append(
                g_embed
            )  # list with length N_mutations - used to make permutations
        final_embed = torch.stack(final_embed, dim=0)  # [2, B, E x (nfl + 1)]

        # do initial dim reduction
        final_embed = model.light_attention(final_embed)  # [2, B, E]

        # if batch is only single mutations, pad it out with a "zero" mutation
        if final_embed.shape[0] == 1:
            zero_embed = torch.zeros(
                final_embed.shape, dtype=torch.float32, device=E_idx.device
            )
            final_embed = torch.cat([final_embed, zero_embed], dim=0)

        # make two copies, one with AB order and other with BA order of mutation
        embedAB = torch.cat((final_embed[0, :, :], final_embed[1, :, :]), dim=-1)
        embedBA = torch.cat((final_embed[1, :, :], final_embed[0, :, :]), dim=-1)

        ddG_A = model.ddg_out(embedAB)  # [B, 1]
        ddG_B = model.ddg_out(embedBA)  # [B, 1]

        ddg = (ddG_A + ddG_B) / 2.0
        preds += list(torch.squeeze(ddg, dim=-1).detach().cpu().numpy())
    return np.squeeze(preds)


class SSMDataset(torch.utils.data.Dataset):
    def __init__(self, POS, WTAA, MUTAA):
        self.POS = POS
        self.WTAA = WTAA
        self.MUTAA = MUTAA

    def __len__(self):
        return self.POS.shape[0]

    def __getitem__(self, index):
        return self.POS[index, :], self.WTAA[index, :], self.MUTAA[index, :]


def run_single_ssm(pdb, cfg, model):
    """Runs single-mutant SSM sweep with ThermoMPNN v2"""

    model.eval()
    model.cuda()
    stime = time.time()

    # placeholder mutation to keep featurization from throwing error
    pdb["mutation"] = Mutation([0], ["A"], ["A"], [0.0], "")

    # featurize input
    device = "cuda"
    batch = tied_featurize_mut([pdb])
    (
        X,
        S,
        mask,
        lengths,
        chain_M,
        chain_encoding_all,
        residue_idx,
        mut_positions,
        mut_wildtype_AAs,
        mut_mutant_AAs,
        mut_ddGs,
        atom_mask,
    ) = batch

    X = X.to(device)
    S = S.to(device)
    mask = mask.to(device)
    lengths = torch.Tensor(lengths).to(device)
    chain_M = chain_M.to(device)
    chain_encoding_all = chain_encoding_all.to(device)
    residue_idx = residue_idx.to(device)
    mut_ddGs = mut_ddGs.to(device)

    # do single pass through thermompnn
    X = torch.nan_to_num(X, nan=0.0)
    all_mpnn_hid, mpnn_embed, _, mpnn_edges = model.prot_mpnn(
        X, S, mask, chain_M, residue_idx, chain_encoding_all
    )

    all_mpnn_hid = torch.cat(all_mpnn_hid[: cfg.model.num_final_layers], -1)
    all_mpnn_hid = torch.squeeze(torch.cat([all_mpnn_hid, mpnn_embed], -1), 0)  # [L, E]

    all_mpnn_hid = model.light_attention(torch.unsqueeze(all_mpnn_hid, -1))

    ddg = model.ddg_out(all_mpnn_hid)  # [L, 21]

    # subtract wildtype ddgs to normalize
    S = torch.squeeze(S)  # [L, ]

    wt_ddg = batched_index_select(ddg, dim=-1, index=S)  # [L, 1]
    ddg = ddg - wt_ddg.expand(-1, 21)  # [L, 21]
    etime = time.time()
    elapsed = etime - stime
    length = ddg.shape[0]
    print(
        f"ThermoMPNN single mutant predictions generated for protein of length {length} in {round(elapsed, 2)} seconds."
    )
    return ddg, S

def expand_additive(ddg):
    """Uses torch broadcasting to add all possible single mutants to each other in a vectorized operation."""
    # ddg [L, 21]
    dims = ddg.shape
    ddgA = ddg.reshape(dims[0], dims[1], 1, 1)  # [L, 21, 1, 1]
    ddgB = ddg.reshape(1, 1, dims[0], dims[1])  # [1, 1, L, 21]
    ddg = ddgA + ddgB  # L, 21, L, 21

    # mask out diagonal representing two mutations at the same position - this is invalid
    for i in range(dims[0]):
        ddg[i, :, i, :] = torch.nan

    return ddg


def format_output_single(ddg, S, threshold=-0.5):
    """Converts raw SSM predictions into nice format for analysis"""
    ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
    ddg = ddg.cpu().detach().numpy()
    ddg = ddg[:, :20]

    keep_L, keep_AA = np.where(ddg <= threshold)
    ddg = ddg[ddg <= threshold]  # [N, ]

    mutlist = []
    for L_idx, AA_idx in tqdm(zip(keep_L, keep_AA)):
        wtAA = ALPHABET[S[L_idx]]
        mutAA = ALPHABET[AA_idx]
        mutlist.append(wtAA + str(L_idx + 1) + mutAA)

    return ddg, mutlist


def format_output_double(ddg, S, threshold, pdb, distance):
    """Converts raw SSM predictions into nice format for analysis"""
    stime = time.time()
    ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
    ddg = ddg.cpu().detach().numpy()  # [L, 21]
    L, AA = ddg.shape

    ddg = expand_additive(ddg)  # [L, 21, L, 21]
    ddg = ddg[:, :20, :, :20]  # drop X predictions

    # Pre-mask matrix with distance constraints for speedup
    dmat = get_dmat(pdb)
    assert ddg.shape[0] == dmat.shape[0]
    valid_mask = (
        (ddg <= threshold)
        * (dmat < distance)[:, None, :, None]
        * (dmat != 0.0)[:, None, :, None]
    )
    p1s, a1s, p2s, a2s = np.where(valid_mask)

    cond = p1s < p2s  # filter to keep only upper triangle
    p1s, a1s, p2s, a2s = p1s[cond], a1s[cond], p2s[cond], a2s[cond]
    wt_seq = [ALPHABET[S[ppp]] for ppp in np.arange(L)]

    mutlist, ddglist = [], []
    for p1, a1, p2, a2 in tqdm(zip(p1s, a1s, p2s, a2s)):
        wt1, wt2 = wt_seq[p1], wt_seq[p2]
        mut1, mut2 = ALPHABET[a1], ALPHABET[a2]

        if (wt1 != mut1) and (wt2 != mut2):  # drop self-mutations
            mutation = f"{wt1}{p1 + 1}{mut1}:{wt2}{p2 + 1}{mut2}"
            mutlist.append(mutation)
            ddglist.append(ddg[p1, a1, p2, a2])

    etime = time.time()
    elapsed = etime - stime
    print(
        f"ThermoMPNN double mutant additive model predictions calculated in {round(elapsed, 2)} seconds."
    )
    return ddglist, mutlist


def format_output_epistatic(ddg, S, pos, wtAA, mutAA, threshold=-0.5):
    "Converts raw SSM predictions into nice format for analysis."
    stime = time.time()
    ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
    S = torch.squeeze(S)

    # filter out ddgs that miss the threshold
    mask = ddg <= threshold
    ddg = ddg[mask]
    wtAA = wtAA[mask, :]
    mutAA = mutAA[mask, :]
    pos = pos[mask, :]
    mut_list = []
    # a bunch of repeats in here?!
    for b in tqdm(range(ddg.shape[0])):
        w1 = ALPHABET[wtAA[b, 0]]
        w2 = ALPHABET[wtAA[b, 1]]
        m1 = ALPHABET[mutAA[b, 0]]
        m2 = ALPHABET[mutAA[b, 1]]
        mut_name = f"{w1}{pos[b, 0] + 1}{m1}:{w2}{pos[b, 1] + 1}{m2}"
        mut_list.append(mut_name)
    etime = time.time()
    elapsed = etime - stime
    print(
        f"ThermoMPNN double mutant epistatic model predictions sorted and filtered in {round(elapsed, 2)} seconds."
    )
    return ddg, mut_list


def run_epistatic_ssm(pdb, cfg, model, distance, threshold, batch_size):
    """Run epistatic model on double mutations"""

    model.eval()
    model.cuda()
    stime = time.time()

    # placeholder mutation to keep featurization from throwing error
    pdb["mutation"] = Mutation([0], ["A"], ["A"], [0.0], "")

    # featurize input
    device = "cuda"
    batch = tied_featurize_mut([pdb])
    (
        X,
        S,
        mask,
        lengths,
        chain_M,
        chain_encoding_all,
        residue_idx,
        mut_positions,
        mut_wildtype_AAs,
        mut_mutant_AAs,
        mut_ddGs,
        atom_mask,
    ) = batch

    X = X.to(device)
    S = S.to(device)
    mask = mask.to(device)
    lengths = torch.Tensor(lengths).to(device)
    chain_M = chain_M.to(device)
    chain_encoding_all = chain_encoding_all.to(device)
    residue_idx = residue_idx.to(device)
    mut_ddGs = mut_ddGs.to(device)

    # do single pass through thermompnn
    X = torch.nan_to_num(X, nan=0.0)
    all_mpnn_hid, mpnn_embed, _, mpnn_edges = model.prot_mpnn(
        X, S, mask, chain_M, residue_idx, chain_encoding_all
    )

    # grab double mutation inputs
    MUT_POS, MUT_WT_AA, MUT_MUT_AA = get_ssm_mutations_double(pdb, distance)
    dataset = SSMDataset(MUT_POS, MUT_WT_AA, MUT_MUT_AA)
    loader = DataLoader(dataset, shuffle=False, batch_size=batch_size, num_workers=8)

    preds = run_double(
        all_mpnn_hid, mpnn_embed, cfg, loader, batch_size, model, X, mask, mpnn_edges
    )
    ddg, mutations = format_output_epistatic(
        preds, S, MUT_POS, MUT_WT_AA, MUT_MUT_AA, threshold
    )

    etime = time.time()
    elapsed = etime - stime
    print(
        f"ThermoMPNN double mutant epistatic model predictions generated in {round(elapsed, 2)} seconds."
    )
    return ddg, mutations


def check_df_size(size):
    if size == 0:
        raise ValueError("No valid mutations passed your distance and ddG filters. Please increase one or both of these parameters and try again.")





In [40]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Add this for better error messages

def run_single_ssm_synchronized(pdb_data, cfg, model, chains=['A', 'B', 'C']):
    """Run single SSM with synchronized mutations across chains"""
    # Set model to evaluation mode and move to appropriate device
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(f"Model moved to {device}")

    stime = time.time()

    # Placeholder mutation to prevent featurization errors
    pdb_data["mutation"] = Mutation(position=0, wildtype="A", mutation="A", ddG=0.0, pdb=pdb_data['name'])

    # Featurize input
    batch = tied_featurize_mut([pdb_data])
    if batch is None:
        raise ValueError("Batch featurization failed")

    (X, S, mask, lengths, chain_M, chain_encoding_all, residue_idx,
     mut_positions, mut_wildtype_AAs, mut_mutant_AAs, mut_ddGs, atom_mask) = batch

    # Debugging outputs
    print(f"X shape: {X.shape}")
    print(f"S shape: {S.shape}")
    print(f"mask shape: {mask.shape}")

    try:
        # Move tensors to device
        X = X.to(device)
        S = S.to(device)
        mask = mask.to(device)
        lengths = torch.Tensor(lengths).to(device)
        chain_M = chain_M.to(device)
        chain_encoding_all = chain_encoding_all.to(device)
        residue_idx = residue_idx.to(device)
        mut_ddGs = mut_ddGs.to(device)

        # Forward pass through prot_mpnn
        X = torch.nan_to_num(X, nan=0.0)
        with torch.no_grad():
            print("Starting prot_mpnn forward pass")
            all_mpnn_hid, mpnn_embed, _, mpnn_edges = model.prot_mpnn(
                X, S, mask, chain_M, residue_idx, chain_encoding_all
            )
            print("prot_mpnn forward pass complete")

            # Concatenate hidden layers
            all_mpnn_hid = torch.cat(all_mpnn_hid[: cfg.model.num_final_layers], dim=-1)
            all_mpnn_hid = torch.cat([all_mpnn_hid, mpnn_embed], dim=-1).squeeze(0)
            all_mpnn_hid = model.light_attention(all_mpnn_hid.unsqueeze(-1))
            ddg = model.ddg_out(all_mpnn_hid)

        # Normalize predictions
        S = torch.squeeze(S)
        wt_ddg = batched_index_select(ddg, dim=-1, index=S)
        ddg = ddg - wt_ddg.expand(-1, 21)

        # Determine chain length and number of chains
        chain_length = len(pdb_data[f'seq_chain_{chains[0]}'])
        n_chains = len(chains)

        print(f"ddg shape: {ddg.shape}")
        print(f"chain_length: {chain_length}, n_chains: {n_chains}")

        # Reshape ddg to [n_chains, chain_length, 21]
        ddg = ddg.view(n_chains, chain_length, 21)

        # Sum ddG across chains to represent the total stability change
        ddg_total = ddg.sum(dim=0)  # [chain_length, 21]

        # Create mutation list with separate Mutation instances per chain
        mutations_list = []
        ALPHABET = "ACDEFGHIKLMNPQRSTVWY-"

        for pos in range(chain_length):
            wt_aa = pdb_data[f'seq_chain_{chains[0]}'][pos]
            if wt_aa == '-':
                continue

            for mut_idx, mut_aa in enumerate(ALPHABET):
                if mut_aa != wt_aa and mut_aa != '-':
                    for chain in chains:
                        mut = Mutation(
                            position=pos,
                            wildtype=wt_aa,
                            mutation=mut_aa,
                            ddG=ddg_total[pos, mut_idx].item(),
                            pdb=pdb_data['name']
                        )
                        mutations_list.append(mut)

    except Exception as e:
        print(f"Error occurred: {str(e)}")
        raise

    etime = time.time()
    print(f"Predictions generated in {round(etime-stime, 2)} seconds.")

    return ddg_total, mutations_list


# Format output
def format_synchronized_output(ddg, mutations, threshold=-1.0):
    """Format the synchronized mutations output"""
    formatted_muts = []
    for mut in mutations:
        if mut.ddG < threshold:
            mut_str = ':'.join([f"{w}{p+1}{m}" for w, p, m in
                              zip(mut.wildtype, mut.position, mut.mutation)])
            formatted_muts.append({
                'ddG (kcal/mol)': mut.ddG,
                'Mutation': mut_str
            })

    df = pd.DataFrame(formatted_muts)
    return df.sort_values('ddG (kcal/mol)')

Error in main execution: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [41]:
#@title # **Run SSM Inference**

import pandas as pd
import numpy as np

from thermompnn.ssm_utils import (
    distance_filter,
    disulfide_penalty,
    get_config,
    get_dmat,
    get_model,
    load_pdb,
    renumber_pdb,
)
# from v2_ssm import (
#     run_single_ssm,
#     run_epistatic_ssm,
#     format_output_single,
#     format_output_double,
#     check_df_size,
#     run_single_ssm_synchronized
# )

# ------------ MAIN INFERENCE ROUTINE -------------- #

mode = Model.lower()
pdb = pdb_file
chains = chain_list
threshold = Threshold
distance = Distance
batch_size = BatchSize
ss_penalty = Penalize

cfg = get_config(mode)
cfg.platform.thermompnn_dir = '/content/ThermoMPNN-D'
model = get_model(mode, cfg)
pdb_data = load_pdb(pdb, chains)
pdbname = os.path.basename(pdb)
print(f"Loaded PDB {pdbname}")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [37]:
pdb_data.keys()

dict_keys(['resn_list_A', 'seq_chain_A', 'coords_chain_A', 'resn_list_B', 'seq_chain_B', 'coords_chain_B', 'resn_list_C', 'seq_chain_C', 'coords_chain_C', 'name', 'num_of_chains', 'seq', 'mutation'])

In [31]:
pdb_data['seq_chain_C']

'AYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHAIH----------DNPVLPFNDGVYFASTEKSNIIRGWIFGTTLDSKTQSLLIVNNATNVVIKVCEFQFCNDPFLGV---------------------NCTFEYVS-------------FKNLREFVFKNIDGYFKIYSKHTPINLVRDLPQGFSALEPLVDLPIGINITRFQTLLALH-----------------AAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSK--GNYNYLYR-------KPFERDI--------------------YFPLQSYGFQPTN-VGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFNFNGLTGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQDVNCTEV--------------------NVFQTRAGCLIGAEHVNNSYECDIPIGAGICASYQT------------SQSIIAYTMSLGAENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLNRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVT--------------------------KFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDPPEAEVQIDRLITGRLQSLQTYVTQQLIRAAEIRASANLA

In [32]:
ddgs, mutations = run_single_ssm_synchronized(pdb_data, cfg, model, chains=['A', 'B', 'C'])
df = format_synchronized_output(ddgs, mutations, threshold=-1.0)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:


if (mode == "single") or (mode == "additive"):
  ddg, S = run_single_ssm(pdb_data, cfg, model)

  if mode == "single":
    ddg, mutations = format_output_single(ddg, S, threshold)
  else:
    ddg, mutations = format_output_double(
      ddg, S, threshold, pdb_data, distance
    )

elif mode == "epistatic":
  ddg, mutations = run_epistatic_ssm(
    pdb_data, cfg, model, distance, threshold, batch_size
  )

else:
  raise ValueError("Invalid mode selected!")

df = pd.DataFrame({"ddG (kcal/mol)": ddg, "Mutation": mutations})

check_df_size(df.shape[0])

if mode != "single":
  df = distance_filter(df, pdb_data, distance)

if ss_penalty:
  df = disulfide_penalty(df, pdb_data, mode)

if not Include:
  df = drop_cysteines(df, mode)

df = df.dropna(subset=["ddG (kcal/mol)"])
if threshold <= -0.0:
  df = df.sort_values(by=["ddG (kcal/mol)"])

if mode != "single":  # sort to have neat output order
  df[["mut1", "mut2"]] = df["Mutation"].str.split(":", n=2, expand=True)
  df["pos1"] = df["mut1"].str[1:-1].astype(int) + 1
  df["pos2"] = df["mut2"].str[1:-1].astype(int) + 1

  df = df.sort_values(by=["pos1", "pos2"])
  df = df[["ddG (kcal/mol)", "Mutation", "CA-CA Distance"]].reset_index(drop=True)

check_df_size(df.shape[0])

try:
  df = renumber_pdb(df, pdb_data, mode)

except (KeyError, IndexError):
  print(
    "PDB renumbering failed (sorry!) You can still use the raw position data. Or, you can renumber your PDB, fill any weird gaps, and try again."
  )




Loading model %s /content/ThermoMPNN-D/vanilla_model_weights/v_48_020.pt
setting ProteinMPNN dropout: 0.0
MLP HIDDEN SIZES: [384, 64, 32, 21]
Loaded PDB 1vii.pdb


  checkpoint = torch.load(checkpoint_path, map_location='cpu')


ThermoMPNN single mutant predictions generated for protein of length 36 in 0.92 seconds.


2it [00:00, 4038.81it/s]

ThermoMPNN predictions renumbered.





In [None]:
#@title **Visualize data in an interactive table**
from google.colab import data_table

data_table.enable_dataframe_formatter()
data_table.DataTable(df, include_index=True, num_rows_per_page=10)

Unnamed: 0,ddG (kcal/mol),Mutation
0,-1.531715,KA70W
1,-1.311499,KA70Y


In [None]:
#@title # **Save Output as CSV**

# ---------- Collect output into DF and save as CSV ---------- #
from google.colab import files

#@markdown Specify prefix for file saving (e.g., MyProtein). Leave blank to use input PDB code.
PREFIX = "MyProtein" #@param {type:"string"}

#@markdown If you wish to retrieve your files manually, you may do so in the **Files** tab in the leftmost toolbar.

#@markdown NOTE: Make sure you click "Allow" if your browser asks to permit downloads at this step.

#@markdown -------------

#@markdown Save verbose output? (Recommended: True)
VERBOSE = True #@param {type: "boolean"}
#@markdown If enabled, more detailed mutation information will be saved.

df['ddG (kcal/mol)'] = df['ddG (kcal/mol)'].round(4)

if len(PREFIX) < 1:
  PREFIX = pdb_file.split('.')[0]
else:
  PREFIX = os.path.join('/content/', PREFIX)

full_fname = PREFIX + '.csv'

if VERBOSE:
  if Model == 'Single':
    df['Wildtype AA'] = df['Mutation'].str[0]
    df['Mutant AA'] = df['Mutation'].str[-1]
    df['Position'] = df['Mutation'].str[2:-1]
    df['Chain'] = df['Mutation'].str[1]

  else:
    df[['Mutation 1', 'Mutation 2']] = df['Mutation'].str.split(':', n=2, expand=True)
    df['Wildtype AA 1'], df['Wildtype AA 2'] = df['Mutation 1'].str[0], df['Mutation 2'].str[0]
    df['Mutant AA 1'], df['Mutant AA 2'] = df['Mutation 1'].str[-1], df['Mutation 2'].str[-1]
    df['Position 1'], df['Position 2'] = df['Mutation 1'].str[2:-1], df['Mutation 2'].str[2:-1]
    df['Chain 1'], df['Chain 2'] = df['Mutation 1'].str[1], df['Mutation 2'].str[1]

df.to_csv(full_fname, index=True)
files.download(full_fname)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# APPENDIX

## License

The source code for ThermoMPNN-D, including license information, can be found [here](https://github.com/Kuhlman-Lab/ThermoMPNN-D)

---

## Citation Information

If you use ThermoMPNN or ThermoMPNN-D in your research, please cite the following paper(s):

### Epistatic or Additive model:
Dieckhaus, H., Kuhlman, B., *Protein stability models fail to capture epistatic interactions of double point mutations*. **2025**, Protein Science, 34(1): e70003, doi: https://doi.org/10.1002/pro.70003.

### Single mutant model:
Dieckhaus, H., Brocidiacono, M., Randolph, N., Kuhlman, B. *Transfer learning to leverage larger datasets for improved prediction of protein stability changes.* Proc Natl Acad Sci **2024**, 121(6): e2314853121, doi: https://doi.org/10.1073/pnas.2314853121.

---

## Contact Information

Please contact Henry Dieckhaus at dieckhau@unc.edu to report any bugs or issues with this notebook. You may also submit issues on the ThermoMPNN-D GitHub page [here](https://github.com/Kuhlman-Lab/ThermoMPNN-D/issues).
