
<table style="background-color: transparent; border: none;">   
  <tr>     
    <td><img src="https://cdn.prod.website-files.com/6606dc3fd5f6645318003df4/6678476dc198b5a75b8c8873_ES_Logo_Black_5.png" width="100" alt="img"/></td>     
    <td><h1>Custom Embeddings + <code>XGBoost</code></h1></td>   
  </tr>
</table>

</br>

> __Updated On: `04.04`__


__Key Notes:__

- This model is a complete overhaul from previous implementations focusing on `NN` (_Neural Network_) implementations.
- Here we explore the capabilities of [_XGBoost_](https://xgboost.readthedocs.io/en/release_3.0.0/) and [_Random Search_](https://www.yourdatateacher.com/2021/05/19/hyperparameter-tuning-grid-search-and-random-search/) hyperparameter tuning.
- Validated __Spearman__ from Gradescope: `0.44`



---

### Required Imports

In [None]:
from copy import deepcopy
import pandas as pd
import os
import time
import shutil
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from scipy.stats import spearmanr
import xgboost as xgb
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import make_scorer, mean_squared_error

---

### Data Collection and Cleaning



> __This section relies on having a `sequence.fasta`, `train.csv`, `query.csv` and `test.csv` in your runtime.__

We can start by looking at our _sequence_ from the `sequence.fasta` file and analyzing its composition and length. Our _sequence_ will be the entry-point to generate mutated sequences from coded mutations as discussed later.

In [None]:
with open('sequence.fasta', 'r') as f:
  data = f.readlines()

print('Sequence:\n')
sequence_wt = data[1].strip()
print(f'{sequence_wt[:100]}...')

print(f'\nLength: {len(sequence_wt)}')

Sequence:

MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLREKMRRRLESGDKWFSLEFFPPRTAEGAVNLISRFDRMAAGGPLYIDVTWHPAGD...

Length: 656


In order to effectively leverage our _training_ data in `train.csv` we need to develop a function for transforming mutations into mutated sequences based on the given __WT__ (_Wild Type_). This is crucial as our vectors will be generated based on the mutated 656 length sequence and _not_ simply a mutant code.

In [None]:
def get_mutated_sequence(mut, sequence_wt):
  wt, pos, mt = mut[0], int(mut[1:-1]), mut[-1]

  sequence = deepcopy(sequence_wt)
  return sequence[:pos]+mt+sequence[pos+1:]

Now generating all mutated sequences based on the __WT__ and mutation code for the data points in `train.csv`:

In [None]:
df_train = pd.read_csv('train.csv')
df_train['sequence'] = df_train.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))

print(df_train.head(5))

  mutant  DMS_score                                           sequence
0    M0Y     0.2730  YVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1    M0W     0.2857  WVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2    M0V     0.2153  VVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3    M0T     0.3122  TVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4    M0S     0.2180  SVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


Now we also need to incorporate the _new_ training data generated via an active learning query into `df_train`. This code accepts _all_ active learning query files of the form `query*.csv`:

In [None]:
for filename in os.listdir():
  if filename.startswith('query') and filename.endswith('.csv'):
    df_query = pd.read_csv(filename)
    df_query['sequence'] = df_query.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
    df_train = pd.concat([df_train, df_query])

print(df_train['DMS_score'].describe())

count    1440.000000
mean        0.333433
std         0.296895
min         0.006700
25%         0.086850
50%         0.234650
75%         0.536943
max         0.995700
Name: DMS_score, dtype: float64


Applying a similar process to the datapoints in `test.csv` (of course with the absence of a `DMS_score`):

In [None]:
df_test = pd.read_csv('test.csv')
df_test['sequence'] = df_test.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))

print(df_test.head(5))

  mutant                                           sequence
0    V1D  MDNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1    V1Y  MYNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2    V1C  MCNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3    V1A  MANEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4    V1E  MENEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


---

### Embedding Our Data



To create a vector representation of our data we leverage a __One-Hot Encoding__ approach complemented with contextual inferences from a pretrained `ESM` model.

