# Embed input DNA sequences

In this notebook we extract the features (embeddings) of the input DNA sequences using the pretrained enformer.

## Setup

In [1]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import joblib
import gzip
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import sonnet as snt
import sys
import os
import time
from tqdm import tqdm

# Make sure the GPU is enabled
print(tf.test.gpu_device_name())

# path to the TF enformer model
sys.path.append("../enformer/")

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
from tensorflow.python.training import checkpoint_utils as cp
from enformer import Enformer # not a package, but a module
from utils import *

/device:GPU:0


## Loading the pretrained weights

Since the saved TF Enformer model does not have a function to extract features (embeddings), we need to initiate our Enformer object defined in the file `enformer/enformer.py` (where we added an extra function to extract the embeddings) and load the weights from the pretrained model in the folder `weights/`. Unfortunately, this is not so straight forward, because the weights are named differently. This next code chunk takes care of this.

In [2]:
SEQUENCE_LENGTH = 393216

# one simple input
input1 = pd.read_csv("../data/dna_example_393_216.txt", header=None).loc[0,0]

# loading the pretrained model and initializing it with a prediction (probably not the best practice, but it works)
enformer = Enformer()
_ = enformer.predict_on_batch(tf.stop_gradient(one_hot_encode(input1)[np.newaxis]))

# Updating the randomly initialized variables with the pretrained
saved_variables = cp.list_variables("weights/variables/variables")

saved_names = [i[0] for i in saved_variables][1:-1]
enformer_names = [i.name for i in enformer.variables]

df = dict(originals = saved_names,
          transformed = [rename_stored_variable(i) for i in saved_names]) # rename_stored_variable is a function from utils.py
df = pd.DataFrame(df)
df["check"] = [str(i) in enformer_names for i in df.transformed]
df["values"] = [cp.load_variable("weights/variables/variables", i) for i in df.originals]

print("Correctly renamed " + str(sum(df.check)) + " variables out of " + str(len(df.check)))
print("Duplicate names", len(df["transformed"]) - len(df["transformed"].unique() ))

# Assumption: the duplicate variables maintain the same order in both lists
for i in range(11):
    xx = f'enformer/trunk/transformer/transformer/transformer_block_{i}/transformer_block_{i}/mlp/mlp/project_out/'
    ids = [xx in j for j in df["transformed"]]
    tmp = df["transformed"][ids].to_list()
    xx2 = f'enformer/trunk/transformer/transformer/transformer_block_{i}/transformer_block_{i}/mlp/mlp/project_in/'
    tmp[:2] = [j.replace(xx, xx2) for j in tmp[:2]]
    df["transformed"][ids] = tmp   
    
df2 = df.copy()
for v in enformer.variables:
    a = int(df2["transformed"][df2["transformed"] == v.name].index.values[0])
    v.assign(df2["values"][a])
    df2.loc[a,:] = None
print("All variables are updated with the correct values. Duplicates are considered!")

Correctly renamed 327 variables out of 329
Duplicate names 22


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["transformed"][ids] = tmp


All variables are updated with the correct values. Duplicates are considered!


## Extracting and storing the features

In [5]:
#TODO!
seq_dir = "/data/DNA_seqs/"
seq_files= os.listdir(seq_dir)
feat_dir = "/data/DNA_feats"

with tqdm(total = len(seq_files), desc = "Extracting features(embeddings) ...", unit = 'seq') as prog_bar:
    for i, seq_name in enumerate(seq_files):
        
        #TODO!
        #Read sequence
        seq = pd.read_csv(seq_dir+seq_name)

        name = seq_dir + seq_name
        feat_file = seq_name.split(".")[0] + ".npy"
        
        feat = enformer.extract_features(tf.stop_gradient(one_hot_encode(seq)[np.newaxis]))
        feat = feat.numpy()

        np.save(feat_dir + feat_file, feat)
        prog_bar.update(1)

Extracting features(embeddings) ...: 100%|████████████████████████████████████████| 6149/6149 [44:40<00:00,  2.29seq/s]


***

# Checking if the model works correctly...

