In [1]:
import pandas as pd

# Load the dataset (change path if needed)
df = pd.read_csv("data/fireprotdb_results.csv")

# Preview the data
print(df.head())

  experiment_id             protein_name uniprot_id pdb_id chain  position  \
0      LL000001  Haloalkane dehalogenase     P59336   1CQW     A       245   
1      LL000002  Haloalkane dehalogenase     P59336   1CQW     A        95   
2      LL000004  Haloalkane dehalogenase     P59336   1CQW     A       176   
3      LL000005  Haloalkane dehalogenase     P59336   1CQW     A       171   
4      LL000006  Haloalkane dehalogenase     P59336   1CQW     A       148   

  wild_type mutation  ddG  dTm  ...  technique  technique_details  pH    tm  \
0         V        L  NaN  2.1  ...        NaN                NaN NaN  52.5   
1         L        V  NaN -0.4  ...        NaN                NaN NaN  50.0   
2         C        F  NaN  5.2  ...        NaN                NaN NaN  55.6   
3         G        Q  NaN  3.1  ...        NaN                NaN NaN  53.5   
4         T        L  NaN  1.1  ...        NaN                NaN NaN  51.5   

   notes  publication_doi  publication_pubmed  hsw_job_i

  df = pd.read_csv("data/fireprotdb_results.csv")


In [2]:
df_filtered = df[df['ddG'].notnull()]

In [None]:
# Keep only unique mutation experiments


In [4]:
columns_to_keep = [
    'experiment_id',             # for tracking/logging
    'protein_name',              # useful for protein-level filtering
    'uniprot_id',                # for linking external annotations
    'pdb_id',                    # required for structure
    'chain',                     # required to select chain from PDB
    'position',                  # residue index of mutation
    'wild_type',                 # original amino acid
    'mutation',                  # mutated amino acid
    'ddG',                       # target variable
    'sequence',                  # full wild-type sequence
    'is_in_catalytic_pocket',    # core binary feature (100% coverage)
    'is_essential'               # core binary feature (100% coverage)
]
df_subset = df_filtered[columns_to_keep]

df_subset = df_subset.dropna()

In [5]:
df_subset = df_subset.drop_duplicates(subset=['experiment_id'])

In [6]:
df_subset.shape

(12090, 12)

In [7]:
import os
import torch
import pandas as pd
import requests
from Bio.PDB import PDBParser
from torch_geometric.data import Data
from scipy.spatial.distance import euclidean
import networkx as nx

# ============ Helper: Fetch AlphaFold Structure ============
def fetch_alphafold_pdb(uniprot_id, save_dir="wt_pdbs"):
    os.makedirs(save_dir, exist_ok=True)
    url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.pdb"
    save_path = os.path.join(save_dir, f"{uniprot_id}.pdb")
    if not os.path.exists(save_path):
        r = requests.get(url)
        if r.status_code == 200:
            with open(save_path, 'wb') as f:
                f.write(r.content)
        else:
            raise ValueError(f"❌ Could not fetch AlphaFold PDB for {uniprot_id}")
    return save_path

# ============ Structure → Graph Construction ============
def extract_residue_coords(pdb_file, chain_id):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    model = structure[0]
    coords = {}
    for residue in model[chain_id]:
        if 'CA' in residue:
            res_id = residue.get_id()[1]
            coords[res_id] = residue['CA'].get_coord()
    return coords

def build_protein_graph(coords, distance_threshold=8.0):
    G = nx.Graph()
    residue_ids = list(coords.keys())
    for res_id in residue_ids:
        G.add_node(res_id, pos=coords[res_id])
    for i, id1 in enumerate(residue_ids):
        for id2 in residue_ids[i+1:]:
            dist = euclidean(coords[id1], coords[id2])
            if dist <= distance_threshold:
                G.add_edge(id1, id2, distance=dist)
    return G

def construct_graph_from_pdb(pdb_file, chain_id, distance_threshold=8.0):
    coords = extract_residue_coords(pdb_file, chain_id)
    return build_protein_graph(coords, distance_threshold)

# ============ Feature Engineering ============
def one_hot_encode_aa(aa):
    aa_list = list("ACDEFGHIKLMNPQRSTVWY")
    return [int(aa == x) for x in aa_list]