For more information about the `ESM` model and its capabilities, consult [this resource](https://github.com/facebookresearch/esm).

> __Understanding The `ProteinDataset` Class__

- The helper method `_compute_and_save_embedding` isolates the logic for computing and caching the per-residue embedding for each protein.
- Before computing an embedding, the class checks if a file exists in the `esm_embeddings` directory for that sample.
- The class returns a pooled vector representation for each protein. Instead of returning the full per-residue embedding (of shape $(L, D)$), it averages over the sequence length to produce a fixed-length vector of shape $(*D,)$.
- This `ProteinDataset` class supports two different modes `max` pooling anf `gauss` pool.
- __Warning! Running the below block will clear the cached embeddings in `esm_embeddings`. If you do not intend to recompute embeddings, skip the initial `shutil` command.__


In [None]:
shutil.rmtree("esm_embeddings")

In [None]:
class ProteinDataset(Dataset):

    def __init__(self, df, istrain=True, device='cuda', mode='both', std=10):
        self.df = df.reset_index(drop=True)
        self.istrain = istrain
        self.device = device
        self.mode = mode
        self.std = std

        self.embedding_dir = "esm_embeddings"
        os.makedirs(self.embedding_dir, exist_ok=True)

        self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main",
                                                   "esm2_t33_650M_UR50D")
        self.batch_converter = self.alphabet.get_batch_converter()
        self.model = self.model.to(self.device)
        self.model.eval()

        for idx, row in self.df.iterrows():
            emb_path = self._embedding_path(idx)
            if not os.path.exists(emb_path):
                self._compute_and_save_embedding(idx, row['sequence'])

        if self.istrain:
            self.targets = self.df['DMS_score'].values

    def _embedding_path(self, idx):
        prefix = "train" if self.istrain else "test"
        return os.path.join(self.embedding_dir, f"{prefix}_seq_{idx}.pt")

    def _compute_and_save_embedding(self, idx, seq):
        name = f"protein{idx}"
        data = [(name, seq)]
        batch_labels, batch_strs, batch_tokens = self.batch_converter(data)
        batch_tokens = batch_tokens.to(self.device)
        batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)

        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=[33], return_contacts=False)
        token_representations = results["representations"][33]
        rep = token_representations[0, 1:batch_lens.item()-1].cpu()
        torch.save(rep, self._embedding_path(idx))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx, std=10):
        emb_path = self._embedding_path(idx)
        embedding = torch.load(emb_path)

        mutant_str = self.df.loc[idx, "mutant"]
        pos = int(mutant_str[1:-1])

        if pos >= embedding.shape[0]:
            pos = embedding.shape[0] - 1

        if self.mode == "max":
            feature = embedding.max(dim=0).values

        elif self.mode == "gauss":
            L = embedding.shape[0]
            indices = torch.arange(L, dtype=torch.float32)
            weights = torch.exp(-((indices - pos) ** 2) / (2 * std ** 2))
            weights = weights / weights.sum()
            feature = (weights.unsqueeze(1) * embedding).sum(dim=0)

        else:
            raise ValueError(f"Unknown mode: {self.mode}")

        if self.istrain:
            target = self.targets[idx]
            return feature, torch.tensor(target, dtype=torch.float32)
        else:
            return feature


Now that we have the `ProteinDataset` class we can apply it to our training data previously generated into `df_train`:

In [None]:
train_dataset_max = ProteinDataset(df_train, mode="max")
train_dataset_gauss = ProteinDataset(df_train, mode="gauss")

embedding, _ = train_dataset_max[0]
print(f"\nSample embedding shape: ({embedding.shape[0]},)  ; Total samples: {len(train_dataset_max)}")

Downloading: "https://github.com/facebookresearch/esm/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt
Using cache found in /root/.cache/torch/hub/facebookresearch_esm_main



Sample embedding shape: (1280,)  ; Total samples: 1440


---

### `Train-Validation` Splitting


In order to validate our eventual model we need to perform a `train-validation` split so that some of our data is used to actually train the model and some of it is used to validate our approach based on pre-defined metrics.

