# Week 8 Workshop

In this week we are working with the protein language model ESM-2. The model family is available on HuggingFace: https://huggingface.co/facebook/esm2_t12_35M_UR50D

We can access and use the model easily using the `transformers` library: https://huggingface.co/docs/transformers/en/index

As always, we begin by importing the required dependencies.

In [None]:
import torch
import polars as pl
from tqdm import tqdm
import plotnine as pln
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, EsmForMaskedLM

Next we're loading the tokenizer and model.

**Note:** If the following code runs extremely slowly or errors out you may have to update your python installation. Alternatively, you can try to run it in a python console, not in Jupyter. Older versions of Jupyter have a bug that prevents this code from running correctly.

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t12_35M_UR50D")
model.to(device)

## Exploring the β-lactamase dataset

We will be working with a deep mutational scanning dataset of the protein β-lactamase. The data comes from:

> Stiffler MA, Hekstra DR, Ranganathan R (2015). Evolvability as a function of purifying selection in TEM-1 β-lactamase. Cell 160:882-892. https://doi.org/10.1016/j.cell.2015.01.035

First we read the data.

In [None]:
data_complete = pl.read_csv('./data/B-Lactamase_Ranganathan2015.csv')
data_complete

Let's look at the distribution of the `target` value, which represents the fitness of each mutant.

In [None]:
density_plot = (
    pln.ggplot(data_complete, pln.aes(x='target'))
    + pln.geom_density()
    + pln.theme_bw()
    )
density_plot.show()

The distribution is bi-modal. Let's say the enzyme has high activity for fitness values above -1 and low activity otherwise. We'll add an activity column to the data to express this.

In [None]:
data_complete = data_complete.with_columns(
    pl.when(pl.col('target') > -1)
      .then(pl.lit('high'))
      .otherwise(pl.lit('low'))
      .alias('activity')
)
data_complete

Next we randomly down-sample to speed up the subsequent calculations. Set the fraction to 1 to work with the complete dataset.

In [None]:
target_fract = 0.1 # set to 1 to work with the complete dataset
data = data_complete.sample(fraction = target_fract, seed=8592153)
data.shape

Let's see how many high and low activity mutants we have in the resulting dataset.

In [None]:
data['activity'].value_counts()

The data set is roughly balanced. We don't have to worry about class imbalances.

## Predicting mutant activity

We want to fit a classifier to distinguish between high activity and low activity mutants. To do this, we will first calculate embeddings, then do a PCA to reduce the embedding space, and then fit a logistic regression to the dimension-reduced embedding space. (We don't have enough data points to fit to the full space.)

First we calculate embeddings.

In [None]:
mean_representations = {}
with torch.no_grad():  # disable gradient calculations
    for seq_id, sequence in tqdm( # iterate using a nice progress bar
        data.select(['ID', 'sequence']).iter_rows(), 
        desc = "Processing sequences", 
        leave = False,
        total = len(data)
    ):
        # tokenize without padding or truncation
        tokens = tokenizer(sequence, return_tensors = "pt", padding = False, truncation = False)
        tokens = tokens.to(device)
        
        # get model output (hidden states are the embedding layers)
        output = model(tokens['input_ids'], output_hidden_states = True)
 
        # get the last hidden state
        # (Not the most efficient for larger datasets. For larger we should use batches.)
        embeddings = output.hidden_states[-1][0]  # Last layer, first and only sequence (batch size = 1)

        # extract the mean embeddings for the sequence, excluding [CLS] and [EOS]
        representations = embeddings[1:-1, :].detach().cpu()  
        
        # compute mean representation of the sequence
        mean_representations[seq_id] = representations.mean(dim=0)

# join emeddings and original data frame
embed = pl.DataFrame(mean_representations).transpose(include_header=True, header_name='ID')
embed_df = data.join(embed, on='ID')
embed_df

Next we do a PCA and visualize active and inactive mutants in PC space.

In [None]:
# select only the embedding columns (there are 480 for the 35M parameter model)
features = embed_df.select(pl.col([f'column_{str(i)}' for i in range(480)]))

# run PCA and extract 20 components
n_components = 20
X_pca  = PCA(n_components=n_components).fit_transform(features)

# create data frame of first two components for visualization
pca_df = pl.DataFrame({
    'PC1': X_pca[:, 0],
    'PC2': X_pca[:, 1],
    'Activity': embed_df['activity']
})
pca_df

In [None]:
pca_plot = (
    pln.ggplot(pca_df, pln.aes(x='PC1', y='PC2', color='Activity'))
    + pln.geom_point(size=1.5, alpha=.7)
    + pln.theme_bw()
)
pca_plot.show()

There seems to be some separation between high and low activity mutants. Let's see if a logistic regression on 20 PC dimensions can successfully classify the data.

In [None]:
# first we create a numeric response variable
# 0 = 'low', 1 = 'high'
y = embed_df['activity'].replace({'low': 0, 'high': 1})

# now we create the training/test split
random_state = 16492345 # change for different train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X_pca, y, test_size = 0.2, random_state = random_state, stratify = y
)
    
print(f"\nTraining samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
print(f"Feature dimension for classification: {X_train.shape[1]}")
    
# train logistic regression classifier on PCA components
print(f"\nTraining classifier on {n_components} principal components...")
clf = LogisticRegression(max_iter = 1000)
clf.fit(X_train, y_train)
    
# make predictions
y_train_pred = clf.predict(X_train)
y_test_pred = clf.predict(X_test)
    
# evaluate performance
train_accuracy = accuracy_score(y_train, y_train_pred)
test_accuracy = accuracy_score(y_test, y_test_pred)
    
print(f"\n{'='*50}")
print("CLASSIFICATION RESULTS (using PCA components)")
print(f"{'='*50}")
print(f"Training Accuracy: {train_accuracy:.2%}")
print(f"Test Accuracy: {test_accuracy:.2%}")
   
print("\nTest Set Classification Report:")
print(classification_report(y_test, y_test_pred, 
                            target_names = ['low', 'high']))
    
print("Test Set Confusion Matrix:")
cm = confusion_matrix(y_test, y_test_pred)
print(f"                  Predicted")
print(f"               low        high")
print(f"Actual low     {cm[0,0]:3d}        {cm[0,1]:3d}")
print(f"       high    {cm[1,0]:3d}        {cm[1,1]:3d}")
    
# create confusion matrix visualization
cm_df = pl.DataFrame({
    'actual': ['low', 'low', 'high', 'high'],
    'predicted': ['low', 'high', 'low', 'high'],
    'count': [cm[0, 0], cm[0, 1], cm[1, 0], cm[1, 1]]
})
    
cm_plot = (
    pln.ggplot(cm_df, pln.aes(x = 'predicted', y = 'actual', fill = 'count'))
    + pln.geom_tile(color = 'white', size = 1.5)
    + pln.geom_text(pln.aes(label = 'count'), size = 20, color = 'white')
    + pln.scale_fill_gradient(low = '#3498db', high = '#e74c3c')
    + pln.scale_x_discrete(limits = ['high', 'low'])
    + pln.scale_y_discrete(limits = ['low', 'high'])
    + pln.labs(
        x='Predicted',
        y='Actual',
        fill='Count'
    )
    + pln.theme_minimal()
)

cm_plot.show()

## Excercises

- Use more of the data than just 10% and see how this changes results
- Change the number of components in the PCA before classification
- Use a different random seed in the training/test split
- Use embeddings from different layers