To be sure that the features are correct and we did not messed up we have to checked 2 things:
- If the "adapted" model still works well on the original task
- If the "adapted" model gets the same (or quite similar) output as the original


## The original model

In [None]:
SEQUENCE_LENGTH = 393216

class Enformer_Original:

  def __init__(self, tfhub_url):
    self._model = hub.load(tfhub_url).model

  def predict_on_batch(self, inputs):
    predictions = self._model.predict_on_batch(inputs)
    return {k: v.numpy() for k, v in predictions.items()}

  @tf.function
  def contribution_input_grad(self, input_sequence,
                              target_mask, output_head='human'):
    input_sequence = input_sequence[tf.newaxis]

    target_mask_mass = tf.reduce_sum(target_mask)
    with tf.GradientTape() as tape:
      tape.watch(input_sequence)
      prediction = tf.reduce_sum(
          target_mask[tf.newaxis] *
          self._model.predict_on_batch(input_sequence)[output_head]) / target_mask_mass

    input_grad = tape.gradient(prediction, input_sequence) * input_sequence
    input_grad = tf.squeeze(input_grad, axis=0)
    return tf.reduce_sum(input_grad, axis=-1)

### Performance on the original task

In [59]:
def evaluate_model(model, dataset, head, max_steps=None):
  metric = MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})
  @tf.function
  def predict(x):
    return model(x, is_training=False)

  for i, batch in tqdm(enumerate(dataset)):
    if max_steps is not None and i > max_steps:
      break
    metric.update_state(batch['target'], predict(batch['sequence']))

  return metric.result()

In [60]:
# need to download the data first...
seq_id = {i.split("_")[1] for i in os.listdir("data/test_data/target/")}
tar_id = {i.split("_")[1] for i in os.listdir("data/test_data/sequence/")}

ds = []
for i in seq_id.intersection(tar_id):
    ds = ds + [{'target': tf.convert_to_tensor(np.load("data/test_data/target"+"tar_"+ i)),
                'sequence': tf.convert_to_tensor(np.load("data/test_data/sequence/"+"seq_"+ i))
               }]
               
metrics_mouse = evaluate_model(enformer,
                               dataset= ds,
                               head='mouse',
                               max_steps=100)
print('')
print({k: v.numpy().mean() for k, v in metrics_mouse.items()})

101it [01:57,  1.17s/it]



{'PearsonR': 0.7012999}


### Comparison between models

In [70]:
import tensorflow_hub as hub
#TODO: select an input sequence
# seq = pd.read_csv("../data/").loc[0,0]
x = tf.stop_gradient(one_hot_encode(seq)[np.newaxis])
enformer_hub = Enformer_Original('https://tfhub.dev/deepmind/enformer/1')

with tf.device('/cpu:0'):
    my_pred = enformer.predict_on_batch(x)
    pred_hub = enformer_hub.predict_on_batch(x)


In [75]:
# my_pred 
np.mean((pred_hub["mouse"] - my_pred )**2)

0.14715277

In [80]:
np.corrcoef(tf.reshape(my_pred, (896*1643)), pred_hub["mouse"].reshape((896*1643)))


array([[1.        , 0.99503214],
       [0.99503214, 1.        ]])

In [74]:
my_pred 

<tf.Tensor: shape=(1, 896, 1643), dtype=float32, numpy=
array([[[0.4128465 , 0.608387  , 1.6581078 , ..., 0.7503728 ,
         1.8233689 , 1.6437178 ],
        [0.8164966 , 1.2658758 , 2.6417766 , ..., 2.0881827 ,
         7.3151054 , 4.889168  ],
        [1.7602785 , 2.9756105 , 4.58759   , ..., 1.7699678 ,
         3.6342628 , 3.8983936 ],
        ...,
        [0.05717006, 0.10506745, 0.08914189, ..., 0.308557  ,
         0.4925663 , 0.517363  ],
        [0.06500754, 0.08711809, 0.08176833, ..., 0.4391661 ,
         1.1487088 , 0.7453046 ],
        [0.04310616, 0.06987476, 0.05819134, ..., 0.25947037,
         0.5015786 , 0.4432723 ]]], dtype=float32)>