# Sequence Representation Visualization with ProkBERT

This guide outlines the steps to visualize sequence embeddings using ProkBERT, specifically focusing on the genomic features of ESKAPE pathogens with ProkBERT-mini. 
The workflow:
1. **Model Loading**: Load the ProkBERT model designed for genomic sequence analysis.
2. **Dataset Preparation**: Ready your dataset for ProkBERT by performing necessary preprocessing.
3. **Model Evaluation**: Process your dataset through ProkBERT to generate embeddings.
4. **Results Visualization**: Visualize these embeddings to identify patterns and insights into the genomic features.

In this example we are going to visualize different geneomic features of the ESKAPE pathogens using the ProkBERT-mini


### Setup and Installation

Before we start, let's ensure that all necessary libraries are installed for our project. This notebook uses packages, including `umap-learn` for dimensionality reduction and `seaborn` for visualization.


In [None]:
# ProkBERT
!pip install git+https://github.com/nbrg-ppcu/prokbert
# Ensure umap-learn is installed
!pip install umap-learn
# Ensure seaborn is installed for vis
!pip install seaborn

# Imports
from prokbert.training_utils import get_default_pretrained_model_parameters, get_torch_data_from_segmentdb_classification
from torch.utils.data import DataLoader
from prokbert.prok_datasets import ProkBERTTrainingDatasetPT
from datasets import load_dataset
import torch
import numpy as np
import umap
import seaborn as sns
import matplotlib.pyplot as plt


## Enabling and testing the GPU (if you are using google colab)

First, you'll need to enable GPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select GPU from the Hardware Accelerator drop-down


### Loading the model
In this step, we'll utilize the MINI pretrained model of ProkBERT, focusing on the base model to extract sequence embeddings. It's important to match the model with the appropriate tokenizer, especially when loading directly from Hugging Face to ensure compatibility with tokenization parameters.

**Embeddings:**

Embeddings are dense vector representations of data, in this case, genomic sequences, where similar sequences are closer in the vector space. This representation allows the model to capture the context and semantic meanings of sequences, facilitating more effective analysis and comparison. By extracting embeddings from the ProkBERT model, we can leverage these rich, contextually informed representations for various bioinformatics applications, such as clustering, similarity searches, or as features for downstream machine learning models.


In [None]:
model_name_path = 'neuralbioinfo/prokbert-mini'
model, tokenizer = get_default_pretrained_model_parameters(
    model_name=model_name_path, 
    model_class='MegatronBertModel', 
    output_hidden_states=False, 
    output_attentions=False,
    move_to_gpu=False
)




### Dataset Preparation

In this section, we download a preprocessed and filtered dataset from Hugging Face. The dataset is then converted into a pandas DataFrame for easy manipulation and analysis. To efficiently manage memory and computation, especially when working with large datasets, we sample a subset (`Nsample`) of the original data. 

Further, we prepare the data for model training and evaluation by creating a PyTorch dataset. This involves tokenizing the sequences and formatting the data according to the requirements of our ProkBERT model. The `batch_size` parameter is crucial here, as it determines the number of samples to process in a single batch. Adjusting `batch_size` is essential for optimizing GPU usage, ensuring that the model training or evaluation process is both efficient and within the memory limits of your hardware.




In [None]:
# Set the batch size for DataLoader and sample size
batch_size=64
Nsample=1000
# Load the dataset from Hugging Face datasets library
dataset = load_dataset("neuralbioinfo/ESKAPE-genomic-features")
eskape = dataset['ESKAPE'].to_pandas()
eskape_features_sample = eskape.sample(1000)
eskape_features_sample['y']=0

# Prepare the data for PyTorch model training/evaluation
[X, y, torchdb] = get_torch_data_from_segmentdb_classification(tokenizer, eskape_features_sample)
ds = ProkBERTTrainingDatasetPT(X, y, AddAttentionMask=True)
eval_dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)





### Evaluating the Dataset

Next, we will execute a forward pass to obtain the output embeddings from the last layer. The output dimension is \( \text{batch size} \times \text{sequence length} \times \text{embedding size} \). To assign one vector to each sequence, rather than to each token, we need to aggregate the vectors. Here, we apply a simple mean function across the sequence length dimension.



In [None]:
# using GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


representations = []
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    batch.pop('labels')
    with torch.no_grad():
        outputs = model(**batch)
        last_hidden_states = outputs.last_hidden_state
        mean_pooled = torch.mean(last_hidden_states, dim=1)        
        # Optionally detach and move to CPU if you're planning to work with numpy or save memory on GPU
        representations_batch = mean_pooled.detach().cpu().numpy()
        representations.append(representations_batch)
representations = np.concatenate(representations)


### Visualization

In this section, we will visualize the high-dimensional representations using Uniform Manifold Approximation and Projection (UMAP). UMAP helps in visualizing complex datasets by projecting them into a lower-dimensional space, typically 2D or 3D, while preserving the original data's global and local structure as much as possible.

**UMAP Parameters:**

- **`n_neighbors`**: Controls how UMAP balances local versus global structure in the data. It determines the number of neighboring points used in local approximations of manifold structure. Higher values can help preserve more of the global structure.
- **`min_dist`**: Sets the minimum distance between points in the low-dimensional representation. Smaller values allow UMAP to focus on finer details, while larger values help to preserve the broader data topology.
- **`random_state`**: Ensures reproducibility of your results by fixing the random seed used by UMAP's stochastic optimization process.

**Visualization Process:**

Following the dimensionality reduction with UMAP, we will plot the embeddings, categorizing them by specific features (e.g., "strand" and "label") to observe how these characteristics distribute across the 2D space. This visualization can uncover patterns, similarities, and differences within the data, providing insights that are not readily apparent in the high-dimensional space.

**Fine-Tuning UMAP Parameters:**

Fine-tuning UMAP's parameters is crucial for achieving meaningful visualizations. Here's how to approach it:

- **Exploring `n_neighbors`**: Start with a value around 10-50 and adjust based on your dataset's size and complexity. Smaller datasets or those with intricate structures may require adjusting this parameter to better capture the data's nuances.
- **Adjusting `min_dist`**: Experiment with values between 0 and 0.99. A smaller `min_dist` allows UMAP to create more focused clusters, ideal for identifying subtle groupings or patterns in the data.
- **Setting `random_state`**: Use a fixed value if you need consistent output across multiple runs, essential for comparative analysis or publication.


In [None]:
umap_random_state = 42
n_neighbors=20
min_dist = 0.4

eskape_embd = torchdb.merge(eskape, how='left', left_on='segment_id', right_on='segment_id')

reducer = umap.UMAP(random_state=umap_random_state, n_neighbors=n_neighbors, min_dist=min_dist)
print('Running UMAP ....')
umap_embeddings = reducer.fit_transform(representations)
eskape_embd['umap_1']=umap_embeddings[:, 0]
eskape_embd['umap_2']=umap_embeddings[:, 1]

g = sns.FacetGrid(eskape_embd, col="strand", hue="label", palette="Set1", height=6)
# Apply a scatterplot to each subplot
g.map(sns.scatterplot, "umap_1", "umap_2")
# Add a legend
g.add_legend()

# Display the plot
plt.show()
