In [1]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ProteInfer Class Activation Mapping (CAM)




## Initial setup (code/data download)

In [2]:
!pip install seaborn

Collecting seaborn
  Downloading seaborn-0.11.2-py3-none-any.whl (292 kB)
     |████████████████████████████████| 292 kB 8.5 MB/s            
Installing collected packages: seaborn
Successfully installed seaborn-0.11.2


In [3]:

import pandas as pd
import tensorflow
from proteinfer import inference, parenthood_lib, baseline_utils, utils, colab_evaluation

import subprocess
import shlex
import tqdm 
import sklearn
import numpy as np
import plotly.express as px
import seaborn as sns

from plotnine import ggplot, geom_point, geom_point, geom_line, aes, stat_smooth, facet_wrap, xlim,coord_cartesian,theme_bw,labs,ggsave


In [4]:
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/models/zipped_models/noxpnd_cnn_swissprot_ec_random_swiss-cnn_for_swissprot_ec_random-13685140.tar.gz
!tar xzf noxpnd_cnn_swissprot_ec_random_swiss-cnn_for_swissprot_ec_random-13685140.tar.gz
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/colab_support/parenthood.json.gz
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/blast_baseline/fasta_files/SWISSPROT_RANDOM_EC/eval_test.fasta



In [5]:
def get_ec_num_mapping():
  tree = ET.parse('enzyme-data.xml')
  root = tree.getroot()
  rows = root[0][3].findall('row')
  rows = root.findall(".//field[@name='accepted_name']..")
  ec_nums = {}
  for row in rows:
      ec_num = row.find(".//*[@name='ec_num']").text
      name = row.find(".//*[@name='accepted_name']").text
      try:
        ec_nums[ec_num]=name
      except TypeError:
        continue
  return ec_nums

def download_dataset():
  total = 13
  file_shard_names = ['https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/datasets/swissprot/random/test-{:05d}-of-{:05d}.tfrecord'.format(i,total) for i in range(total)]

  for shard_name in tqdm.tqdm(file_shard_names, position=0,desc="Downloading"):
    subprocess.check_output(shlex.split(f'wget {shard_name}'))
  return 

In [6]:

import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 100
!wget -qN https://www.enzyme-database.org/downloads/enzyme-data.xml.gz
!gunzip -f enzyme-data.xml.gz


import xml.etree.ElementTree as ET


ec_nums = get_ec_num_mapping()
download_dataset()

Downloading: 100%|██████████| 13/13 [00:07<00:00,  1.74it/s]


##Read in the test dataset

In [12]:
from proteinfer import protein_dataset

In [14]:
protein_dataset

<module 'proteinfer.protein_dataset' from '/app/proteinfer/protein_dataset.py'>

In [16]:
sequence_iterator = protein_dataset.yield_examples("./test*.tfrecord")
sequences = []
labels = []
ids = []
for example in tqdm.tqdm(sequence_iterator):
  ids.append(example[protein_dataset.SEQUENCE_ID_KEY])
  sequences.append(example[protein_dataset.SEQUENCE_KEY])
  labels.append(example[protein_dataset.LABEL_KEY])

# If we want to optimise for inference speed we should sort the dataset by
# sequence length:
seq_lengths = [len(x) for x in sequences]
indices = np.argsort(-np.array(seq_lengths)).tolist()

ids = [ids[indices[x]] for x in range(len(indices))]
sequences = [sequences[indices[x]] for x in range(len(indices))]
labels = [set(labels[indices[x]]) for x in range(len(indices))]

W0419 22:28:08.482895 140425461663552 deprecation.py:323] From /app/proteinfer/protein_dataset.py:293: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.
54285it [00:26, 2061.59it/s]


## Load the saved model

In [18]:
inferrer = inference.Inferrer(
    'noxpnd_cnn_swissprot_ec_random_swiss-cnn_for_swissprot_ec_random-13685140',use_tqdm= True, batch_size=32,activation_type="representation"
)

label_vocab = list(inferrer.get_variable('label_vocab:0').astype(str))
label_normalizer = parenthood_lib.get_applicable_label_dict(
    'parenthood.json.gz')


kernel = inferrer.get_variable("logits/kernel/read:0")



W0419 22:30:36.525437 140425461663552 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/ragged/ragged_tensor.py:1586: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [19]:
def get_multi_full_ec(labels, desired_number=3):
  subset = [x for x in labels if x.startswith(b"EC:")]
  subset = [x for x in subset if b'-' not in x]
  if len(subset)==desired_number:
    return subset
  else:
    return []

In [20]:
def moving_average(a, n=3) :
    new = np.zeros_like(a)
    length_dim = a.shape[0]
    for i in range(length_dim):
      new[i,:]=np.mean(a[np.maximum(i-n,0):np.minimum(i+n,length_dim),:],axis=0)
    return new


In [21]:
from matplotlib import colors as clr
from matplotlib import pyplot as plt
palette = clr.LinearSegmentedColormap.from_list('custom blue', ['#FFFFFF','#EEEEEE','#00EE00'], N=256)

In [22]:
items_that_satisfy_criteria = []
for i in range(len(sequences)):
  new_lab = get_multi_full_ec(labels[i],desired_number=2)
  new_seq = sequences[i]
  new_id = ids[i]
  if len(new_lab)>0:
    items_that_satisfy_criteria.append({'labels':new_lab, 'sequence':new_seq, 'id':new_id})


## Perform inference

In [25]:
counter = 0

sns.set_style("whitegrid")
sns.set(rc={'figure.figsize': (15, 35)})

one_by_one = False
ids = ['Q4LB35', 'Q54QE4', 'P54889', 'Q9PLG1', 'O94632', 'P19835', 'Q3MEJ8']
if True:
    sns.set(rc={'figure.figsize': (15, 2)})

for item in items_that_satisfy_criteria:
    if item['id'].decode() in ids or one_by_one:

        the_labels = item['labels']
        representation = inferrer.get_activations(
            [item['sequence'].decode("utf-8")])
        label_ids = [label_vocab.index(x.decode('utf-8')) for x in the_labels]
        print(f"""
representation: {representation.shape}
kernel: {kernel.shape}
label_ids: {len(label_ids)}
               """)
        contributions = np.matmul(representation.squeeze(), kernel)[:, label_ids]

        sum_contributions = contributions.sum(axis=1)
        contributions = np.maximum(contributions, 0)
        contributions = moving_average(contributions, 80)
        contributions = contributions / contributions.max(axis=0,
                                                          keepdims=True)

        df = pd.DataFrame(contributions.T)
        try:
            df.index = [
                ec_nums[x.decode().replace("EC:", "")].replace(
                    "<em>",
                    "").replace("</em>",
                                "").replace("<small>",
                                            "").replace("</small>", "")
                for x in the_labels
            ]
        except KeyError:
            continue
        print(item['id'].decode())
        ax = None
        if False:
            ax = axes[counter]
        try:
            g = sns.heatmap(df, cmap=palette, xticklabels=500, ax=ax, vmin=0)
            counter += 1
            g.set_title(item['id'].decode())
            for _, spine in g.spines.items():
                spine.set_visible(True)
            plt.yticks(rotation=0)
            plt.xticks(rotation=0)
            plt.subplots_adjust(hspace=0.7)
        except ValueError:
            continue

        if True:
            plt.show()

plt.show()


Annotating batches of sequences:   0%|          | 0/1 [00:00<?, ?it/s]

Annotating batches of sequences: 100%|██████████| 1/1 [00:00<00:00,  2.93it/s]


representation: (1,)
kernel: (1100, 5134)
label_ids: 2
               





ValueError: matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)