In [35]:
import pandas as pd
df = pd.read_csv('../train.csv')

# get the sequence
seq = open('../sequence.fasta', 'r').read()
seq = seq.split("\n")[1]

# create each mutated sequence using the info
sequences = []
indices = []
for i in df['mutant']:
    ind = int(i[1:-1])
    tmp = seq[:ind] + i[-1] + seq[ind+1:]
    sequences.append(tmp)
    indices.append(ind)
df['Sequence'] = sequences
df['Position'] = indices

In [36]:
df

Unnamed: 0,mutant,DMS_score,Sequence,Position
0,M0Y,0.2730,YVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,0
1,M0W,0.2857,WVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,0
2,M0V,0.2153,VVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,0
3,M0T,0.3122,TVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,0
4,M0S,0.2180,SVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,0
...,...,...,...,...
1135,P347D,0.3876,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,347
1136,P347C,0.1837,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,347
1137,P347A,0.4611,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,347
1138,P347M,0.2412,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,347


In [37]:
import torch
import esm

# Load ESM-2 model (esm2_t6_8M is small and fast, scale up later if needed)
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disable dropout for eval mode

# Example wild-type and mutant sequences

wild_type_sequence = seq
mutant_sequence = df.loc[1]['Sequence']

# Prepare data (must be a list of (name, sequence) tuples)
data = [("wt", wild_type_sequence), ("mut", mutant_sequence)]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

with torch.no_grad():
    results = model(batch_tokens, repr_layers=[6], return_contacts=False)

# Extract per-residue embeddings from the desired layer (layer 6 here)
# Shape: [batch_size, sequence_length, embedding_dim]
token_representations = results["representations"][6]

# Get the per-residue embeddings for WT and mutant (excluding padding and start/end tokens)
wt_embedding = token_representations[0, 1:-1]  # shape: [L, D]
mut_embedding = token_representations[1, 1:-1]  # shape: [L, D]

# Example: get embedding at mutation position (e.g., position 20 → index 19)
mutation_pos = 0
wt_residue_vec = wt_embedding[mutation_pos]
mut_residue_vec = mut_embedding[mutation_pos]

# Cosine similarity between WT and mutant residue embedding
cos_sim = torch.nn.functional.cosine_similarity(wt_residue_vec, mut_residue_vec, dim=0)
print(f"Cosine similarity at position {mutation_pos + 1}: {cos_sim.item():.4f}")


Cosine similarity at position 1: 0.8984


In [61]:
import torch
import esm
import pandas as pd
from tqdm import tqdm
import pyarrow.parquet as pq  # Make sure pyarrow is installed