There are numerous ways to generate this split including [__K-Fold Cross-Validation__](https://machinelearningmastery.com/k-fold-cross-validation/), but we approached this problem with a simple `80/20` split (training/validation).

In [None]:
X_max, y_full = [], []
X_gauss = []

for i in range(len(train_dataset_max)):

    feature_max, target = train_dataset_max[i]
    X_max.append(feature_max.numpy())

    feature_gauss, _ = train_dataset_gauss[i]
    X_gauss.append(feature_gauss.numpy())

    y_full.append(target.item())

X_max = np.vstack(X_max)
X_gauss = np.vstack(X_gauss)

y_full = np.array(y_full)

print("Max Approach Shape:", X_max.shape)
print("Gaussian Approach Shape:", X_gauss.shape)

print("y_full shape:", y_full.shape)

X_max_train, X_max_val, y_train, y_val = train_test_split(X_max, y_full, test_size=0.2, random_state=42)
X_gauss_train, X_gauss_val, _, _ = train_test_split(X_gauss, y_full, test_size=0.2, random_state=42)

print("Train shapes:", X_max_train.shape, X_gauss_train.shape)
print("Val shapes:", X_max_val.shape, X_gauss_val.shape)

Max Approach Shape: (1440, 1280)
Gaussian Approach Shape: (1440, 1280)
y_full shape: (1440,)
Train shapes: (1152, 1280) (1152, 1280)
Val shapes: (288, 1280) (288, 1280)


---

### Constructing a Model

In this section, we train an XGBoost regressor on our __average‐pooled `ESM` embeddings__ to predict fitness. A goal is to achieve a high Spearman correlation between the predicted and true fitness scores, ensuring that our model correctly ranks the mutants.

The training pipeline involves the following steps:

- **Custom Evaluation Metric:**  We define a custom scoring function based on Spearman correlation. This metric is used by `GridSearchCV` to tune the hyperparameters so that the model not only minimizes error (MSE) but also ranks samples correctly.

- **Hyperparameter Tuning:**  A `parameter_grid` is defined for tuning key __XGBoost__ hyperparameters such as `max_depth`, `learning_rate`, `n_estimators`, and others. We use 3-fold cross-validation to search for the best parameters that maximize the Spearman correlation on the training split. To avoid repetitive computation we have derived optimal hyperparameters:
  - `subsample`: `0.8`
  - `reg_lambda`: `1`
  - `reg_alpha`: `0.01`
  - `n_estimators`: `300`
  - `max_depth`: `6`
  - `gamma`: `0`
  - `colsample_bytree`: `0.8`

- **Validation Evaluation:**  After tuning, we evaluate the best model on a held-out validation set by computing both the Spearman correlation and the MSE.

- **Final Model Training:**  The best model is retrained on the full training set before making predictions on the test set.


In [None]:
def spearman_score(y_true, y_pred):
    rho, _ = spearmanr(y_true, y_pred)
    return rho

spearman_scorer = make_scorer(spearman_score, greater_is_better=True)

best_params = {
    'objective': 'reg:squarederror',
    'learning_rate': 0.1,
    'max_depth': 6,
    'n_estimators': 300,
    'subsample': 0.8,
    'colsample_bytree': 0.8,
    'gamma': 0,
    'reg_alpha': 0.01,
    'reg_lambda': 1,
    'seed': 42,
    'tree_method': 'hist',
    'device': 'cuda'
}

model_max = xgb.XGBRegressor(**best_params)
model_max.fit(X_max_train, y_train)

model_gauss = xgb.XGBRegressor(**best_params)
model_gauss.fit(X_gauss_train, y_train)


y_val_max = model_max.predict(X_max_val)
y_val_gauss = model_gauss.predict(X_gauss_val)

y_val_ensemble = (y_val_max + y_val_gauss) / 2

print("\n[MAX] Validation Spearman correlation:", spearman_score(y_val, y_val_max))
print("[MAX] Validation MSE:", mean_squared_error(y_val, y_val_max))

print("\n[GAUSS] Validation Spearman correlation:", spearman_score(y_val, y_val_gauss))
print("[GAUSS] Validation MSE:", mean_squared_error(y_val, y_val_gauss))

print("\n[ENSEMBLE] Validation Spearman correlation:", spearmanr(y_val, y_val_ensemble))
print("[ENSEMBLE] Validation MSE:", mean_squared_error(y_val, y_val_ensemble))

model_max.fit(X_max, y_full)
model_gauss.fit(X_gauss, y_full)

Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.





[MAX] Validation Spearman correlation: 0.6951324562323872
[MAX] Validation MSE: 0.03348663990562158

[GAUSS] Validation Spearman correlation: 0.6841055409994233
[GAUSS] Validation MSE: 0.03401722020659545

[ENSEMBLE] Validation Spearman correlation: SignificanceResult(statistic=np.float64(0.7116224681218501), pvalue=np.float64(9.364706823770621e-46))
[ENSEMBLE] Validation MSE: 0.0316011555531743


---

### Testing a Model


We use our final trained __XGBoost__ regressor to generate predictions for all mutants in the test set.

We then use the model to predict the fitness scores (DMS scores) and format our results into a CSV file (`predictions.csv`) with two columns:
- **`mutant`:** The mutation identifier in the format `M#A`.
- **`DMS_score_predicted`:** The predicted fitness score.

Finally, we sort the predictions to extract the top 10 mutants and save their identifiers into a text file (`top10.txt`).


In [None]:
test_dataset_max = ProteinDataset(df_test, istrain=False, mode="max")
test_dataset_gauss = ProteinDataset(df_test, istrain=False, mode="gauss")

X_test_max = np.vstack([test_dataset_max[i].numpy() for i in range(len(test_dataset_max))])
X_test_gauss = np.vstack([test_dataset_gauss[i].numpy() for i in range(len(test_dataset_gauss))])

print("X_test_max shape:", X_test_max.shape)
print("X_test_gauss shape:", X_test_gauss.shape)

y_test_max = model_max.predict(X_test_max)
y_test_gauss = model_gauss.predict(X_test_gauss)

y_test_pred = (y_test_max + y_test_gauss) / 2

df_results = pd.DataFrame({
    "mutant": df_test["mutant"],
    "DMS_score_predicted": y_test_pred
})
df_results.to_csv("predictions.csv", index=False)

df_top10 = df_results.sort_values(by="DMS_score_predicted", ascending=False).head(10)
with open("top10.txt", "w") as f:
    for mutant in df_top10["mutant"]:
        f.write(mutant + "\n")

Using cache found in /root/.cache/torch/hub/facebookresearch_esm_main
Using cache found in /root/.cache/torch/hub/facebookresearch_esm_main


X_test_max shape: (11324, 1280)
X_test_gauss shape: (11324, 1280)
