# Reproduce DDGemb method

In [66]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, EsmModel
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import tqdm

In [2]:
path = 'data/'
df_fireprot = pd.read_csv(path + 'fireprotdb_results.csv')


  df_fireprot = pd.read_csv(path + 'fireprotdb_results.csv')


In [4]:
df_fireprot.head()

Unnamed: 0,experiment_id,protein_name,uniprot_id,pdb_id,chain,position,wild_type,mutation,ddG,dTm,...,technique,technique_details,pH,tm,notes,publication_doi,publication_pubmed,hsw_job_id,datasets,sequence
0,LL000001,Haloalkane dehalogenase,P59336,1CQW,A,245,V,L,,2.1,...,,,,52.5,,,,xfyu58,,MSEIGTGFPFDPHYVEVLGERMHYVDVGPRDGTPVLFLHGNPTSSY...
1,LL000002,Haloalkane dehalogenase,P59336,1CQW,A,95,L,V,,-0.4,...,,,,50.0,,,,xfyu58,,MSEIGTGFPFDPHYVEVLGERMHYVDVGPRDGTPVLFLHGNPTSSY...
2,LL000004,Haloalkane dehalogenase,P59336,1CQW,A,176,C,F,,5.2,...,,,,55.6,,,,xfyu58,,MSEIGTGFPFDPHYVEVLGERMHYVDVGPRDGTPVLFLHGNPTSSY...
3,LL000005,Haloalkane dehalogenase,P59336,1CQW,A,171,G,Q,,3.1,...,,,,53.5,,,,xfyu58,,MSEIGTGFPFDPHYVEVLGERMHYVDVGPRDGTPVLFLHGNPTSSY...
4,LL000006,Haloalkane dehalogenase,P59336,1CQW,A,148,T,L,,1.1,...,,,,51.5,,,,xfyu58,,MSEIGTGFPFDPHYVEVLGERMHYVDVGPRDGTPVLFLHGNPTSSY...


In [44]:
protein_counts_1 = df_fireprot["protein_name"].value_counts().reset_index()
protein_counts_1.head()

Unnamed: 0,protein_name,count
0,Subtilisin-chymotrypsin inhibitor-2A,11160
1,Immunoglobulin G-binding protein G,2158
2,Tryptophan synthase alpha chain,1915
3,Thermonuclease,1857
4,10 kDa chaperonin,1764


In [14]:
columns_to_keep = ['experiment_id', 'protein_name', 'uniprot_id', 'pdb_id', 'chain',
       'position', 'wild_type', 'mutation', 'ddG', 'sequence']
df_fireprot = df_fireprot[-df_fireprot['ddG'].isna()]
df_subset = df_fireprot[columns_to_keep]

In [18]:
df_subset[df_subset["protein_name"].isna()].pdb_id.unique()

array(['2IMM', '1YYX'], dtype=object)

In [16]:
print("The dataset contains {} rows".format(len(df_subset)))

The dataset contains 39177 rows


In [46]:
# Create sequence_wildtype and sequence_mutant columns
df_fireprot.rename(columns={"sequence": "sequence_wildtype"}, inplace = True)
df_fireprot["sequence_mutant"] = df_fireprot["sequence_wildtype"]

# Update sequence_mutant column
for index, row in df_fireprot.iterrows():
    s = list(row["sequence_mutant"])
    s[row["position"]-1] = row["mutation"]
    s = "".join(s)
    df_fireprot.at[index,"sequence_mutant"] = s

In [54]:
X = df_fireprot[["sequence_wildtype", "sequence_mutant"]]
y = df_fireprot["ddG"]
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42)