def get_embed(df: pd.DataFrame, output_file: str, include_fitness=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load ESM-2 model
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    model = model.to(device)

    batch_converter = alphabet.get_batch_converter()
    model.eval()

    # Initialize list to hold data
    data_rows = []

    # Loop through all rows in the mutation DataFrame
    for i, row in tqdm(df.iterrows(), total=len(df)):
        try:
            mutant_sequence = row["Sequence"]
            mutation_pos = int(row["Position"])  # 0-based index
            if include_fitness:
                fitness_score = float(row["DMS_score"])

            # Prepare batch for ESM
            batch_data = [("wt", seq), ("mut", mutant_sequence)]
            _, _, batch_tokens = batch_converter(batch_data)

            with torch.no_grad():
                results = model(batch_tokens.to(device), repr_layers=[6], return_contacts=False)

            reps = results["representations"][6]
            wt_embedding = reps[0, 1:-1]
            mut_embedding = reps[1, 1:-1]

            # Skip if mutation position is invalid
            if mutation_pos >= wt_embedding.shape[0]:
                continue

            # Extract residue embeddings
            wt_vec = wt_embedding[mutation_pos]
            mut_vec = mut_embedding[mutation_pos]

            # Combine features (concat + delta)
            feature_vec = torch.cat((wt_vec, mut_vec), dim=0).cpu().numpy()

            # Store in list
            if include_fitness:
                data_rows.append({
                    "features": feature_vec,
                    "fitness": fitness_score
                })
            else:
                data_rows.append({
                    "features": feature_vec
                })

        except Exception as e:
            print(f"Error on row {i}: {e}")
            continue

    # Convert to DataFrame
    features_df = pd.DataFrame(data_rows)
    # Expand the feature vectors into columns
    features_df = features_df.join(pd.DataFrame(features_df.pop("features").tolist()))

    # Save as parquet
    features_df.to_parquet(output_file, index=False)
    print(f"Saved to {output_file}.parquet")


In [41]:
wt_embedding[17:21].shape

torch.Size([4, 1280])

In [42]:
features_df

Unnamed: 0,fitness,0,1,2,3,4,5,6,7,8,...,2550,2551,2552,2553,2554,2555,2556,2557,2558,2559
0,0.2730,-0.533403,-0.399560,-1.236379,-2.145120,2.079403,0.564469,0.836047,-4.388066,1.294325,...,-1.000465,-0.634368,-0.443642,2.415246,0.283359,0.953022,-0.457921,0.114484,0.583813,1.564510
1,0.2857,-0.533403,-0.399560,-1.236379,-2.145120,2.079403,0.564469,0.836047,-4.388066,1.294325,...,-0.343154,-0.148721,-0.197519,2.299424,0.382421,1.193208,-0.570886,0.827278,-0.316071,1.634797
2,0.2153,-0.533403,-0.399560,-1.236379,-2.145120,2.079403,0.564469,0.836047,-4.388066,1.294325,...,-0.018773,-0.439191,-0.198621,0.772897,1.220283,1.361502,-0.292921,-0.066517,0.559170,1.120865
3,0.3122,-0.533403,-0.399560,-1.236379,-2.145120,2.079403,0.564469,0.836047,-4.388066,1.294325,...,-0.388565,-0.809664,0.128564,2.385023,1.454805,1.018796,-0.260638,0.508979,0.993335,1.467477
4,0.2180,-0.533403,-0.399560,-1.236379,-2.145120,2.079403,0.564469,0.836047,-4.388066,1.294325,...,-1.828906,-1.389767,0.198599,1.565156,1.401824,1.014213,-0.129383,0.548280,1.218197,0.916413
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1135,0.3876,-0.574614,-0.288829,-1.011021,-0.893898,0.292528,0.022261,0.611964,2.468987,0.829979,...,-2.543934,-0.258886,0.302302,1.456124,0.298548,1.670200,-0.619644,-0.793980,1.118043,-1.863922
1136,0.1837,-0.574614,-0.288829,-1.011021,-0.893898,0.292528,0.022261,0.611964,2.468987,0.829979,...,-1.837595,-0.043557,-0.421330,1.869205,-0.455002,1.328572,-0.906081,-0.118915,0.573976,-1.424718
1137,0.4611,-0.574614,-0.288829,-1.011021,-0.893898,0.292528,0.022261,0.611964,2.468987,0.829979,...,-0.898046,0.075524,-0.162744,1.624811,-0.042466,1.741800,-0.638047,-0.471511,0.359480,-2.026469
1138,0.2412,-0.574614,-0.288829,-1.011021,-0.893898,0.292528,0.022261,0.611964,2.468987,0.829979,...,-2.277321,-0.162311,0.051905,2.154739,-0.478884,1.328230,-0.774814,-0.145511,0.458649,-2.031186


In [53]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
from tqdm import tqdm

# Load the dataset
df = pd.read_parquet("protein_mutation_fitness.parquet")

# Extract X (features) and y (fitness score)
X = df.drop(columns=["fitness"]).values.astype(np.float32)
y = df["fitness"].values.astype(np.float32).reshape(-1, 1)

# Normalize features
scaler_X = StandardScaler()
X_scaled = scaler_X.fit_transform(X)

scaler_y = StandardScaler()
y_scaled = scaler_y.fit_transform(y)

# Convert to tensors
X_tensor = torch.tensor(X_scaled)
y_tensor = torch.tensor(y_scaled)

# Create dataset
dataset = TensorDataset(X_tensor, y_tensor)

# Train/val/test split
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64)
test_loader = DataLoader(test_set, batch_size=64)

