# **Tutorial 2** Segment-NT: Inferring embeddings

SegmentNT models utilize a Nucleotide Transformer (NT) core without its language modeling head, replaced by a specialized 1-dimensional U-Net segmentation head.

In this tutorial, we will use SegmentNT models from GitHub repo and try to infer segments of DNA into probabilities of being a genomic feature. (out of 14)

![](https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/main/imgs/segment_nt_panel1_screen.png)

# **Installing dependencies**

In [None]:
!git clone https://github.com/instadeepai/nucleotide-transformer.git

In [None]:
!pip install nucleotide-transformer/.
!pip install biopython
!pip install matplotlib

# import clear
from IPython.display import clear_output
clear_output()

# **Import libraries**

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model
from nucleotide_transformer.pretrained import get_pretrained_model

In [None]:
from Bio import SeqIO
import gzip
import numpy as np
import seaborn as sns
from typing import List
import matplotlib.pyplot as plt

In [None]:
# Initialize CPU as default JAX device. This makes the code robust to memory leakage on
# the devices.
jax.config.update("jax_platform_name", "cpu")

backend = "cpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

# **Defining probability plots**

In [None]:
# seaborn settings
sns.set_style("whitegrid")
sns.set_context(
    "notebook",
    font_scale=1,
    rc={
        "font.size": 14,
        "axes.titlesize": 18,
        "axes.labelsize": 18,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "legend.fontsize": 16,
        }
)

plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.left'] = True

# set colors
colors = sns.color_palette("Set2").as_hex()
colors2 = sns.color_palette("husl").as_hex()


# Rearrange order of the features to match Fig.3 from the paper
features_rearranged = [
 'protein_coding_gene',
 'lncRNA',
 '5UTR',
 '3UTR',
 'exon',
 'intron',
 'splice_donor',
 'splice_acceptor',
 'promoter_Tissue_specific',
 'promoter_Tissue_invariant',
 'enhancer_Tissue_specific',
 'enhancer_Tissue_invariant',
 'CTCF-bound',
 'polyA_signal',
]

def plot_features(
    predicted_probabilities_all,
    seq_length: int,
    features: List[str],
    order_to_plot: List[str],
    fig_width=8,
):
    """
    Function to plot labels and predicted probabilities.

    Args:
        predicted_probabilities_all: Probabilities per genomic feature for each
            nucleotides in the DNA sequence.
        seq_length: DNA sequence length.
        feature: Genomic features to plot.
        order_to_plot: Order in which to plot the genomic features. This needs to be
            specified in order to match the order presented in the Fig.3 of the paper
        fig_width: Width of the figure
    """

    sc = 1.8
    n_panels = 7

    # fig, axes = plt.subplots(n_panels, 1, figsize=(fig_width * sc, (n_panels + 2) * sc), height_ratios=[6] + [2] * (n_panels-1))
    _, axes = plt.subplots(n_panels, 1, figsize=(fig_width * sc, (n_panels + 4) * sc))

    for n, feat in enumerate(order_to_plot):
        feat_id = features.index(feat)
        prob_dist = predicted_probabilities_all[:, feat_id]

        # Use the appropriate subplot
        ax = axes[n // 2]

        try:
            id_color = colors[feat_id]
        except:
            id_color = colors2[feat_id - 8]
        ax.plot(
            prob_dist,
            color=id_color,
            label=feat,
            linestyle="-",
            linewidth=1.5,
        )
        ax.set_xlim(0, seq_length)
        ax.grid(False)
        ax.spines['bottom'].set_color('black')
        ax.spines['top'].set_color('black')
        ax.spines['right'].set_color('black')
        ax.spines['left'].set_color('black')

    for a in range (0,n_panels):
        axes[a].set_ylim(0, 1.05)
        axes[a].set_ylabel("Prob.")
        axes[a].legend(loc="upper left", bbox_to_anchor=(1, 1), borderaxespad=0)
        if a != (n_panels-1):
            axes[a].tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=False)

    # Set common x-axis label
    axes[-1].set_xlabel("Nucleotides")
    # axes[0].axis('off')  # Turn off the axis
    axes[n_panels-1].grid(False)
    axes[n_panels-1].tick_params(axis='y', which='both', left=True, right=False, labelleft=True, labelright=False)

    axes[0].set_title("Probabilities predicted over all genomics features", fontweight="bold")

    plt.show()

# **Use-case**: Studying chromosome!

Feel free to use your favourite chromosome :)

Just modify the value at ...chromosome.<>.fa.gz

In [None]:
!wget https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz

In [None]:
# change name here appropriately depending on chromosome
fasta_path = "Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz"

with gzip.open(fasta_path, "rt") as handle:
    record = next(SeqIO.parse(handle, "fasta"))
    # here as well
    chr20 = str(record.seq)

## Inferring 10kb genomic sequence

### Instantiate SegmentNT inference function

The following code cell enables you to download the weights of a Segment-NT model. It provides access to the weights dictionary, the haiku forward function, the tokenizer, and the configuration dictionary.

You have the flexibility to specify:
- The layers from which you wish to extract embeddings (e.g., (5, 10, 20) to retrieve embeddings at layers 5, 10, and 20).
- The attention maps you desire (e.g., ((1, 4), (7, 18)) for attention maps corresponding to layer 1, head 4 and layer 7, head 18). Please refer to the model's configuration for the specific number of layers and heads.
- The maximum sequence length for inference. It's advisable to keep this number minimal for optimized memory usage and faster inference times, up to the limit specified in the model's configuration (including the automatically added class token at the sequence start).

### **Tokenization**

In [None]:
# set maximum token number
max_num_dna_tokens = 1668

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
    model_name="segment_nt",
    embeddings_layers_to_save=(29,),
    attention_maps_to_save=((1, 4), (7, 10)),
    max_positions=max_num_dna_tokens + 1,
    # If the progress bar gets stuck at the start of the model wieghts download,
    # you can set verbose=False to download without the progress bar.
    verbose=True
)

forward_fn = hk.transform(forward_fn)

In [None]:
# Get data and tokenize it
# Set start co-ordinate here
# For chromosome 20, it's 2650520, find yours :)

idx_start = 2650520
idx_stop = idx_start + max_num_dna_tokens*6

sequences = [chr20[idx_start:idx_stop]]

# tokenize
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

### **Infer** from resulting batch

In [None]:
# Initialize random key
random_key = jax.random.PRNGKey(0)

In [None]:
# Infer
outs = forward_fn.apply(parameters, random_key, tokens)

In [None]:
# Obtain the logits over the genomic features
logits = outs["logits"]

In [None]:
# Transform them in probabilities
probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
print(f"Probabilities shape: {probabilities.shape}")

In [None]:
print(f"Features inferred: {config.features}")

In [None]:
plot_features(
    probabilities[0],
    probabilities.shape[-2],
    fig_width=20,
    features=config.features,
    order_to_plot=features_rearranged
)

🎉 **Woohoo, you did it!** 🎉

You’ve just aced the tutorial on inferring probabilities of SegmentNT on a query sequence.

Now, it’s time to unleash your creativity! Dive into experimenting with different tokenization methods, explore various embeddings, and have fun with more fascinating genomic regions.

The genomic playground is all yours! 🚀🔬🌟