def enrich_graph_with_features(graph, row, sequence_length):
    sequence = row['sequence']
    catalytic = row['is_in_catalytic_pocket']
    essential = row['is_essential']
    pdb_positions = sorted(graph.nodes)
    offset = min(pdb_positions)

    for pos in graph.nodes:
        seq_index = pos - offset
        if 0 <= seq_index < len(sequence):
            aa = sequence[seq_index]
            graph.nodes[pos]['aa_type'] = one_hot_encode_aa(aa)
            graph.nodes[pos]['is_catalytic'] = int(catalytic)
            graph.nodes[pos]['is_essential'] = int(essential)
            graph.nodes[pos]['relative_pos'] = pos / sequence_length
    return graph

# ============ Graph → PyG ============
def convert_nx_to_pyg(graph, ddG_value=0.0):
    features = []
    node_id_map = {}
    filtered_nodes = []

    for i, n in enumerate(graph.nodes):
        node_data = graph.nodes[n]
        if all(k in node_data for k in ['aa_type', 'is_catalytic', 'is_essential', 'relative_pos']):
            node_id_map[n] = len(filtered_nodes)
            filtered_nodes.append(n)
            feat = node_data['aa_type'] + [
                node_data['is_catalytic'],
                node_data['is_essential'],
                node_data['relative_pos']
            ]
            features.append(feat)

    if len(features) == 0:
        raise ValueError("❌ No usable nodes with complete features in WT graph.")

    edge_index, edge_attr = [], []
    for u, v, attrs in graph.edges(data=True):
        if u in node_id_map and v in node_id_map:
            edge_index.append([node_id_map[u], node_id_map[v]])
            edge_index.append([node_id_map[v], node_id_map[u]])
            edge_attr.append([attrs['distance']])
            edge_attr.append([attrs['distance']])

    if len(edge_index) == 0:
        raise ValueError("❌ No valid edges in WT graph.")

    x = torch.tensor(features, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    y = torch.tensor([ddG_value], dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

# ============ Row Processor ============
def process_row_to_wt_embedding(row, model):
    try:
        uniprot_id = row['uniprot_id']
        chain_id = 'A'
        sequence_length = len(row['sequence'])

        pdb_file = fetch_alphafold_pdb(uniprot_id)
        G = construct_graph_from_pdb(pdb_file, chain_id)
        G = enrich_graph_with_features(G, row, sequence_length)

        data = convert_nx_to_pyg(G)
        data.batch = torch.zeros(data.num_nodes, dtype=torch.long)

        model.eval()
        with torch.no_grad():
            embedding = model(data.x, data.edge_index, data.batch).squeeze().cpu()
        return embedding

    except Exception as e:
        print(f"❌ WT row error: {e}")
        return None

# ============ Full Processor ============
def process_wildtype_embeddings(df, model, save_dir="outputs_wt"):
    os.makedirs(save_dir, exist_ok=True)

    df = df.drop_duplicates(subset=['uniprot_id', 'position', 'mutation'])
    embeddings = {}
    logs = []

    for idx, row in df.iterrows():
        print(f"\n🔄 WT row {idx} — UniProt: {row['uniprot_id']} Pos: {row['position']}")
        embedding = process_row_to_wt_embedding(row, model)
        key = f"{row['uniprot_id']}_{row['position']}_{row['mutation']}"

        if embedding is not None:
            embeddings[key] = embedding.numpy()
            logs.append({**row.to_dict(), "key": key, "embedding_generated": True})
        else:
            logs.append({**row.to_dict(), "key": key, "embedding_generated": False})

    torch.save(embeddings, os.path.join(save_dir, "wt_embeddings.pt"))
    pd.DataFrame(logs).to_csv(os.path.join(save_dir, "wt_embedding_log.csv"), index=False)

    print(f"\n✅ Generated {len(embeddings)} WT embeddings.")
    return embeddings

In [12]:


import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class ProteinGNNEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):  # ✅ fixed __init__
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.readout = global_mean_pool  # can swap to max or attention
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.readout(x, batch)
        return self.mlp(x)

# Load model
model = ProteinGNNEncoder(in_dim=23, hidden_dim=64, out_dim=128)

# Run embedding pipeline
wt_embeddings = process_wildtype_embeddings(df_subset, model)


🔄 WT row 5818 — UniProt: P06654 Pos: 1

🔄 WT row 5820 — UniProt: P06654 Pos: 1

🔄 WT row 5822 — UniProt: P06654 Pos: 1

🔄 WT row 5824 — UniProt: P06654 Pos: 1

🔄 WT row 5826 — UniProt: P06654 Pos: 1

🔄 WT row 5828 — UniProt: P06654 Pos: 1

🔄 WT row 5830 — UniProt: P06654 Pos: 1

🔄 WT row 5832 — UniProt: P06654 Pos: 1

🔄 WT row 5834 — UniProt: P06654 Pos: 1

🔄 WT row 5836 — UniProt: P06654 Pos: 228

🔄 WT row 5838 — UniProt: P06654 Pos: 228

🔄 WT row 5840 — UniProt: P06654 Pos: 228

🔄 WT row 5842 — UniProt: P06654 Pos: 228

🔄 WT row 5844 — UniProt: P06654 Pos: 228

🔄 WT row 5846 — UniProt: P06654 Pos: 228

🔄 WT row 5848 — UniProt: P06654 Pos: 228

🔄 WT row 5850 — UniProt: P06654 Pos: 228

🔄 WT row 5852 — UniProt: P06654 Pos: 228

🔄 WT row 5854 — UniProt: P06654 Pos: 228

🔄 WT row 5856 — UniProt: P06654 Pos: 228

🔄 WT row 5858 — UniProt: P06654 Pos: 228

🔄 WT row 5860 — UniProt: P06654 Pos: 228

🔄 WT row 5862 — UniProt: P06654 Pos: 228

🔄 WT row 5864 — UniProt: P06654 Pos: 229

🔄 WT row 

In [16]:
os.makedirs("outputs", exist_ok=True)
torch.save(wt_embeddings, os.path.join("outputs", "embeddings.pt"))


In [17]:
import torch
import pandas as pd

# Load embedding dictionary

embedding_dict = torch.load("outputs/embeddings.pt", weights_only=False)

# Convert to DataFrame
df_embeddings = pd.DataFrame.from_dict(embedding_dict, orient='index')
df_embeddings.columns = [f"dim_{i}" for i in range(df_embeddings.shape[1])]
df_embeddings.index.name = "row_idx"

# Preview
print(f"✅ Loaded {len(df_embeddings)} embeddings")
display(df_embeddings.head())


✅ Loaded 4548 embeddings


Unnamed: 0_level_0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,...,dim_118,dim_119,dim_120,dim_121,dim_122,dim_123,dim_124,dim_125,dim_126,dim_127
row_idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
P06654_1_A,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272
P06654_1_D,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272
P06654_1_E,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272
P06654_1_F,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272
P06654_1_G,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272


In [19]:
import torch
import pandas as pd

# Load embedding dictionary

embedding_dict = torch.load("outputs_wt/wt_embeddings.pt", weights_only=False)

# Convert to DataFrame
df_wt_embeddings = pd.DataFrame.from_dict(embedding_dict, orient='index')
df_wt_embeddings.columns = [f"dim_{i}" for i in range(df_embeddings.shape[1])]
df_wt_embeddings.index.name = "row_idx"

# Preview
print(f"✅ Loaded {len(df_embeddings)} embeddings")
display(df_wt_embeddings.head())


✅ Loaded 4548 embeddings


Unnamed: 0_level_0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,...,dim_118,dim_119,dim_120,dim_121,dim_122,dim_123,dim_124,dim_125,dim_126,dim_127
row_idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
P06654_1_A,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272
P06654_1_D,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272
P06654_1_E,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272
P06654_1_F,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272
P06654_1_G,0.015131,-0.076039,0.0203,-0.127062,0.175872,-0.054542,0.07257,-0.129026,-0.005289,-0.026273,...,-0.067685,-0.044786,-0.073017,0.058983,0.003985,-0.056319,-0.148701,-0.067055,0.060501,-0.093272


In [21]:
import pandas as pd


# Subtract embeddings (mutant - wild-type)
df_delta = df_embeddings - df_wt_embeddings

# (Optional) Reset index if you want row_idx as a column
df_delta = df_delta.reset_index()

# Save if needed
df_delta.to_csv("delta_embeddings.csv", index=False)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/4548 [00:00<?, ?it/s]


KeyError: 'sequence'