# Define MLP model
class MLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.model(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MLP(input_dim=X.shape[1]).to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
epochs = 50
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb)
        loss = criterion(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(xb)

    val_loss = 0
    model.eval()
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            val_loss += criterion(pred, yb).item() * len(xb)

    print(f"Epoch {epoch+1:02d}: Train Loss = {total_loss / train_size:.4f}, Val Loss = {val_loss / val_size:.4f}")

# Evaluate on test set
model.eval()
y_preds = []
y_true = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        preds = model(xb).cpu()
        y_preds.append(preds)
        y_true.append(yb)

y_preds = torch.cat(y_preds).numpy()
y_true = torch.cat(y_true).numpy()

# Unscale predictions
y_preds = scaler_y.inverse_transform(y_preds)
y_true = scaler_y.inverse_transform(y_true)

mse = mean_squared_error(y_true, y_preds)
r2 = r2_score(y_true, y_preds)
print(f"\n🧪 Test MSE: {mse:.4f}")
print(f"📈 R² Score: {r2:.4f}")

Epoch 01: Train Loss = 1.0030, Val Loss = 0.8706
Epoch 02: Train Loss = 0.9177, Val Loss = 0.8822
Epoch 03: Train Loss = 0.8542, Val Loss = 0.8465
Epoch 04: Train Loss = 0.8108, Val Loss = 0.8038
Epoch 05: Train Loss = 0.7768, Val Loss = 0.9222
Epoch 06: Train Loss = 0.7613, Val Loss = 0.8685
Epoch 07: Train Loss = 0.7186, Val Loss = 0.8454
Epoch 08: Train Loss = 0.6684, Val Loss = 0.8314
Epoch 09: Train Loss = 0.6009, Val Loss = 0.8753
Epoch 10: Train Loss = 0.6073, Val Loss = 0.8051
Epoch 11: Train Loss = 0.5664, Val Loss = 0.8695
Epoch 12: Train Loss = 0.4888, Val Loss = 0.8580
Epoch 13: Train Loss = 0.4638, Val Loss = 0.9173
Epoch 14: Train Loss = 0.4289, Val Loss = 0.9557
Epoch 15: Train Loss = 0.3635, Val Loss = 0.9678
Epoch 16: Train Loss = 0.3480, Val Loss = 0.8494
Epoch 17: Train Loss = 0.3592, Val Loss = 0.9040
Epoch 18: Train Loss = 0.3176, Val Loss = 0.9197
Epoch 19: Train Loss = 0.2441, Val Loss = 0.9181
Epoch 20: Train Loss = 0.2073, Val Loss = 0.9294
Epoch 21: Train Loss

In [54]:
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr
import numpy as np

# Split your dataset
X_train, X_val, y_train, y_val = train_test_split(X_scaled, y.ravel(), test_size=0.2, random_state=42)

# Create LightGBM datasets
train_set = lgb.Dataset(X_train, label=y_train)
val_set = lgb.Dataset(X_val, label=y_val)

# Define parameters for regression
params = {
    "objective": "regression",
    "metric": "rmse",
    "learning_rate": 0.05,
    "num_leaves": 64,
    "verbosity": -1
}

# Train the model with early stopping
model = lgb.train(
    params,
    train_set,
    valid_sets=[val_set],
    num_boost_round=1000,
    callbacks=[lgb.early_stopping(stopping_rounds=10), lgb.log_evaluation(period=10)]
)

# Predict on validation set
y_pred = model.predict(X_val)

# Evaluation
mse = mean_squared_error(y_val, y_pred)
spearman = spearmanr(y_val, y_pred).correlation

print(f"\n✅ RMSE: {np.sqrt(mse):.4f}")
print(f"📈 Spearman correlation: {spearman:.4f}")


Training until validation scores don't improve for 10 rounds
[10]	valid_0's rmse: 0.206362
[20]	valid_0's rmse: 0.199665
[30]	valid_0's rmse: 0.197501
[40]	valid_0's rmse: 0.195936
[50]	valid_0's rmse: 0.195923
Early stopping, best iteration is:
[46]	valid_0's rmse: 0.195436

✅ RMSE: 0.1954
📈 Spearman correlation: 0.4907


In [59]:
df_test = pd.read_csv('../test.csv')

sequences = []
indices = []
for i in df_test['mutant']:
    ind = int(i[1:-1])
    tmp = seq[:ind] + i[-1] + seq[ind+1:]
    sequences.append(tmp)
    indices.append(ind)
df_test['Sequence'] = sequences
df_test['Position'] = indices


In [63]:
get_embed(df_test,'protein_mutation_fitness_test.parquet')

 26%|██▌       | 2940/11324 [09:50<48:58,  2.85it/s]