In [None]:
class DDGPredictor(nn.Module):
    def __init__(self, 
                esm_model_name="facebook/esm2_t6_8M_UR50D", 
                embedding_dim=320, 
                conv_channels=128, 
                heads=4, 
                ffn_dim=256):
        super(DDGPredictor, self).__init__()
        
        # Load pretrained ESM2 model
        self.esm_model = EsmModel.from_pretrained(esm_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
        
        # Freeze ESM model if needed
        for param in self.esm_model.parameters():
            param.requires_grad = False
        
        # 1D Conv Layer
        self.conv1d = nn.Conv1d(
            in_channels=embedding_dim, 
            out_channels=conv_channels, 
            kernel_size=3, 
            padding=1
            )
        
        # Multi-Head Attention
        self.attention = nn.MultiheadAttention(
            embed_dim=conv_channels, 
            num_heads=heads, 
            batch_first=True)
        
        # Position-wise FFN
        self.ffn = nn.Sequential(
            nn.Linear(conv_channels, ffn_dim),
            nn.ReLU(),
            nn.Linear(ffn_dim, conv_channels)
        )
        
        # Pooling
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        
        # Dense layers for final regression
        self.regressor = nn.Sequential(
            nn.Linear(conv_channels * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def get_embeddings(self, seqs):
        # Tokenize
        tokens = self.tokenizer(seqs, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.esm_model(**tokens)
        # Extract per-residue embeddings
        return outputs.last_hidden_state[:,1:-1,:]  # shape: (batch, L, 320)
    
    def forward(self, wt_seqs, mut_seqs):
        # Get ESM embeddings
        emb_wt = self.get_embeddings(wt_seqs)
        emb_mut = self.get_embeddings(mut_seqs)
        
        # Take difference
        d = emb_wt - emb_mut  # shape: (B, L, 320)
        
        # Conv1D expects (B, C, L)
        c = self.conv1d(d.transpose(1, 2)).transpose(1, 2)  # shape: (B, L, conv_channels)
        
        # Multihead attention: q, k, v = c
        m, _ = self.attention(c, c, c)
        
        # Residual connection
        z = c + m
        
        # Feedforward with residual
        p = self.ffn(z)
        f = z + p
        
        # Pooling
        f_t = f.transpose(1, 2)  # (B, C, L)
        gp = self.global_avg_pool(f_t).squeeze(-1)  # (B, C)
        gm = self.global_max_pool(f_t).squeeze(-1)  # (B, C)
        
        conc = torch.cat([gp, gm], dim=1)  # (B, 2C)
        
        # Final regression
        ddg_pred = self.regressor(conc).squeeze(-1)
        return ddg_pred

In [42]:
model = DDGPredictor()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [58]:
#define wt_seqs_batch and mut_seqs_batch, ddG_true
class FireProtDataset(Dataset):
    def __init__(self, sequences, targets):
        self.sequences = sequences.reset_index(drop=True)
        self.targets = targets.reset_index(drop=True)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        wt = self.sequences.loc[idx, "sequence_wildtype"]
        mt = self.sequences.loc[idx, "sequence_mutant"]
        ddg = self.targets[idx]

        return wt, mt, torch.tensor(ddg, dtype=torch.float32)


In [59]:
train_dataset = FireProtDataset(x_train, y_train)
val_dataset   = FireProtDataset(x_val, y_val)
test_dataset  = FireProtDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [64]:
def train_ddg_model(model, train_loader, val_loader, num_epochs=10, lr=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for wt_seq, mt_seq, ddg in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
            ddg = ddg.to(device)

            # Forward pass (model handles tokenization internally)
            pred = model(wt_seq, mt_seq)  # wt_seq and mt_seq are lists of strings
            loss = loss_fn(pred.squeeze(), ddg)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for wt_seq, mt_seq, ddg in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]"):
                ddg = ddg.to(device)
                pred = model(wt_seq, mt_seq)
                loss = loss_fn(pred.squeeze(), ddg)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1} - Val Loss: {avg_val_loss:.4f}")


In [None]:
train_ddg_model(model, train_loader, val_loader, num_epochs=10, lr=1e-4)